├── .gitignore ├── LICENSE ├── README.md ├── Town01_robust_rl_paths.png ├── carle_gym ├── __init__.py ├── envs │ ├── __init__.py │ └── carle_env.py └── setup.py ├── config ├── config_A2C_Town01_mlp.json ├── config_A2C_Town02_mlp.json ├── config_A2C_Town03_mlp.json ├── config_DQN_Town01_mlp.json ├── config_DQN_Town02_mlp.json ├── config_DQN_Town03_mlp.json ├── config_PPO_Town01_mlp.json ├── config_PPO_Town02_mlp.json ├── config_PPO_Town03_mlp.json ├── config_QRDQN_Town01_mlp.json ├── config_QRDQN_Town01_robust_rl_greedy.json ├── config_QRDQN_Town01_robust_rl_ssd.json ├── config_QRDQN_Town01_robust_rl_thres_ssd.json ├── config_QRDQN_Town02_mlp.json ├── config_QRDQN_Town02_robust_rl_greedy.json ├── config_QRDQN_Town02_robust_rl_ssd.json ├── config_QRDQN_Town02_robust_rl_thres_ssd.json └── config_QRDQN_Town03_mlp.json ├── observation.png ├── parameters.md ├── run_stable_baselines3.py ├── scripts ├── carla_docker.sh └── extract_maps.py └── thirdparty ├── sb3_contrib ├── __init__.py ├── ars │ ├── __init__.py │ ├── ars.py │ └── policies.py ├── common │ ├── __init__.py │ ├── envs │ │ ├── __init__.py │ │ └── invalid_actions_env.py │ ├── maskable │ │ ├── __init__.py │ │ ├── buffers.py │ │ ├── callbacks.py │ │ ├── distributions.py │ │ ├── evaluation.py │ │ ├── policies.py │ │ └── utils.py │ ├── utils.py │ ├── vec_env │ │ ├── __init__.py │ │ └── async_eval.py │ └── wrappers │ │ ├── __init__.py │ │ ├── action_masker.py │ │ └── time_feature.py ├── local_modifications.txt ├── ppo_mask │ ├── __init__.py │ ├── policies.py │ └── ppo_mask.py ├── py.typed ├── qrdqn │ ├── __init__.py │ ├── policies.py │ └── qrdqn.py ├── tqc │ ├── __init__.py │ ├── policies.py │ └── tqc.py ├── trpo │ ├── __init__.py │ ├── policies.py │ └── trpo.py └── version.txt └── stable_baselines3 ├── __init__.py ├── a2c ├── __init__.py ├── a2c.py └── policies.py ├── common ├── __init__.py ├── atari_wrappers.py ├── base_class.py ├── buffers.py ├── callbacks.py ├── distributions.py ├── env_checker.py ├── env_util.py ├── envs │ ├── __init__.py │ ├── bit_flipping_env.py │ ├── identity_env.py │ └── multi_input_envs.py ├── evaluation.py ├── logger.py ├── monitor.py ├── noise.py ├── off_policy_algorithm.py ├── on_policy_algorithm.py ├── policies.py ├── preprocessing.py ├── results_plotter.py ├── running_mean_std.py ├── save_util.py ├── sb2_compat │ ├── __init__.py │ └── rmsprop_tf_like.py ├── torch_layers.py ├── type_aliases.py ├── utils.py └── vec_env │ ├── __init__.py │ ├── base_vec_env.py │ ├── dummy_vec_env.py │ ├── stacked_observations.py │ ├── subproc_vec_env.py │ ├── util.py │ ├── vec_check_nan.py │ ├── vec_extract_dict_obs.py │ ├── vec_frame_stack.py │ ├── vec_monitor.py │ ├── vec_normalize.py │ ├── vec_transpose.py │ └── vec_video_recorder.py ├── ddpg ├── __init__.py ├── ddpg.py └── policies.py ├── dqn ├── __init__.py ├── dqn.py └── policies.py ├── her ├── __init__.py ├── goal_selection_strategy.py └── her_replay_buffer.py ├── local_modifications.txt ├── ppo ├── __init__.py ├── policies.py └── ppo.py ├── py.typed ├── sac ├── __init__.py ├── policies.py └── sac.py ├── td3 ├── __init__.py ├── policies.py └── td3.py └── version.txt /.gitignore: -------------------------------------------------------------------------------- 1 | carla_data 2 | Baselines 3 | carle_gym/carle.egg-info 4 | __pycache__ 5 | scripts 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Robust Field Autonomy Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Distributional RL for Route Planning 2 | 3 | This repository provides the codes of our UR 2023 paper [here](https://arxiv.org/abs/2304.09996). We developed the Stochastic Road Network Environment based on the autonomous vehicle simulator [CARLA](https://github.com/carla-simulator/carla), and proposed a Distributional RL based route planner that can plan the shortest routes that minimize stochasticity in travel time. 4 | 5 |

6 | 7 |

8 | 9 | If you find this repository useful, please cite our paper 10 | ``` 11 | @INPROCEEDINGS{10202222, 12 | author={Lin, Xi and Szenher, Paul and Martin, John D. and Englot, Brendan}, 13 | booktitle={2023 20th International Conference on Ubiquitous Robots (UR)}, 14 | title={Robust Route Planning with Distributional Reinforcement Learning in a Stochastic Road Network Environment}, 15 | year={2023}, 16 | volume={}, 17 | number={}, 18 | pages={287-294}, 19 | doi={10.1109/UR57808.2023.10202222}} 20 | ``` 21 | 22 | ## The Stochastic Road Network Environment 23 | 24 | The Stochastic Road Network Environment is built upon map structure and simulated sensor data originating from CARLA version 0.9.6. Five such maps are available by default, named in the sequencing of Town01 to Town05. The graph topology structure of Town01 and an example observation provided to the learning system are shown as follows. For detailed information about the Stochastic Road Network Environment, please refer to our paper. 25 | 26 |

27 | 28 |

29 | 30 | ## Map Data Generation 31 | 32 | The map data needed to run experiments on Town01 to Town05 could be downloaded from [here](https://stevens0-my.sharepoint.com/:f:/g/personal/xlin26_stevens_edu/EioIeHjcj_xNnJJc7ziMAUMBmz6fLFFxblYV2JWNHvAcyQ?e=R1UAjR), or you could go through the following process to generate data with the provided scripts. 33 | 34 | 1. Install [NVIDIA Container Runtime](https://nvidia.github.io/nvidia-container-runtime/) 35 | ``` 36 | curl -s -L https://nvidia.github.io/nvidia-container-runtime/gpgkey | \ 37 | sudo apt-key add - 38 | distribution=$(. /etc/os-release;echo $ID$VERSION_ID) 39 | curl -s -L https://nvidia.github.io/nvidia-container-runtime/$distribution/nvidia-container-runtime.list | \ 40 | sudo tee /etc/apt/sources.list.d/nvidia-container-runtime.list 41 | sudo apt-get update 42 | sudo apt install nvidia-container-runtime 43 | 44 | sudo systemctl daemon-reload 45 | sudo systemctl restart docker 46 | ``` 47 | 48 | 2. Clone this git repo and enter the directory. 49 | ``` 50 | git clone git@github.com:RobustFieldAutonomyLab/Stochastic_Road_Network.git 51 | cd Stochastic_Road_Network 52 | ``` 53 | 54 | 3. Install relevant system dependencies for CARLA Python Library: 55 | ``` 56 | sudo apt install libpng16-16 libjpeg8 libtiff5 57 | ``` 58 | 59 | 4. Run the Docker script (initializes headless CARLA server under Docker) 60 | ``` 61 | sudo scripts/carla_docker.sh -oe 62 | ``` 63 | 64 | 5. Run the data generation script: 65 | ``` 66 | python scripts/extract_maps.py 67 | ``` 68 | 69 | ## Train Route Planning RL agents 70 | 71 | Our proposed planner uses QR-DQN, and we select A2C, PPO and DQN as the traditional RL baselines. We provide configuration files in config directory for training RL agents on different maps. 72 | ``` 73 | $ python run_stable_baselines3.py -C [config file (required)] -P [number of processes (optional)] -D [cuda device (optional)] 74 | ``` 75 | 76 | ## Experiment Parameterization 77 | Example configuration files are provided in the **config** directory, and see [parameters.md](parameters.md) for detailed explanations of common parameters. 78 | 79 | ## Third Party Libraries 80 | This project uses implementations of A2C, PPO, DQN and QR-DQN agents from [stable-baselines3](https://github.com/DLR-RM/stable-baselines3) and [stable-baselines3-contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib), and makes some modifications to apply to the proposed environment. There are some agent specific parameters in the provided configuration files, please refer to [on_policy_algorithm.py](https://github.com/RobustFieldAutonomyLab/Stochastic_Road_Network/blob/main/thirdparty/stable_baselines3/common/on_policy_algorithm.py) ((A2C and PPO)) and [off_policy_algorithm.py](https://github.com/RobustFieldAutonomyLab/Stochastic_Road_Network/blob/main/thirdparty/stable_baselines3/common/off_policy_algorithm.py) (DQN and QR-DQN) for further information. 81 | -------------------------------------------------------------------------------- /Town01_robust_rl_paths.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RobustFieldAutonomyLab/Stochastic_Road_Network/1cdcd41c7311560bf7a5df0d4d8bca829fe2b958/Town01_robust_rl_paths.png -------------------------------------------------------------------------------- /carle_gym/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register( 4 | id='carle-v0', 5 | entry_point='carle_gym.envs:CarleEnv', 6 | ) 7 | -------------------------------------------------------------------------------- /carle_gym/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from carle_gym.envs.carle_env import CarleEnv 2 | -------------------------------------------------------------------------------- /carle_gym/envs/carle_env.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence, Tuple, Union 3 | import numpy as np 4 | 5 | import gym 6 | 7 | class CarleEnv(gym.Env): 8 | 9 | def __init__( 10 | self, 11 | is_eval_env: bool, 12 | seed: int, 13 | dataset_dir: Union[str, Path], 14 | goal_states: Sequence[int], 15 | reset_state: int, 16 | discount: float, 17 | crosswalk_states: Sequence[int], 18 | agent: str, 19 | network: str, 20 | r_base: float, 21 | r_loopback: float 22 | ): 23 | # PRNG for random rewards. 24 | self._rand = np.random.RandomState(seed) 25 | 26 | # Count timestep and record stochastic reward 27 | self.count = 0 28 | 29 | # Record path if the environment is for evaluation 30 | self.record = True if is_eval_env else False 31 | self.curr_path = [] 32 | self.all_paths = [] 33 | 34 | # Record quantiles of all state action pair (for QR-DQN agent) 35 | self.agent = agent 36 | self.quantiles = [] 37 | 38 | # Set state information for goals and resets 39 | self.goal_states = goal_states 40 | self.reset_state = reset_state 41 | self.state = self.reset_state 42 | self.prev_state = None # used for checking self-transition 43 | self.crosswalk_states = crosswalk_states 44 | 45 | # Initialize environment parameters and data 46 | dataset_path = Path(dataset_dir) 47 | self.waypoint_locations = np.loadtxt( 48 | fname=dataset_path / "waypoint_locations.csv", delimiter="," 49 | ) 50 | self.transition_matrix = np.loadtxt( 51 | fname=dataset_path / "transition_matrix.csv", dtype=int, delimiter="," 52 | ) 53 | self.observations = np.loadtxt( 54 | fname=dataset_path / "observations.csv", delimiter="," 55 | ) 56 | 57 | # Define action space and observation space 58 | self.network = network 59 | self.action_space = gym.spaces.Discrete(np.shape(self.transition_matrix)[1]) 60 | if self.network == "MlpPolicy": 61 | self.observation_space = gym.spaces.Box(np.zeros((255*255,)),np.ones((255*255,)),dtype=np.float32) 62 | elif self.network == "CnnPolicy": 63 | self.observation_shape = (1,255,255) 64 | self.observation_space = gym.spaces.Box(np.zeros(self.observation_shape),np.ones(self.observation_shape),dtype=np.float32) 65 | else: 66 | raise RuntimeError("The network strucutre is not available") 67 | 68 | # Initialize transition counter 69 | self.transition_counts = np.zeros_like(self.transition_matrix) 70 | 71 | # Set discount factor 72 | self.discount = discount 73 | 74 | # Set rewards values 75 | self.r_base = r_base 76 | self.r_loopback = r_loopback 77 | self.rewards = -self.r_base * np.ones(len(self.transition_matrix)) 78 | self.rewards[np.array(goal_states)] = 0 79 | 80 | def step(self, action: int) -> Tuple[np.ndarray, float, bool, dict]: 81 | """Apply the given action and transition to the next state.""" 82 | self.prev_state = self.state 83 | self.transition_counts[self.state, action] += 1 84 | self.state = self.transition_matrix[self.state, action] 85 | self.count += 1 86 | 87 | if self.record: 88 | self.curr_path.append(self.state) 89 | 90 | return self.get_obs(), self.get_reward(), self.get_done(), self.get_state() 91 | 92 | def reset(self) -> np.ndarray: 93 | """Reset the environment and save path if the environment is for evaluation""" 94 | self.transition_counts = np.zeros_like(self.transition_matrix) 95 | self.state = self.reset_state 96 | self.prev_state = None 97 | self.count = 0 98 | 99 | if self.record: 100 | if np.shape(self.curr_path)[0] != 0: 101 | self.all_paths.append(self.curr_path) 102 | self.curr_path = [] 103 | 104 | return self.get_obs() 105 | 106 | def get_obs_at_state(self, state:int) -> np.ndarray: 107 | """Returns the observation image (ground) for a given state.""" 108 | scans = np.reshape(self.observations[state],(255,255,2)) 109 | ground = scans[:,:,1] 110 | if self.network == "MlpPolicy": 111 | return np.array(ground.flatten()) 112 | elif self.network == "CnnPolicy": 113 | return np.array([ground]) 114 | else: 115 | raise RuntimeError("The network strucutre is not available") 116 | 117 | def get_obs(self) -> np.ndarray: 118 | """Returns the observation image (ground) for the current state.""" 119 | #scans = np.reshape(self.observations[self.state],(255,255,2)) 120 | #ground = scans[:,:,1] 121 | #return np.array(ground.flatten()) 122 | return self.get_obs_at_state(self.state) 123 | 124 | def save_quantiles(self, quantiles:np.ndarray) -> np.ndarray: 125 | """Save quantiles of all state action pair (for QR-DQN agent)""" 126 | assert self.agent == "QRDQN", "save_quantiles is only avaible to the QR-DQN agent" 127 | self.quantiles.append(quantiles) 128 | 129 | def get_quantiles(self) -> np.ndarray: 130 | """Get quantiles of all state action pair (for QR-DQN agent)""" 131 | assert self.agent == "QRDQN", "get_quantiles is only avaible to the QR-DQN agent" 132 | return np.array(self.quantiles) 133 | 134 | def get_reward(self, ssd_thres:int=15) -> float: 135 | """Returns the reward for reaching the current state.""" 136 | # Penalize the self-transition action 137 | if self.prev_state == self.state: 138 | # return self.rewards[self.state] - ssd_thres - 3 139 | return self.rewards[self.state] - self.r_loopback 140 | 141 | # Add noise at the simulated cross walks. 142 | if self.state in self.crosswalk_states: 143 | # Use vonmises distribution as stand-in for wrapped gaussian 144 | # - Interval is bounded from -2*r_base to 0 with below parameters 145 | # - kappa parameter is inversely proportional to variance 146 | # - see: 147 | # https://numpy.org/devdocs/reference/random/generated/numpy.random.vonmises.html 148 | return self.r_base*(self._rand.vonmises(mu=0, kappa=1) / np.pi) - self.r_base 149 | 150 | # Deterministic traffic penalty otherwise. 151 | return self.rewards[self.state] 152 | 153 | def get_done(self) -> bool: 154 | """Returns a done flag if the goal is reached.""" 155 | return self.state in self.goal_states 156 | 157 | def get_state(self) -> dict: 158 | """Return current state id""" 159 | info = {"state_id":self.state} 160 | return info 161 | 162 | def get_count(self) -> list: 163 | """Return count since last reset""" 164 | return self.count 165 | 166 | def get_all_paths(self) -> list: 167 | """Return all paths in evaluation""" 168 | self.reset() 169 | return self.all_paths 170 | 171 | def get_num_of_states(self) -> int: 172 | return np.shape(self.transition_matrix)[0] 173 | -------------------------------------------------------------------------------- /carle_gym/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name='carle', 4 | version='0.0.1', 5 | install_requires=['gym'] 6 | ) 7 | -------------------------------------------------------------------------------- /config/config_A2C_Town01_mlp.json: -------------------------------------------------------------------------------- 1 | { 2 | "base":{ 3 | "seed": [0], 4 | "eval_freq": 10000, 5 | "num_timesteps": 1000000 6 | }, 7 | "agent":{ 8 | "name": "A2C", 9 | "discount": 0.99, 10 | "alpha": [1e-04], 11 | "buffer_size": 64 12 | }, 13 | "environment":{ 14 | "map_name": "Town01", 15 | "data_dir": "carla_data", 16 | "start_state": 209, 17 | "goal_states": [112, 113], 18 | "crosswalk_states": [261], 19 | "r_base": 1, 20 | "r_loopback": 0 21 | }, 22 | "policy": "MlpPolicy", 23 | "save_dir": "Baselines" 24 | } 25 | -------------------------------------------------------------------------------- /config/config_A2C_Town02_mlp.json: -------------------------------------------------------------------------------- 1 | { 2 | "base":{ 3 | "seed": [0], 4 | "eval_freq": 10000, 5 | "num_timesteps": 1000000 6 | }, 7 | "agent":{ 8 | "name": "A2C", 9 | "discount": 0.99, 10 | "alpha": [8e-05], 11 | "buffer_size": 64 12 | }, 13 | "environment":{ 14 | "map_name": "Town02", 15 | "data_dir": "carla_data", 16 | "start_state": 16, 17 | "goal_states": [89,90], 18 | "crosswalk_states": [96], 19 | "r_base": 1, 20 | "r_loopback": 0 21 | }, 22 | "policy": "MlpPolicy", 23 | "save_dir": "Baselines" 24 | } 25 | -------------------------------------------------------------------------------- /config/config_A2C_Town03_mlp.json: -------------------------------------------------------------------------------- 1 | { 2 | "base":{ 3 | "seed": [0], 4 | "eval_freq": 10000, 5 | "num_timesteps": 2000000 6 | }, 7 | "agent":{ 8 | "name": "A2C", 9 | "discount": 0.99, 10 | "alpha": [1e-04], 11 | "buffer_size": 64 12 | }, 13 | "environment":{ 14 | "map_name": "Town03", 15 | "data_dir": "carla_data", 16 | "start_state": 546, 17 | "goal_states": [641, 642], 18 | "crosswalk_states": [585], 19 | "r_base": 1, 20 | "r_loopback": 0 21 | }, 22 | "policy": "MlpPolicy", 23 | "save_dir": "Baselines" 24 | } 25 | -------------------------------------------------------------------------------- /config/config_DQN_Town01_mlp.json: -------------------------------------------------------------------------------- 1 | { 2 | "base":{ 3 | "seed": [0], 4 | "eval_freq": 10000, 5 | "num_timesteps": 1000000 6 | }, 7 | "agent":{ 8 | "name": "DQN", 9 | "discount": 0.99, 10 | "alpha": [1e-04], 11 | "buffer_size": 2048, 12 | "batch_size": 64 13 | }, 14 | "environment":{ 15 | "map_name": "Town01", 16 | "data_dir": "carla_data", 17 | "start_state": 209, 18 | "goal_states": [112, 113], 19 | "crosswalk_states": [261], 20 | "r_base": 1, 21 | "r_loopback": 0 22 | }, 23 | "policy": "MlpPolicy", 24 | "save_dir": "Baselines" 25 | } 26 | -------------------------------------------------------------------------------- /config/config_DQN_Town02_mlp.json: -------------------------------------------------------------------------------- 1 | { 2 | "base":{ 3 | "seed": [0], 4 | "eval_freq": 10000, 5 | "num_timesteps": 1000000 6 | }, 7 | "agent":{ 8 | "name": "DQN", 9 | "discount": 0.99, 10 | "alpha": [5e-06], 11 | "buffer_size": 2048, 12 | "batch_size": 64 13 | }, 14 | "environment":{ 15 | "map_name": "Town02", 16 | "data_dir": "carla_data", 17 | "start_state": 16, 18 | "goal_states": [89,90], 19 | "crosswalk_states": [96], 20 | "r_base": 1, 21 | "r_loopback": 0 22 | }, 23 | "policy": "MlpPolicy", 24 | "save_dir": "Baselines" 25 | } 26 | -------------------------------------------------------------------------------- /config/config_DQN_Town03_mlp.json: -------------------------------------------------------------------------------- 1 | { 2 | "base":{ 3 | "seed": [0], 4 | "eval_freq": 10000, 5 | "num_timesteps": 2000000 6 | }, 7 | "agent":{ 8 | "name": "DQN", 9 | "discount": 0.99, 10 | "alpha": [3e-04], 11 | "buffer_size": 2048, 12 | "batch_size": 64 13 | }, 14 | "environment":{ 15 | "map_name": "Town03", 16 | "data_dir": "carla_data", 17 | "start_state": 546, 18 | "goal_states": [641, 642], 19 | "crosswalk_states": [585], 20 | "r_base": 1, 21 | "r_loopback": 0 22 | }, 23 | "policy": "MlpPolicy", 24 | "save_dir": "Baselines" 25 | } 26 | -------------------------------------------------------------------------------- /config/config_PPO_Town01_mlp.json: -------------------------------------------------------------------------------- 1 | { 2 | "base":{ 3 | "seed": [0], 4 | "eval_freq": 10000, 5 | "num_timesteps": 1000000 6 | }, 7 | "agent":{ 8 | "name": "PPO", 9 | "discount": 0.99, 10 | "alpha": [2e-04], 11 | "buffer_size": 2048, 12 | "batch_size": 64, 13 | "n_epochs":1 14 | }, 15 | "environment":{ 16 | "map_name": "Town01", 17 | "data_dir": "carla_data", 18 | "start_state": 209, 19 | "goal_states": [112, 113], 20 | "crosswalk_states": [261], 21 | "r_base": 1, 22 | "r_loopback": 0 23 | }, 24 | "policy": "MlpPolicy", 25 | "save_dir": "Baselines" 26 | } 27 | -------------------------------------------------------------------------------- /config/config_PPO_Town02_mlp.json: -------------------------------------------------------------------------------- 1 | { 2 | "base":{ 3 | "seed": [0], 4 | "eval_freq": 10000, 5 | "num_timesteps": 1000000 6 | }, 7 | "agent":{ 8 | "name": "PPO", 9 | "discount": 0.99, 10 | "alpha": [3e-04], 11 | "buffer_size": 2048, 12 | "batch_size": 64, 13 | "n_epochs":1 14 | }, 15 | "environment":{ 16 | "map_name": "Town02", 17 | "data_dir": "carla_data", 18 | "start_state": 16, 19 | "goal_states": [89, 90], 20 | "crosswalk_states": [96], 21 | "r_base": 1, 22 | "r_loopback": 0 23 | }, 24 | "policy": "MlpPolicy", 25 | "save_dir": "Baselines" 26 | } 27 | -------------------------------------------------------------------------------- /config/config_PPO_Town03_mlp.json: -------------------------------------------------------------------------------- 1 | { 2 | "base":{ 3 | "seed": [0], 4 | "eval_freq": 10000, 5 | "num_timesteps": 2000000 6 | }, 7 | "agent":{ 8 | "name": "PPO", 9 | "discount": 0.99, 10 | "alpha": [2e-04], 11 | "buffer_size": 2048, 12 | "batch_size": 64, 13 | "n_epochs":1 14 | }, 15 | "environment":{ 16 | "map_name": "Town03", 17 | "data_dir": "carla_data", 18 | "start_state": 546, 19 | "goal_states": [641, 642], 20 | "crosswalk_states": [585], 21 | "r_base": 1, 22 | "r_loopback": 0 23 | }, 24 | "policy": "MlpPolicy", 25 | "save_dir": "Baselines" 26 | } 27 | -------------------------------------------------------------------------------- /config/config_QRDQN_Town01_mlp.json: -------------------------------------------------------------------------------- 1 | { 2 | "base":{ 3 | "seed": [0], 4 | "eval_freq": 10000, 5 | "num_timesteps": 2000000 6 | }, 7 | "agent":{ 8 | "name": "QRDQN", 9 | "discount": 0.99, 10 | "alpha": [6e-04], 11 | "buffer_size": 2048, 12 | "batch_size": 64, 13 | "n_quantiles": 4, 14 | "epsilon": 0.1, 15 | "eps_fraction": 0.02, 16 | "eval_policy": "Greedy" 17 | }, 18 | "environment":{ 19 | "map_name": "Town01", 20 | "data_dir": "carla_data", 21 | "start_state": 209, 22 | "goal_states": [112, 113], 23 | "crosswalk_states": [261], 24 | "r_base": 1, 25 | "r_loopback": 0 26 | }, 27 | "policy": "MlpPolicy", 28 | "save_dir": "Baselines" 29 | } 30 | 31 | -------------------------------------------------------------------------------- /config/config_QRDQN_Town01_robust_rl_greedy.json: -------------------------------------------------------------------------------- 1 | { 2 | "base":{ 3 | "seed": [0], 4 | "eval_freq": 10000, 5 | "num_timesteps": 1000000 6 | }, 7 | "agent":{ 8 | "name": "QRDQN", 9 | "discount": 0.99, 10 | "alpha": [5e-04], 11 | "buffer_size": 2048, 12 | "batch_size": 64, 13 | "n_quantiles": 4, 14 | "epsilon": 0.1, 15 | "eps_fraction": 0.02, 16 | "eval_policy": "Greedy" 17 | }, 18 | "environment":{ 19 | "map_name": "Town01", 20 | "data_dir": "carla_data", 21 | "start_state": 209, 22 | "goal_states": [112, 113], 23 | "crosswalk_states": [261], 24 | "r_base": 3, 25 | "r_loopback": 18 26 | }, 27 | "policy": "CnnPolicy", 28 | "save_dir": "Baselines" 29 | } 30 | 31 | -------------------------------------------------------------------------------- /config/config_QRDQN_Town01_robust_rl_ssd.json: -------------------------------------------------------------------------------- 1 | { 2 | "base":{ 3 | "seed": [0], 4 | "eval_freq": 10000, 5 | "num_timesteps": 1000000 6 | }, 7 | "agent":{ 8 | "name": "QRDQN", 9 | "discount": 0.99, 10 | "alpha": [5e-04], 11 | "buffer_size": 2048, 12 | "batch_size": 64, 13 | "n_quantiles": 4, 14 | "epsilon": 0.1, 15 | "eps_fraction": 0.02, 16 | "eval_policy": "SSD" 17 | }, 18 | "environment":{ 19 | "map_name": "Town01", 20 | "data_dir": "carla_data", 21 | "start_state": 209, 22 | "goal_states": [112, 113], 23 | "crosswalk_states": [261], 24 | "r_base": 3, 25 | "r_loopback": 18 26 | }, 27 | "policy": "CnnPolicy", 28 | "save_dir": "Baselines" 29 | } 30 | 31 | -------------------------------------------------------------------------------- /config/config_QRDQN_Town01_robust_rl_thres_ssd.json: -------------------------------------------------------------------------------- 1 | { 2 | "base":{ 3 | "seed": [0], 4 | "eval_freq": 10000, 5 | "num_timesteps": 1000000 6 | }, 7 | "agent":{ 8 | "name": "QRDQN", 9 | "discount": 0.99, 10 | "alpha": [5e-04], 11 | "buffer_size": 2048, 12 | "batch_size": 64, 13 | "n_quantiles": 4, 14 | "epsilon": 0.1, 15 | "eps_fraction": 0.02, 16 | "eval_policy": "Thresholded_SSD" 17 | }, 18 | "environment":{ 19 | "map_name": "Town01", 20 | "data_dir": "carla_data", 21 | "start_state": 209, 22 | "goal_states": [112, 113], 23 | "crosswalk_states": [261], 24 | "r_base": 3, 25 | "r_loopback": 18 26 | }, 27 | "policy": "CnnPolicy", 28 | "save_dir": "Baselines" 29 | } 30 | 31 | -------------------------------------------------------------------------------- /config/config_QRDQN_Town02_mlp.json: -------------------------------------------------------------------------------- 1 | { 2 | "base":{ 3 | "seed": [0], 4 | "eval_freq": 10000, 5 | "num_timesteps": 2000000 6 | }, 7 | "agent":{ 8 | "name": "QRDQN", 9 | "discount": 0.99, 10 | "alpha": [6e-05], 11 | "buffer_size": 2048, 12 | "batch_size": 64, 13 | "n_quantiles": 4, 14 | "epsilon": 0.1, 15 | "eps_fraction": 0.02, 16 | "eval_policy": "Greedy" 17 | }, 18 | "environment":{ 19 | "map_name": "Town02", 20 | "data_dir": "carla_data", 21 | "start_state": 16, 22 | "goal_states": [89, 90], 23 | "crosswalk_states": [96], 24 | "r_base": 1, 25 | "r_loopback": 0 26 | }, 27 | "policy": "MlpPolicy", 28 | "save_dir": "Baselines" 29 | } 30 | 31 | -------------------------------------------------------------------------------- /config/config_QRDQN_Town02_robust_rl_greedy.json: -------------------------------------------------------------------------------- 1 | { 2 | "base":{ 3 | "seed": [0], 4 | "eval_freq": 10000, 5 | "num_timesteps": 1000000 6 | }, 7 | "agent":{ 8 | "name": "QRDQN", 9 | "discount": 0.99, 10 | "alpha": [5e-04], 11 | "buffer_size": 2048, 12 | "batch_size": 64, 13 | "n_quantiles": 4, 14 | "epsilon": 0.1, 15 | "eps_fraction": 0.02, 16 | "eval_policy": "Greedy" 17 | }, 18 | "environment":{ 19 | "map_name": "Town02", 20 | "data_dir": "/home/rfal/code/tro/data/carla_data", 21 | "start_state": 16, 22 | "goal_states": [89, 90], 23 | "crosswalk_states": [96], 24 | "r_base": 3, 25 | "r_loopback": 18 26 | }, 27 | "policy": "CnnPolicy", 28 | "save_dir": "/home/rfal/Stochastic_Road_Network/Baselines" 29 | } 30 | 31 | -------------------------------------------------------------------------------- /config/config_QRDQN_Town02_robust_rl_ssd.json: -------------------------------------------------------------------------------- 1 | { 2 | "base":{ 3 | "seed": [0], 4 | "eval_freq": 10000, 5 | "num_timesteps": 1000000 6 | }, 7 | "agent":{ 8 | "name": "QRDQN", 9 | "discount": 0.99, 10 | "alpha": [5e-04], 11 | "buffer_size": 2048, 12 | "batch_size": 64, 13 | "n_quantiles": 4, 14 | "epsilon": 0.1, 15 | "eps_fraction": 0.02, 16 | "eval_policy": "SSD" 17 | }, 18 | "environment":{ 19 | "map_name": "Town02", 20 | "data_dir": "/home/rfal/code/tro/data/carla_data", 21 | "start_state": 16, 22 | "goal_states": [89, 90], 23 | "crosswalk_states": [96], 24 | "r_base": 3, 25 | "r_loopback": 18 26 | }, 27 | "policy": "CnnPolicy", 28 | "save_dir": "/home/rfal/Stochastic_Road_Network/Baselines" 29 | } 30 | 31 | -------------------------------------------------------------------------------- /config/config_QRDQN_Town02_robust_rl_thres_ssd.json: -------------------------------------------------------------------------------- 1 | { 2 | "base":{ 3 | "seed": [0], 4 | "eval_freq": 10000, 5 | "num_timesteps": 1000000 6 | }, 7 | "agent":{ 8 | "name": "QRDQN", 9 | "discount": 0.99, 10 | "alpha": [5e-04], 11 | "buffer_size": 2048, 12 | "batch_size": 64, 13 | "n_quantiles": 4, 14 | "epsilon": 0.1, 15 | "eps_fraction": 0.02, 16 | "eval_policy": "Thresholded_SSD", 17 | "ssd_thres": 15 18 | }, 19 | "environment":{ 20 | "map_name": "Town02", 21 | "data_dir": "/home/rfal/code/tro/data/carla_data", 22 | "start_state": 16, 23 | "goal_states": [89, 90], 24 | "crosswalk_states": [96], 25 | "r_base": 3, 26 | "r_loopback": 18 27 | }, 28 | "policy": "CnnPolicy", 29 | "save_dir": "/home/rfal/Stochastic_Road_Network/Baselines" 30 | } 31 | 32 | -------------------------------------------------------------------------------- /config/config_QRDQN_Town03_mlp.json: -------------------------------------------------------------------------------- 1 | { 2 | "base":{ 3 | "seed": [0], 4 | "eval_freq": 10000, 5 | "num_timesteps": 2000000 6 | }, 7 | "agent":{ 8 | "name": "QRDQN", 9 | "discount": 0.99, 10 | "alpha": [5e-05], 11 | "buffer_size": 2048, 12 | "batch_size": 64, 13 | "n_quantiles": 4, 14 | "epsilon": 0.1, 15 | "eps_fraction": 0.02, 16 | "eval_policy": "Greedy" 17 | }, 18 | "environment":{ 19 | "map_name": "Town03", 20 | "data_dir": "carla_data", 21 | "start_state": 546, 22 | "goal_states": [641, 642], 23 | "crosswalk_states": [585], 24 | "r_base": 1, 25 | "r_loopback": 0 26 | }, 27 | "policy": "MlpPolicy", 28 | "save_dir": "Baselines" 29 | } 30 | 31 | -------------------------------------------------------------------------------- /observation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RobustFieldAutonomyLab/Stochastic_Road_Network/1cdcd41c7311560bf7a5df0d4d8bca829fe2b958/observation.png -------------------------------------------------------------------------------- /parameters.md: -------------------------------------------------------------------------------- 1 | # Experiment Config File parameters 2 | 3 | - **`base`**: Trial parameters. 4 | - **`seed`**: RNG Seeds. 5 | - **`eval_freq`**: Number of training timesteps between two evaluations (per trial). 6 | - **`num_timesteps`**: Number of total training timesteps (per trial). 7 | - **`agent`**: Agent parameters. 8 | - **`name`**: Method (A2C, PPO, DQN, or QRDQN). 9 | - **`discount`**: Discount factor (gamma). 10 | - **`alpha`**: Learning step size (alpha). 11 | - **`environment`**: Environment Parameters. 12 | - **`map_name`**: Name of the target Map. 13 | - **`data_dir`**: Directory where environment data is stored. 14 | - **`start_state`**: Start state. 15 | - **`goal_states`**: List of goal states. 16 | - **`crosswalk_states`**: List of stochastic crosswalk states. 17 | - **`policy`**: Agent network structure (currently only "CnnPolicy" is available). 18 | - **`save_dir`**: Directory where experiment data is stored. 19 | -------------------------------------------------------------------------------- /scripts/carla_docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ### carla_docker.sh -- run Carla environment in Docker container 4 | ### 5 | ### Usage: 6 | ### ./carla_docker.sh [options] 7 | ### 8 | ### Options: 9 | ### -h Show this message 10 | ### -o Offscreen mode: disable CARLA server window 11 | ### 12 | 13 | usage() { 14 | # Use the file header as a usage guide 15 | # Reference: https://samizdat.dev/help-message-for-shell-scripts/ 16 | sed -rn 's/^### ?/ /;T;p' "$0" 17 | } 18 | 19 | version() { 20 | # Convert decimal version numbers into integers suitable for comparison 21 | echo "$@" | awk -F. '{ printf("%03d%03d%03d\n", $1,$2,$3); }' 22 | } 23 | 24 | 25 | # Get directory of scripts 26 | script_dir=$(dirname "$(readlink -f "$0")") 27 | 28 | # Set name of CARLA parameters 29 | carla_image="carlasim/carla:0.9.6" 30 | carla_egg_dir="/home/carla/PythonAPI/carla/dist" 31 | 32 | # Set nvidia container runtime flag (dependent on docker version) 33 | if [ "$(version "$(docker version -f '{{.Server.Version}}')")" -gt "$(version 19.03.0)" ]; then 34 | nvidia_flag="--gpus all" 35 | else 36 | nvidia_flag="--runtime=nvidia" 37 | fi 38 | 39 | # Set SDL_VIDEODRIVER env var to make CARLA server visible 40 | sdl_driver=x11 41 | 42 | # Parse arguments with posix-compatible getopt 43 | cmdargs=$(GETOPT_COMPATIBLE=1 getopt hoe "$@") 44 | eval set -- "$cmdargs" 45 | 46 | while [ "$#" -ne 0 ] ; do 47 | case "$1" in 48 | -h|--help) 49 | usage 50 | exit 1 ;; 51 | -o|--offscreen) 52 | sdl_driver=offscreen 53 | shift 1 54 | ;; 55 | -e|--egg) 56 | copy_egg=yes 57 | shift 1 58 | ;; 59 | --) 60 | break ;; 61 | *) logerror "Argument parse error, could not parse $1" 62 | exit 1 ;; 63 | esac 64 | done 65 | 66 | # If the CARLA egg file does not exist, copy it from the docker image 67 | if [ -n "$copy_egg" ]; then 68 | echo "CARLA egg file requested, copying from \"$carla_image\" image." 69 | egg_file=$(docker run --rm "$carla_image" bash -c "ls ${carla_egg_dir}/*py3*.egg") 70 | echo "Egg file $egg_file found, copying..." 71 | docker cp "$(docker create --rm $carla_image)":"${egg_file}" "$script_dir" 72 | fi 73 | 74 | echo "Running $carla_image" 75 | docker run \ 76 | -p 2000-2002:2000-2002 \ 77 | -v /tmp/.X11-unix:/tmp/.X11-unix \ 78 | -e DISPLAY="$DISPLAY" \ 79 | -e SDL_VIDEODRIVER="$sdl_driver"\ 80 | -it \ 81 | --rm \ 82 | $nvidia_flag \ 83 | $carla_image \ 84 | ./CarlaUE4.sh -opengl 85 | -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from sb3_contrib.ars import ARS 4 | from sb3_contrib.ppo_mask import MaskablePPO 5 | from sb3_contrib.qrdqn import QRDQN 6 | from sb3_contrib.tqc import TQC 7 | from sb3_contrib.trpo import TRPO 8 | 9 | # Read version from file 10 | version_file = os.path.join(os.path.dirname(__file__), "version.txt") 11 | with open(version_file, "r") as file_handler: 12 | __version__ = file_handler.read().strip() 13 | -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/ars/__init__.py: -------------------------------------------------------------------------------- 1 | from sb3_contrib.ars.ars import ARS 2 | from sb3_contrib.ars.policies import LinearPolicy, MlpPolicy 3 | -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/ars/policies.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Type 2 | 3 | import gym 4 | import torch as th 5 | from stable_baselines3.common.policies import BasePolicy, register_policy 6 | from stable_baselines3.common.preprocessing import get_action_dim 7 | from stable_baselines3.common.torch_layers import create_mlp 8 | from torch import nn 9 | 10 | 11 | class ARSPolicy(BasePolicy): 12 | """ 13 | Policy network for ARS. 14 | 15 | :param observation_space: The observation space of the environment 16 | :param action_space: The action space of the environment 17 | :param net_arch: Network architecture, defaults to a 2 layers MLP with 64 hidden nodes. 18 | :param activation_fn: Activation function 19 | :param squash_output: For continuous actions, whether the output is squashed 20 | or not using a ``tanh()`` function. If not squashed with tanh the output will instead be clipped. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | observation_space: gym.spaces.Space, 26 | action_space: gym.spaces.Space, 27 | net_arch: Optional[List[int]] = None, 28 | activation_fn: Type[nn.Module] = nn.ReLU, 29 | squash_output: bool = True, 30 | ): 31 | 32 | super().__init__( 33 | observation_space, 34 | action_space, 35 | squash_output=isinstance(action_space, gym.spaces.Box) and squash_output, 36 | ) 37 | 38 | if net_arch is None: 39 | net_arch = [64, 64] 40 | 41 | self.net_arch = net_arch 42 | self.features_extractor = self.make_features_extractor() 43 | self.features_dim = self.features_extractor.features_dim 44 | self.activation_fn = activation_fn 45 | 46 | if isinstance(action_space, gym.spaces.Box): 47 | action_dim = get_action_dim(action_space) 48 | actor_net = create_mlp(self.features_dim, action_dim, net_arch, activation_fn, squash_output=True) 49 | elif isinstance(action_space, gym.spaces.Discrete): 50 | actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn) 51 | else: 52 | raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.") 53 | 54 | self.action_net = nn.Sequential(*actor_net) 55 | 56 | def _get_constructor_parameters(self) -> Dict[str, Any]: 57 | # data = super()._get_constructor_parameters() this adds normalize_images, which we don't support... 58 | data = dict( 59 | observation_space=self.observation_space, 60 | action_space=self.action_space, 61 | net_arch=self.net_arch, 62 | activation_fn=self.activation_fn, 63 | ) 64 | return data 65 | 66 | def forward(self, obs: th.Tensor) -> th.Tensor: 67 | 68 | features = self.extract_features(obs) 69 | if isinstance(self.action_space, gym.spaces.Box): 70 | return self.action_net(features) 71 | elif isinstance(self.action_space, gym.spaces.Discrete): 72 | logits = self.action_net(features) 73 | return th.argmax(logits, dim=1) 74 | else: 75 | raise NotImplementedError() 76 | 77 | def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor: 78 | # Non deterministic action does not really make sense for ARS, we ignore this parameter for now.. 79 | return self(observation) 80 | 81 | 82 | class ARSLinearPolicy(ARSPolicy): 83 | """ 84 | Linear policy network for ARS. 85 | 86 | :param observation_space: The observation space of the environment 87 | :param action_space: The action space of the environment 88 | :param with_bias: With or without bias on the output 89 | :param squash_output: For continuous actions, whether the output is squashed 90 | or not using a ``tanh()`` function. If not squashed with tanh the output will instead be clipped. 91 | """ 92 | 93 | def __init__( 94 | self, 95 | observation_space: gym.spaces.Space, 96 | action_space: gym.spaces.Space, 97 | with_bias: bool = False, 98 | squash_output: bool = False, 99 | ): 100 | 101 | super().__init__(observation_space, action_space, squash_output=squash_output) 102 | 103 | if isinstance(action_space, gym.spaces.Box): 104 | action_dim = get_action_dim(action_space) 105 | self.action_net = nn.Linear(self.features_dim, action_dim, bias=with_bias) 106 | if squash_output: 107 | self.action_net = nn.Sequential(self.action_net, nn.Tanh()) 108 | elif isinstance(action_space, gym.spaces.Discrete): 109 | self.action_net = nn.Linear(self.features_dim, action_space.n, bias=with_bias) 110 | else: 111 | raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.") 112 | 113 | 114 | MlpPolicy = ARSPolicy 115 | LinearPolicy = ARSLinearPolicy 116 | 117 | 118 | register_policy("LinearPolicy", LinearPolicy) 119 | register_policy("MlpPolicy", MlpPolicy) 120 | -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RobustFieldAutonomyLab/Stochastic_Road_Network/1cdcd41c7311560bf7a5df0d4d8bca829fe2b958/thirdparty/sb3_contrib/common/__init__.py -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/common/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from sb3_contrib.common.envs.invalid_actions_env import ( 2 | InvalidActionEnvDiscrete, 3 | InvalidActionEnvMultiBinary, 4 | InvalidActionEnvMultiDiscrete, 5 | ) 6 | -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/common/envs/invalid_actions_env.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import numpy as np 4 | from gym import spaces 5 | from stable_baselines3.common.envs import IdentityEnv 6 | 7 | 8 | class InvalidActionEnvDiscrete(IdentityEnv): 9 | """ 10 | Identity env with a discrete action space. Supports action masking. 11 | """ 12 | 13 | def __init__( 14 | self, 15 | dim: Optional[int] = None, 16 | ep_length: int = 100, 17 | n_invalid_actions: int = 0, 18 | ): 19 | if dim is None: 20 | dim = 1 21 | assert n_invalid_actions < dim, f"Too many invalid actions: {n_invalid_actions} < {dim}" 22 | 23 | space = spaces.Discrete(dim) 24 | self.n_invalid_actions = n_invalid_actions 25 | self.possible_actions = np.arange(space.n) 26 | self.invalid_actions: List[int] = [] 27 | super().__init__(space=space, ep_length=ep_length) 28 | 29 | def _choose_next_state(self) -> None: 30 | self.state = self.action_space.sample() 31 | # Randomly choose invalid actions that are not the current state 32 | potential_invalid_actions = [i for i in self.possible_actions if i != self.state] 33 | self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False) 34 | 35 | def action_masks(self) -> List[bool]: 36 | return [action not in self.invalid_actions for action in self.possible_actions] 37 | 38 | 39 | class InvalidActionEnvMultiDiscrete(IdentityEnv): 40 | """ 41 | Identity env with a multidiscrete action space. Supports action masking. 42 | """ 43 | 44 | def __init__( 45 | self, 46 | dims: Optional[List[int]] = None, 47 | ep_length: int = 100, 48 | n_invalid_actions: int = 0, 49 | ): 50 | if dims is None: 51 | dims = [1, 1] 52 | 53 | if n_invalid_actions > sum(dims) - len(dims): 54 | raise ValueError(f"Cannot find a valid action for each dim. Set n_invalid_actions <= {sum(dims) - len(dims)}") 55 | 56 | space = spaces.MultiDiscrete(dims) 57 | self.n_invalid_actions = n_invalid_actions 58 | self.possible_actions = np.arange(sum(dims)) 59 | self.invalid_actions: List[int] = [] 60 | super().__init__(space=space, ep_length=ep_length) 61 | 62 | def _choose_next_state(self) -> None: 63 | self.state = self.action_space.sample() 64 | 65 | converted_state: List[int] = [] 66 | running_total = 0 67 | for i in range(len(self.action_space.nvec)): 68 | converted_state.append(running_total + self.state[i]) 69 | running_total += self.action_space.nvec[i] 70 | 71 | # Randomly choose invalid actions that are not the current state 72 | potential_invalid_actions = [i for i in self.possible_actions if i not in converted_state] 73 | self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False) 74 | 75 | def action_masks(self) -> List[bool]: 76 | return [action not in self.invalid_actions for action in self.possible_actions] 77 | 78 | 79 | class InvalidActionEnvMultiBinary(IdentityEnv): 80 | """ 81 | Identity env with a multibinary action space. Supports action masking. 82 | """ 83 | 84 | def __init__( 85 | self, 86 | dims: Optional[int] = None, 87 | ep_length: int = 100, 88 | n_invalid_actions: int = 0, 89 | ): 90 | if dims is None: 91 | dims = 1 92 | 93 | if n_invalid_actions > dims: 94 | raise ValueError(f"Cannot find a valid action for each dim. Set n_invalid_actions <= {dims}") 95 | 96 | space = spaces.MultiBinary(dims) 97 | self.n_invalid_actions = n_invalid_actions 98 | self.possible_actions = np.arange(2 * dims) 99 | self.invalid_actions: List[int] = [] 100 | super().__init__(space=space, ep_length=ep_length) 101 | 102 | def _choose_next_state(self) -> None: 103 | self.state = self.action_space.sample() 104 | 105 | converted_state: List[int] = [] 106 | running_total = 0 107 | for i in range(self.action_space.n): 108 | converted_state.append(running_total + self.state[i]) 109 | running_total += 2 110 | 111 | # Randomly choose invalid actions that are not the current state 112 | potential_invalid_actions = [i for i in self.possible_actions if i not in converted_state] 113 | self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False) 114 | 115 | def action_masks(self) -> List[bool]: 116 | return [action not in self.invalid_actions for action in self.possible_actions] 117 | -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/common/maskable/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RobustFieldAutonomyLab/Stochastic_Road_Network/1cdcd41c7311560bf7a5df0d4d8bca829fe2b958/thirdparty/sb3_contrib/common/maskable/__init__.py -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/common/maskable/buffers.py: -------------------------------------------------------------------------------- 1 | from typing import Generator, NamedTuple, Optional, Union 2 | 3 | import numpy as np 4 | import torch as th 5 | from gym import spaces 6 | from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer 7 | from stable_baselines3.common.type_aliases import TensorDict 8 | from stable_baselines3.common.vec_env import VecNormalize 9 | 10 | 11 | class MaskableRolloutBufferSamples(NamedTuple): 12 | observations: th.Tensor 13 | actions: th.Tensor 14 | old_values: th.Tensor 15 | old_log_prob: th.Tensor 16 | advantages: th.Tensor 17 | returns: th.Tensor 18 | action_masks: th.Tensor 19 | 20 | 21 | class MaskableDictRolloutBufferSamples(MaskableRolloutBufferSamples): 22 | observations: TensorDict 23 | actions: th.Tensor 24 | old_values: th.Tensor 25 | old_log_prob: th.Tensor 26 | advantages: th.Tensor 27 | returns: th.Tensor 28 | action_masks: th.Tensor 29 | 30 | 31 | class MaskableRolloutBuffer(RolloutBuffer): 32 | """ 33 | Rollout buffer that also stores the invalid action masks associated with each observation. 34 | 35 | :param buffer_size: Max number of element in the buffer 36 | :param observation_space: Observation space 37 | :param action_space: Action space 38 | :param device: 39 | :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator 40 | Equivalent to classic advantage when set to 1. 41 | :param gamma: Discount factor 42 | :param n_envs: Number of parallel environments 43 | """ 44 | 45 | def __init__(self, *args, **kwargs): 46 | self.action_masks = None 47 | super().__init__(*args, **kwargs) 48 | 49 | def reset(self) -> None: 50 | if isinstance(self.action_space, spaces.Discrete): 51 | mask_dims = self.action_space.n 52 | elif isinstance(self.action_space, spaces.MultiDiscrete): 53 | mask_dims = sum(self.action_space.nvec) 54 | elif isinstance(self.action_space, spaces.MultiBinary): 55 | mask_dims = 2 * self.action_space.n # One mask per binary outcome 56 | else: 57 | raise ValueError(f"Unsupported action space {type(self.action_space)}") 58 | 59 | self.mask_dims = mask_dims 60 | self.action_masks = np.ones((self.buffer_size, self.n_envs, self.mask_dims), dtype=np.float32) 61 | 62 | super().reset() 63 | 64 | def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None: 65 | """ 66 | :param action_masks: Masks applied to constrain the choice of possible actions. 67 | """ 68 | if action_masks is not None: 69 | self.action_masks[self.pos] = action_masks.reshape((self.n_envs, self.mask_dims)) 70 | 71 | super().add(*args, **kwargs) 72 | 73 | def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]: 74 | assert self.full, "" 75 | indices = np.random.permutation(self.buffer_size * self.n_envs) 76 | # Prepare the data 77 | if not self.generator_ready: 78 | for tensor in [ 79 | "observations", 80 | "actions", 81 | "values", 82 | "log_probs", 83 | "advantages", 84 | "returns", 85 | "action_masks", 86 | ]: 87 | self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) 88 | self.generator_ready = True 89 | 90 | # Return everything, don't create minibatches 91 | if batch_size is None: 92 | batch_size = self.buffer_size * self.n_envs 93 | 94 | start_idx = 0 95 | while start_idx < self.buffer_size * self.n_envs: 96 | yield self._get_samples(indices[start_idx : start_idx + batch_size]) 97 | start_idx += batch_size 98 | 99 | def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples: 100 | data = ( 101 | self.observations[batch_inds], 102 | self.actions[batch_inds], 103 | self.values[batch_inds].flatten(), 104 | self.log_probs[batch_inds].flatten(), 105 | self.advantages[batch_inds].flatten(), 106 | self.returns[batch_inds].flatten(), 107 | self.action_masks[batch_inds].reshape(-1, self.mask_dims), 108 | ) 109 | return MaskableRolloutBufferSamples(*map(self.to_torch, data)) 110 | 111 | 112 | class MaskableDictRolloutBuffer(DictRolloutBuffer): 113 | """ 114 | Dict Rollout buffer used in on-policy algorithms like A2C/PPO. 115 | Extends the RolloutBuffer to use dictionary observations 116 | 117 | It corresponds to ``buffer_size`` transitions collected 118 | using the current policy. 119 | This experience will be discarded after the policy update. 120 | In order to use PPO objective, we also store the current value of each state 121 | and the log probability of each taken action. 122 | 123 | The term rollout here refers to the model-free notion and should not 124 | be used with the concept of rollout used in model-based RL or planning. 125 | Hence, it is only involved in policy and value function training but not action selection. 126 | 127 | :param buffer_size: Max number of element in the buffer 128 | :param observation_space: Observation space 129 | :param action_space: Action space 130 | :param device: 131 | :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator 132 | Equivalent to classic advantage when set to 1. 133 | :param gamma: Discount factor 134 | :param n_envs: Number of parallel environments 135 | """ 136 | 137 | def __init__( 138 | self, 139 | buffer_size: int, 140 | observation_space: spaces.Space, 141 | action_space: spaces.Space, 142 | device: Union[th.device, str] = "cpu", 143 | gae_lambda: float = 1, 144 | gamma: float = 0.99, 145 | n_envs: int = 1, 146 | ): 147 | self.action_masks = None 148 | super(MaskableDictRolloutBuffer, self).__init__( 149 | buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs 150 | ) 151 | 152 | def reset(self) -> None: 153 | if isinstance(self.action_space, spaces.Discrete): 154 | mask_dims = self.action_space.n 155 | elif isinstance(self.action_space, spaces.MultiDiscrete): 156 | mask_dims = sum(self.action_space.nvec) 157 | elif isinstance(self.action_space, spaces.MultiBinary): 158 | mask_dims = 2 * self.action_space.n # One mask per binary outcome 159 | else: 160 | raise ValueError(f"Unsupported action space {type(self.action_space)}") 161 | 162 | self.mask_dims = mask_dims 163 | self.action_masks = np.ones((self.buffer_size, self.n_envs, self.mask_dims), dtype=np.float32) 164 | 165 | super(MaskableDictRolloutBuffer, self).reset() 166 | 167 | def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None: 168 | """ 169 | :param action_masks: Masks applied to constrain the choice of possible actions. 170 | """ 171 | if action_masks is not None: 172 | self.action_masks[self.pos] = action_masks.reshape((self.n_envs, self.mask_dims)) 173 | 174 | super(MaskableDictRolloutBuffer, self).add(*args, **kwargs) 175 | 176 | def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]: 177 | assert self.full, "" 178 | indices = np.random.permutation(self.buffer_size * self.n_envs) 179 | # Prepare the data 180 | if not self.generator_ready: 181 | 182 | for key, obs in self.observations.items(): 183 | self.observations[key] = self.swap_and_flatten(obs) 184 | 185 | _tensor_names = ["actions", "values", "log_probs", "advantages", "returns", "action_masks"] 186 | 187 | for tensor in _tensor_names: 188 | self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) 189 | self.generator_ready = True 190 | 191 | # Return everything, don't create minibatches 192 | if batch_size is None: 193 | batch_size = self.buffer_size * self.n_envs 194 | 195 | start_idx = 0 196 | while start_idx < self.buffer_size * self.n_envs: 197 | yield self._get_samples(indices[start_idx : start_idx + batch_size]) 198 | start_idx += batch_size 199 | 200 | def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples: 201 | 202 | return MaskableDictRolloutBufferSamples( 203 | observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()}, 204 | actions=self.to_torch(self.actions[batch_inds]), 205 | old_values=self.to_torch(self.values[batch_inds].flatten()), 206 | old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), 207 | advantages=self.to_torch(self.advantages[batch_inds].flatten()), 208 | returns=self.to_torch(self.returns[batch_inds].flatten()), 209 | action_masks=self.to_torch(self.action_masks[batch_inds].reshape(-1, self.mask_dims)), 210 | ) 211 | -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/common/maskable/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from stable_baselines3.common.callbacks import EvalCallback 5 | from stable_baselines3.common.vec_env import sync_envs_normalization 6 | 7 | from sb3_contrib.common.maskable.evaluation import evaluate_policy 8 | 9 | 10 | class MaskableEvalCallback(EvalCallback): 11 | """ 12 | Callback for evaluating an agent. Supports invalid action masking. 13 | 14 | :param eval_env: The environment used for initialization 15 | :param callback_on_new_best: Callback to trigger 16 | when there is a new best model according to the ``mean_reward`` 17 | :param n_eval_episodes: The number of episodes to test the agent 18 | :param eval_freq: Evaluate the agent every eval_freq call of the callback. 19 | :param log_path: Path to a folder where the evaluations (``evaluations.npz``) 20 | will be saved. It will be updated at each evaluation. 21 | :param best_model_save_path: Path to a folder where the best model 22 | according to performance on the eval env will be saved. 23 | :param deterministic: Whether the evaluation should 24 | use a stochastic or deterministic actions. 25 | :param render: Whether to render or not the environment during evaluation 26 | :param verbose: 27 | :param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been 28 | wrapped with a Monitor wrapper) 29 | :param use_masking: Whether or not to use invalid action masks during evaluation 30 | """ 31 | 32 | def __init__(self, *args, use_masking: bool = True, **kwargs): 33 | super().__init__(*args, **kwargs) 34 | self.use_masking = use_masking 35 | 36 | def _on_step(self) -> bool: 37 | if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: 38 | # Sync training and eval env if there is VecNormalize 39 | sync_envs_normalization(self.training_env, self.eval_env) 40 | 41 | # Reset success rate buffer 42 | self._is_success_buffer = [] 43 | 44 | # Note that evaluate_policy() has been patched to support masking 45 | episode_rewards, episode_lengths = evaluate_policy( 46 | self.model, 47 | self.eval_env, 48 | n_eval_episodes=self.n_eval_episodes, 49 | render=self.render, 50 | deterministic=self.deterministic, 51 | return_episode_rewards=True, 52 | warn=self.warn, 53 | callback=self._log_success_callback, 54 | use_masking=self.use_masking, 55 | ) 56 | 57 | if self.log_path is not None: 58 | self.evaluations_timesteps.append(self.num_timesteps) 59 | self.evaluations_results.append(episode_rewards) 60 | self.evaluations_length.append(episode_lengths) 61 | 62 | kwargs = {} 63 | # Save success log if present 64 | if len(self._is_success_buffer) > 0: 65 | self.evaluations_successes.append(self._is_success_buffer) 66 | kwargs = dict(successes=self.evaluations_successes) 67 | 68 | np.savez( 69 | self.log_path, 70 | timesteps=self.evaluations_timesteps, 71 | results=self.evaluations_results, 72 | ep_lengths=self.evaluations_length, 73 | **kwargs, 74 | ) 75 | 76 | mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards) 77 | mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths) 78 | self.last_mean_reward = mean_reward 79 | 80 | if self.verbose > 0: 81 | print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}") 82 | print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}") 83 | # Add to current Logger 84 | self.logger.record("eval/mean_reward", float(mean_reward)) 85 | self.logger.record("eval/mean_ep_length", mean_ep_length) 86 | 87 | if len(self._is_success_buffer) > 0: 88 | success_rate = np.mean(self._is_success_buffer) 89 | if self.verbose > 0: 90 | print(f"Success rate: {100 * success_rate:.2f}%") 91 | self.logger.record("eval/success_rate", success_rate) 92 | 93 | # Dump log so the evaluation results are printed with the correct timestep 94 | self.logger.record("time/total timesteps", self.num_timesteps, exclude="tensorboard") 95 | self.logger.dump(self.num_timesteps) 96 | 97 | if mean_reward > self.best_mean_reward: 98 | if self.verbose > 0: 99 | print("New best mean reward!") 100 | if self.best_model_save_path is not None: 101 | self.model.save(os.path.join(self.best_model_save_path, "best_model")) 102 | self.best_mean_reward = mean_reward 103 | # Trigger callback if needed 104 | if self.callback is not None: 105 | return self._on_event() 106 | 107 | return True 108 | -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/common/maskable/evaluation.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 3 | 4 | import gym 5 | import numpy as np 6 | from stable_baselines3.common.monitor import Monitor 7 | from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped 8 | 9 | from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported 10 | from sb3_contrib.ppo_mask import MaskablePPO 11 | 12 | 13 | def evaluate_policy( # noqa: C901 14 | model: MaskablePPO, 15 | env: Union[gym.Env, VecEnv], 16 | n_eval_episodes: int = 10, 17 | deterministic: bool = True, 18 | render: bool = False, 19 | callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None, 20 | reward_threshold: Optional[float] = None, 21 | return_episode_rewards: bool = False, 22 | warn: bool = True, 23 | use_masking: bool = True, 24 | ) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]: 25 | """ 26 | Runs policy for ``n_eval_episodes`` episodes and returns average reward. 27 | If a vector env is passed in, this divides the episodes to evaluate onto the 28 | different elements of the vector env. This static division of work is done to 29 | remove bias. See https://github.com/DLR-RM/stable-baselines3/issues/402 for more 30 | details and discussion. 31 | 32 | .. note:: 33 | If environment has not been wrapped with ``Monitor`` wrapper, reward and 34 | episode lengths are counted as it appears with ``env.step`` calls. If 35 | the environment contains wrappers that modify rewards or episode lengths 36 | (e.g. reward scaling, early episode reset), these will affect the evaluation 37 | results as well. You can avoid this by wrapping environment with ``Monitor`` 38 | wrapper before anything else. 39 | 40 | :param model: The RL agent you want to evaluate. 41 | :param env: The gym environment. In the case of a ``VecEnv`` 42 | this must contain only one environment. 43 | :param n_eval_episodes: Number of episode to evaluate the agent 44 | :param deterministic: Whether to use deterministic or stochastic actions 45 | :param render: Whether to render the environment or not 46 | :param callback: callback function to do additional checks, 47 | called after each step. Gets locals() and globals() passed as parameters. 48 | :param reward_threshold: Minimum expected reward per episode, 49 | this will raise an error if the performance is not met 50 | :param return_episode_rewards: If True, a list of rewards and episde lengths 51 | per episode will be returned instead of the mean. 52 | :param warn: If True (default), warns user about lack of a Monitor wrapper in the 53 | evaluation environment. 54 | :param use_masking: Whether or not to use invalid action masks during evaluation 55 | :return: Mean reward per episode, std of reward per episode. 56 | Returns ([float], [int]) when ``return_episode_rewards`` is True, first 57 | list containing per-episode rewards and second containing per-episode lengths 58 | (in number of steps). 59 | """ 60 | 61 | if use_masking and not is_masking_supported(env): 62 | raise ValueError("Environment does not support action masking. Consider using ActionMasker wrapper") 63 | 64 | is_monitor_wrapped = False 65 | 66 | if not isinstance(env, VecEnv): 67 | env = DummyVecEnv([lambda: env]) 68 | 69 | is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0] 70 | 71 | if not is_monitor_wrapped and warn: 72 | warnings.warn( 73 | "Evaluation environment is not wrapped with a ``Monitor`` wrapper. " 74 | "This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. " 75 | "Consider wrapping environment first with ``Monitor`` wrapper.", 76 | UserWarning, 77 | ) 78 | 79 | n_envs = env.num_envs 80 | episode_rewards = [] 81 | episode_lengths = [] 82 | 83 | episode_counts = np.zeros(n_envs, dtype="int") 84 | # Divides episodes among different sub environments in the vector as evenly as possible 85 | episode_count_targets = np.array([(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype="int") 86 | 87 | current_rewards = np.zeros(n_envs) 88 | current_lengths = np.zeros(n_envs, dtype="int") 89 | observations = env.reset() 90 | states = None 91 | 92 | while (episode_counts < episode_count_targets).any(): 93 | if use_masking: 94 | action_masks = get_action_masks(env) 95 | actions, state = model.predict( 96 | observations, 97 | state=states, 98 | deterministic=deterministic, 99 | action_masks=action_masks, 100 | ) 101 | else: 102 | actions, states = model.predict(observations, state=states, deterministic=deterministic) 103 | observations, rewards, dones, infos = env.step(actions) 104 | current_rewards += rewards 105 | current_lengths += 1 106 | for i in range(n_envs): 107 | if episode_counts[i] < episode_count_targets[i]: 108 | 109 | # unpack values so that the callback can access the local variables 110 | reward = rewards[i] 111 | done = dones[i] 112 | info = infos[i] 113 | 114 | if callback is not None: 115 | callback(locals(), globals()) 116 | 117 | if dones[i]: 118 | if is_monitor_wrapped: 119 | # Atari wrapper can send a "done" signal when 120 | # the agent loses a life, but it does not correspond 121 | # to the true end of episode 122 | if "episode" in info.keys(): 123 | # Do not trust "done" with episode endings. 124 | # Monitor wrapper includes "episode" key in info if environment 125 | # has been wrapped with it. Use those rewards instead. 126 | episode_rewards.append(info["episode"]["r"]) 127 | episode_lengths.append(info["episode"]["l"]) 128 | # Only increment at the real end of an episode 129 | episode_counts[i] += 1 130 | else: 131 | episode_rewards.append(current_rewards[i]) 132 | episode_lengths.append(current_lengths[i]) 133 | episode_counts[i] += 1 134 | current_rewards[i] = 0 135 | current_lengths[i] = 0 136 | if states is not None: 137 | states[i] *= 0 138 | 139 | if render: 140 | env.render() 141 | 142 | mean_reward = np.mean(episode_rewards) 143 | std_reward = np.std(episode_rewards) 144 | if reward_threshold is not None: 145 | assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}" 146 | if return_episode_rewards: 147 | return episode_rewards, episode_lengths 148 | return mean_reward, std_reward 149 | -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/common/maskable/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from stable_baselines3.common.type_aliases import GymEnv 3 | from stable_baselines3.common.vec_env import VecEnv 4 | 5 | EXPECTED_METHOD_NAME = "action_masks" 6 | 7 | 8 | def get_action_masks(env: GymEnv) -> np.ndarray: 9 | """ 10 | Checks whether gym env exposes a method returning invalid action masks 11 | 12 | :param env: the Gym environment to get masks from 13 | :return: A numpy array of the masks 14 | """ 15 | 16 | if isinstance(env, VecEnv): 17 | return np.stack(env.env_method(EXPECTED_METHOD_NAME)) 18 | else: 19 | return getattr(env, EXPECTED_METHOD_NAME)() 20 | 21 | 22 | def is_masking_supported(env: GymEnv) -> bool: 23 | """ 24 | Checks whether gym env exposes a method returning invalid action masks 25 | 26 | :param env: the Gym environment to check 27 | :return: True if the method is found, False otherwise 28 | """ 29 | 30 | if isinstance(env, VecEnv): 31 | try: 32 | # TODO: add VecEnv.has_attr() 33 | env.get_attr(EXPECTED_METHOD_NAME) 34 | return True 35 | except AttributeError: 36 | return False 37 | else: 38 | return hasattr(env, EXPECTED_METHOD_NAME) 39 | -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/common/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Sequence 2 | 3 | import torch as th 4 | from torch import nn 5 | 6 | 7 | def quantile_huber_loss( 8 | current_quantiles: th.Tensor, 9 | target_quantiles: th.Tensor, 10 | cum_prob: Optional[th.Tensor] = None, 11 | sum_over_quantiles: bool = True, 12 | ) -> th.Tensor: 13 | """ 14 | The quantile-regression loss, as described in the QR-DQN and TQC papers. 15 | Partially taken from https://github.com/bayesgroup/tqc_pytorch. 16 | 17 | :param current_quantiles: current estimate of quantiles, must be either 18 | (batch_size, n_quantiles) or (batch_size, n_critics, n_quantiles) 19 | :param target_quantiles: target of quantiles, must be either (batch_size, n_target_quantiles), 20 | (batch_size, 1, n_target_quantiles), or (batch_size, n_critics, n_target_quantiles) 21 | :param cum_prob: cumulative probabilities to calculate quantiles (also called midpoints in QR-DQN paper), 22 | must be either (batch_size, n_quantiles), (batch_size, 1, n_quantiles), or (batch_size, n_critics, n_quantiles). 23 | (if None, calculating unit quantiles) 24 | :param sum_over_quantiles: if summing over the quantile dimension or not 25 | :return: the loss 26 | """ 27 | if current_quantiles.ndim != target_quantiles.ndim: 28 | raise ValueError( 29 | f"Error: The dimension of curremt_quantile ({current_quantiles.ndim}) needs to match " 30 | f"the dimension of target_quantiles ({target_quantiles.ndim})." 31 | ) 32 | if current_quantiles.shape[0] != target_quantiles.shape[0]: 33 | raise ValueError( 34 | f"Error: The batch size of curremt_quantile ({current_quantiles.shape[0]}) needs to match " 35 | f"the batch size of target_quantiles ({target_quantiles.shape[0]})." 36 | ) 37 | if current_quantiles.ndim not in (2, 3): 38 | raise ValueError(f"Error: The dimension of current_quantiles ({current_quantiles.ndim}) needs to be either 2 or 3.") 39 | 40 | if cum_prob is None: 41 | n_quantiles = current_quantiles.shape[-1] 42 | # Cumulative probabilities to calculate quantiles. 43 | cum_prob = (th.arange(n_quantiles, device=current_quantiles.device, dtype=th.float) + 0.5) / n_quantiles 44 | if current_quantiles.ndim == 2: 45 | # For QR-DQN, current_quantiles have a shape (batch_size, n_quantiles), and make cum_prob 46 | # broadcastable to (batch_size, n_quantiles, n_target_quantiles) 47 | cum_prob = cum_prob.view(1, -1, 1) 48 | elif current_quantiles.ndim == 3: 49 | # For TQC, current_quantiles have a shape (batch_size, n_critics, n_quantiles), and make cum_prob 50 | # broadcastable to (batch_size, n_critics, n_quantiles, n_target_quantiles) 51 | cum_prob = cum_prob.view(1, 1, -1, 1) 52 | 53 | # QR-DQN 54 | # target_quantiles: (batch_size, n_target_quantiles) -> (batch_size, 1, n_target_quantiles) 55 | # current_quantiles: (batch_size, n_quantiles) -> (batch_size, n_quantiles, 1) 56 | # pairwise_delta: (batch_size, n_target_quantiles, n_quantiles) 57 | # TQC 58 | # target_quantiles: (batch_size, 1, n_target_quantiles) -> (batch_size, 1, 1, n_target_quantiles) 59 | # current_quantiles: (batch_size, n_critics, n_quantiles) -> (batch_size, n_critics, n_quantiles, 1) 60 | # pairwise_delta: (batch_size, n_critics, n_quantiles, n_target_quantiles) 61 | # Note: in both cases, the loss has the same shape as pairwise_delta 62 | pairwise_delta = target_quantiles.unsqueeze(-2) - current_quantiles.unsqueeze(-1) 63 | abs_pairwise_delta = th.abs(pairwise_delta) 64 | huber_loss = th.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta**2 * 0.5) 65 | loss = th.abs(cum_prob - (pairwise_delta.detach() < 0).float()) * huber_loss 66 | if sum_over_quantiles: 67 | loss = loss.sum(dim=-2).mean() 68 | else: 69 | loss = loss.mean() 70 | return loss 71 | 72 | 73 | def conjugate_gradient_solver( 74 | matrix_vector_dot_fn: Callable[[th.Tensor], th.Tensor], 75 | b, 76 | max_iter=10, 77 | residual_tol=1e-10, 78 | ) -> th.Tensor: 79 | """ 80 | Finds an approximate solution to a set of linear equations Ax = b 81 | 82 | Sources: 83 | - https://github.com/ajlangley/trpo-pytorch/blob/master/conjugate_gradient.py 84 | - https://github.com/joschu/modular_rl/blob/master/modular_rl/trpo.py#L122 85 | 86 | Reference: 87 | - https://epubs.siam.org/doi/abs/10.1137/1.9781611971446.ch6 88 | 89 | :param matrix_vector_dot_fn: 90 | a function that right multiplies a matrix A by a vector v 91 | :param b: 92 | the right hand term in the set of linear equations Ax = b 93 | :param max_iter: 94 | the maximum number of iterations (default is 10) 95 | :param residual_tol: 96 | residual tolerance for early stopping of the solving (default is 1e-10) 97 | :return x: 98 | the approximate solution to the system of equations defined by `matrix_vector_dot_fn` 99 | and b 100 | """ 101 | 102 | # The vector is not initialized at 0 because of the instability issues when the gradient becomes small. 103 | # A small random gaussian noise is used for the initialization. 104 | x = 1e-4 * th.randn_like(b) 105 | residual = b - matrix_vector_dot_fn(x) 106 | # Equivalent to th.linalg.norm(residual) ** 2 (L2 norm squared) 107 | residual_squared_norm = th.matmul(residual, residual) 108 | 109 | if residual_squared_norm < residual_tol: 110 | # If the gradient becomes extremely small 111 | # The denominator in alpha will become zero 112 | # Leading to a division by zero 113 | return x 114 | 115 | p = residual.clone() 116 | 117 | for i in range(max_iter): 118 | # A @ p (matrix vector multiplication) 119 | A_dot_p = matrix_vector_dot_fn(p) 120 | 121 | alpha = residual_squared_norm / p.dot(A_dot_p) 122 | x += alpha * p 123 | 124 | if i == max_iter - 1: 125 | return x 126 | 127 | residual -= alpha * A_dot_p 128 | new_residual_squared_norm = th.matmul(residual, residual) 129 | 130 | if new_residual_squared_norm < residual_tol: 131 | return x 132 | 133 | beta = new_residual_squared_norm / residual_squared_norm 134 | residual_squared_norm = new_residual_squared_norm 135 | p = residual + beta * p 136 | 137 | 138 | def flat_grad( 139 | output, 140 | parameters: Sequence[nn.parameter.Parameter], 141 | create_graph: bool = False, 142 | retain_graph: bool = False, 143 | ) -> th.Tensor: 144 | """ 145 | Returns the gradients of the passed sequence of parameters into a flat gradient. 146 | Order of parameters is preserved. 147 | 148 | :param output: functional output to compute the gradient for 149 | :param parameters: sequence of ``Parameter`` 150 | :param retain_graph: – If ``False``, the graph used to compute the grad will be freed. 151 | Defaults to the value of ``create_graph``. 152 | :param create_graph: – If ``True``, graph of the derivative will be constructed, 153 | allowing to compute higher order derivative products. Default: ``False``. 154 | :return: Tensor containing the flattened gradients 155 | """ 156 | grads = th.autograd.grad( 157 | output, 158 | parameters, 159 | create_graph=create_graph, 160 | retain_graph=retain_graph, 161 | allow_unused=True, 162 | ) 163 | return th.cat([th.ravel(grad) for grad in grads if grad is not None]) 164 | -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/common/vec_env/__init__.py: -------------------------------------------------------------------------------- 1 | from sb3_contrib.common.vec_env.async_eval import AsyncEval 2 | -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/common/vec_env/async_eval.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import multiprocessing as mp 3 | from collections import defaultdict 4 | from typing import Callable, List, Optional, Tuple, Union 5 | 6 | import numpy as np 7 | import torch as th 8 | from stable_baselines3.common.evaluation import evaluate_policy 9 | from stable_baselines3.common.policies import BasePolicy 10 | from stable_baselines3.common.running_mean_std import RunningMeanStd 11 | from stable_baselines3.common.vec_env import VecEnv, unwrap_vec_normalize 12 | from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper 13 | 14 | 15 | def _worker( 16 | remote: mp.connection.Connection, 17 | parent_remote: mp.connection.Connection, 18 | worker_env_wrapper: CloudpickleWrapper, 19 | train_policy_wrapper: CloudpickleWrapper, 20 | n_eval_episodes: int = 1, 21 | ) -> None: 22 | """ 23 | Function that will be run in each process. 24 | It is in charge of creating environments, evaluating candidates 25 | and communicating with the main process. 26 | 27 | :param remote: Pipe to communicate with the parent process. 28 | :param parent_remote: 29 | :param worker_env_wrapper: Callable used to create the environment inside the process. 30 | :param train_policy_wrapper: Callable used to create the policy inside the process. 31 | :param n_eval_episodes: Number of evaluation episodes per candidate. 32 | """ 33 | parent_remote.close() 34 | env = worker_env_wrapper.var() 35 | train_policy = train_policy_wrapper.var 36 | vec_normalize = unwrap_vec_normalize(env) 37 | if vec_normalize is not None: 38 | obs_rms = vec_normalize.obs_rms 39 | else: 40 | obs_rms = None 41 | while True: 42 | try: 43 | cmd, data = remote.recv() 44 | if cmd == "eval": 45 | results = [] 46 | # Evaluate each candidate and save results 47 | for weights_idx, candidate_weights in data: 48 | train_policy.load_from_vector(candidate_weights.cpu()) 49 | episode_rewards, episode_lengths = evaluate_policy( 50 | train_policy, 51 | env, 52 | n_eval_episodes=n_eval_episodes, 53 | return_episode_rewards=True, 54 | warn=False, 55 | ) 56 | results.append((weights_idx, (episode_rewards, episode_lengths))) 57 | remote.send(results) 58 | elif cmd == "seed": 59 | remote.send(env.seed(data)) 60 | elif cmd == "get_obs_rms": 61 | remote.send(obs_rms) 62 | elif cmd == "sync_obs_rms": 63 | vec_normalize.obs_rms = data 64 | obs_rms = data 65 | elif cmd == "close": 66 | env.close() 67 | remote.close() 68 | break 69 | else: 70 | raise NotImplementedError(f"`{cmd}` is not implemented in the worker") 71 | except EOFError: 72 | break 73 | 74 | 75 | class AsyncEval(object): 76 | """ 77 | Helper class to do asynchronous evaluation of different policies with multiple processes. 78 | It is useful when implementing population based methods like Evolution Strategies (ES), 79 | Cross Entropy Method (CEM) or Augmented Random Search (ARS). 80 | 81 | .. warning:: 82 | 83 | Only 'forkserver' and 'spawn' start methods are thread-safe, 84 | which is important to avoid race conditions. 85 | However, compared to 86 | 'fork' they incur a small start-up cost and have restrictions on 87 | global variables. With those methods, users must wrap the code in an 88 | ``if __name__ == "__main__":`` block. 89 | For more information, see the multiprocessing documentation. 90 | 91 | :param envs_fn: Vectorized environments to run in subprocesses (callable) 92 | :param train_policy: The policy object that will load the different candidate 93 | weights. 94 | :param start_method: method used to start the subprocesses. 95 | Must be one of the methods returned by ``multiprocessing.get_all_start_methods()``. 96 | Defaults to 'forkserver' on available platforms, and 'spawn' otherwise. 97 | :param n_eval_episodes: The number of episodes to test each agent 98 | """ 99 | 100 | def __init__( 101 | self, 102 | envs_fn: List[Callable[[], VecEnv]], 103 | train_policy: BasePolicy, 104 | start_method: Optional[str] = None, 105 | n_eval_episodes: int = 1, 106 | ): 107 | self.waiting = False 108 | self.closed = False 109 | n_envs = len(envs_fn) 110 | 111 | if start_method is None: 112 | # Fork is not a thread safe method (see issue #217) 113 | # but is more user friendly (does not require to wrap the code in 114 | # a `if __name__ == "__main__":`) 115 | forkserver_available = "forkserver" in mp.get_all_start_methods() 116 | start_method = "forkserver" if forkserver_available else "spawn" 117 | ctx = mp.get_context(start_method) 118 | 119 | self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)]) 120 | self.processes = [] 121 | for work_remote, remote, worker_env in zip(self.work_remotes, self.remotes, envs_fn): 122 | args = ( 123 | work_remote, 124 | remote, 125 | CloudpickleWrapper(worker_env), 126 | CloudpickleWrapper(train_policy), 127 | n_eval_episodes, 128 | ) 129 | # daemon=True: if the main process crashes, we should not cause things to hang 130 | process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error 131 | process.start() 132 | self.processes.append(process) 133 | work_remote.close() 134 | 135 | def send_jobs(self, candidate_weights: th.Tensor, pop_size: int) -> None: 136 | """ 137 | Send jobs to the workers to evaluate new candidates. 138 | 139 | :param candidate_weights: The weights to be evaluated. 140 | :pop_size: The number of candidate (size of the population) 141 | """ 142 | jobs_per_worker = defaultdict(list) 143 | for weights_idx in range(pop_size): 144 | jobs_per_worker[weights_idx % len(self.remotes)].append((weights_idx, candidate_weights[weights_idx])) 145 | 146 | for remote_idx, remote in enumerate(self.remotes): 147 | remote.send(("eval", jobs_per_worker[remote_idx])) 148 | self.waiting = True 149 | 150 | def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: 151 | """ 152 | Seed the environments. 153 | 154 | :param seed: The seed for the pseudo-random generators. 155 | :return: 156 | """ 157 | for idx, remote in enumerate(self.remotes): 158 | remote.send(("seed", seed + idx)) 159 | return [remote.recv() for remote in self.remotes] 160 | 161 | def get_results(self) -> List[Tuple[int, Tuple[np.ndarray, np.ndarray]]]: 162 | """ 163 | Retreive episode rewards and lengths from each worker 164 | for all candidates (there might be multiple candidates per worker) 165 | 166 | :return: A list of tuples containing each candidate index and its 167 | result (episodic reward and episode length) 168 | """ 169 | results = [remote.recv() for remote in self.remotes] 170 | flat_results = [result for worker_results in results for result in worker_results] 171 | self.waiting = False 172 | return flat_results 173 | 174 | def get_obs_rms(self) -> List[RunningMeanStd]: 175 | """ 176 | Retrieve the observation filters (observation running mean std) 177 | of each process, they will be combined in the main process. 178 | Synchronisation is done afterward using ``sync_obs_rms()``. 179 | :return: A list of ``RunningMeanStd`` objects (one per process) 180 | """ 181 | for remote in self.remotes: 182 | remote.send(("get_obs_rms", None)) 183 | return [remote.recv() for remote in self.remotes] 184 | 185 | def sync_obs_rms(self, obs_rms: RunningMeanStd) -> None: 186 | """ 187 | Synchronise (and update) the observation filters 188 | (observation running mean std) 189 | :param obs_rms: The updated ``RunningMeanStd`` to be used 190 | by workers for normalizing observations. 191 | """ 192 | for remote in self.remotes: 193 | remote.send(("sync_obs_rms", obs_rms)) 194 | 195 | def close(self) -> None: 196 | """ 197 | Close the processes. 198 | """ 199 | if self.closed: 200 | return 201 | if self.waiting: 202 | for remote in self.remotes: 203 | remote.recv() 204 | for remote in self.remotes: 205 | remote.send(("close", None)) 206 | for process in self.processes: 207 | process.join() 208 | self.closed = True 209 | -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/common/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from sb3_contrib.common.wrappers.action_masker import ActionMasker 2 | from sb3_contrib.common.wrappers.time_feature import TimeFeatureWrapper 3 | -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/common/wrappers/action_masker.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Union 2 | 3 | import gym 4 | import numpy as np 5 | 6 | 7 | class ActionMasker(gym.Wrapper): 8 | """ 9 | Env wrapper providing the method required to support masking. 10 | 11 | Exposes a method called action_masks(), which returns masks for the wrapped env. 12 | This wrapper is not needed if the env exposes the expected method itself. 13 | 14 | :param env: the Gym environment to wrap 15 | :param action_mask_fn: A function that takes a Gym environment and returns an action mask, 16 | or the name of such a method provided by the environment. 17 | """ 18 | 19 | def __init__(self, env: gym.Env, action_mask_fn: Union[str, Callable[[gym.Env], np.ndarray]]): 20 | super().__init__(env) 21 | 22 | if isinstance(action_mask_fn, str): 23 | found_method = getattr(self.env, action_mask_fn) 24 | if not callable(found_method): 25 | raise ValueError(f"Environment attribute {action_mask_fn} is not a method") 26 | 27 | self._action_mask_fn = found_method 28 | else: 29 | self._action_mask_fn = action_mask_fn 30 | 31 | def action_masks(self) -> np.ndarray: 32 | return self._action_mask_fn(self.env) 33 | -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/common/wrappers/time_feature.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | 3 | import gym 4 | import numpy as np 5 | from stable_baselines3.common.type_aliases import GymObs, GymStepReturn 6 | 7 | 8 | class TimeFeatureWrapper(gym.Wrapper): 9 | """ 10 | Add remaining, normalized time to observation space for fixed length episodes. 11 | See https://arxiv.org/abs/1712.00378 and https://github.com/aravindr93/mjrl/issues/13. 12 | 13 | .. note:: 14 | 15 | Only ``gym.spaces.Box`` and ``gym.spaces.Dict`` (``gym.GoalEnv``) 1D observation spaces 16 | are supported for now. 17 | 18 | :param env: Gym env to wrap. 19 | :param max_steps: Max number of steps of an episode 20 | if it is not wrapped in a ``TimeLimit`` object. 21 | :param test_mode: In test mode, the time feature is constant, 22 | equal to zero. This allow to check that the agent did not overfit this feature, 23 | learning a deterministic pre-defined sequence of actions. 24 | """ 25 | 26 | def __init__(self, env: gym.Env, max_steps: int = 1000, test_mode: bool = False): 27 | assert isinstance( 28 | env.observation_space, (gym.spaces.Box, gym.spaces.Dict) 29 | ), "`TimeFeatureWrapper` only supports `gym.spaces.Box` and `gym.spaces.Dict` (`gym.GoalEnv`) observation spaces." 30 | 31 | # Add a time feature to the observation 32 | if isinstance(env.observation_space, gym.spaces.Dict): 33 | assert "observation" in env.observation_space.spaces, "No `observation` key in the observation space" 34 | obs_space = env.observation_space.spaces["observation"] 35 | assert isinstance( 36 | obs_space, gym.spaces.Box 37 | ), "`TimeFeatureWrapper` only supports `gym.spaces.Box` observation space." 38 | obs_space = env.observation_space.spaces["observation"] 39 | else: 40 | obs_space = env.observation_space 41 | 42 | assert len(obs_space.shape) == 1, "Only 1D observation spaces are supported" 43 | 44 | low, high = obs_space.low, obs_space.high 45 | low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0])) 46 | self.dtype = obs_space.dtype 47 | 48 | if isinstance(env.observation_space, gym.spaces.Dict): 49 | env.observation_space.spaces["observation"] = gym.spaces.Box(low=low, high=high, dtype=self.dtype) 50 | else: 51 | env.observation_space = gym.spaces.Box(low=low, high=high, dtype=self.dtype) 52 | 53 | super(TimeFeatureWrapper, self).__init__(env) 54 | 55 | # Try to infer the max number of steps per episode 56 | try: 57 | self._max_steps = env.spec.max_episode_steps 58 | except AttributeError: 59 | self._max_steps = None 60 | 61 | # Fallback to provided value 62 | if self._max_steps is None: 63 | self._max_steps = max_steps 64 | 65 | self._current_step = 0 66 | self._test_mode = test_mode 67 | 68 | def reset(self) -> GymObs: 69 | self._current_step = 0 70 | return self._get_obs(self.env.reset()) 71 | 72 | def step(self, action: Union[int, np.ndarray]) -> GymStepReturn: 73 | self._current_step += 1 74 | obs, reward, done, info = self.env.step(action) 75 | return self._get_obs(obs), reward, done, info 76 | 77 | def _get_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]: 78 | """ 79 | Concatenate the time feature to the current observation. 80 | 81 | :param obs: 82 | :return: 83 | """ 84 | # Remaining time is more general 85 | time_feature = 1 - (self._current_step / self._max_steps) 86 | if self._test_mode: 87 | time_feature = 1.0 88 | time_feature = np.array(time_feature, dtype=self.dtype) 89 | 90 | if isinstance(obs, dict): 91 | obs["observation"] = np.append(obs["observation"], time_feature) 92 | return obs 93 | return np.append(obs, time_feature) 94 | -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/local_modifications.txt: -------------------------------------------------------------------------------- 1 | modified: sb3_contrib/qrdqn/policies.py 2 | sb3_contrib/qrdqn/qrdqn.py 3 | -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/ppo_mask/__init__.py: -------------------------------------------------------------------------------- 1 | from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 2 | from sb3_contrib.ppo_mask.ppo_mask import MaskablePPO 3 | -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/ppo_mask/policies.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.common.policies import register_policy 2 | 3 | from sb3_contrib.common.maskable.policies import ( 4 | MaskableActorCriticCnnPolicy, 5 | MaskableActorCriticPolicy, 6 | MaskableMultiInputActorCriticPolicy, 7 | ) 8 | 9 | MlpPolicy = MaskableActorCriticPolicy 10 | CnnPolicy = MaskableActorCriticCnnPolicy 11 | MultiInputPolicy = MaskableMultiInputActorCriticPolicy 12 | 13 | register_policy("MlpPolicy", MaskableActorCriticPolicy) 14 | register_policy("CnnPolicy", MaskableActorCriticCnnPolicy) 15 | register_policy("MultiInputPolicy", MaskableMultiInputActorCriticPolicy) 16 | -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RobustFieldAutonomyLab/Stochastic_Road_Network/1cdcd41c7311560bf7a5df0d4d8bca829fe2b958/thirdparty/sb3_contrib/py.typed -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/qrdqn/__init__.py: -------------------------------------------------------------------------------- 1 | from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 2 | from sb3_contrib.qrdqn.qrdqn import QRDQN 3 | -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/tqc/__init__.py: -------------------------------------------------------------------------------- 1 | from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 2 | from sb3_contrib.tqc.tqc import TQC 3 | -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/trpo/__init__.py: -------------------------------------------------------------------------------- 1 | from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 2 | from sb3_contrib.trpo.trpo import TRPO 3 | -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/trpo/policies.py: -------------------------------------------------------------------------------- 1 | # This file is here just to define MlpPolicy/CnnPolicy 2 | # that work for TRPO 3 | from stable_baselines3.common.policies import ( 4 | ActorCriticCnnPolicy, 5 | ActorCriticPolicy, 6 | MultiInputActorCriticPolicy, 7 | register_policy, 8 | ) 9 | 10 | MlpPolicy = ActorCriticPolicy 11 | CnnPolicy = ActorCriticCnnPolicy 12 | MultiInputPolicy = MultiInputActorCriticPolicy 13 | 14 | register_policy("MlpPolicy", ActorCriticPolicy) 15 | register_policy("CnnPolicy", ActorCriticCnnPolicy) 16 | register_policy("MultiInputPolicy", MultiInputPolicy) 17 | -------------------------------------------------------------------------------- /thirdparty/sb3_contrib/version.txt: -------------------------------------------------------------------------------- 1 | 1.5.0 2 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from stable_baselines3.a2c import A2C 4 | from stable_baselines3.common.utils import get_system_info 5 | from stable_baselines3.ddpg import DDPG 6 | from stable_baselines3.dqn import DQN 7 | from stable_baselines3.her.her_replay_buffer import HerReplayBuffer 8 | from stable_baselines3.ppo import PPO 9 | from stable_baselines3.sac import SAC 10 | from stable_baselines3.td3 import TD3 11 | 12 | # Read version from file 13 | version_file = os.path.join(os.path.dirname(__file__), "version.txt") 14 | with open(version_file, "r") as file_handler: 15 | __version__ = file_handler.read().strip() 16 | 17 | 18 | def HER(*args, **kwargs): 19 | raise ImportError( 20 | "Since Stable Baselines 2.1.0, `HER` is now a replay buffer class `HerReplayBuffer`.\n " 21 | "Please check the documentation for more information: https://stable-baselines3.readthedocs.io/" 22 | ) 23 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/a2c/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.a2c.a2c import A2C 2 | from stable_baselines3.a2c.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 3 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/a2c/policies.py: -------------------------------------------------------------------------------- 1 | # This file is here just to define MlpPolicy/CnnPolicy 2 | # that work for A2C 3 | from stable_baselines3.common.policies import ( 4 | ActorCriticCnnPolicy, 5 | ActorCriticPolicy, 6 | MultiInputActorCriticPolicy, 7 | register_policy, 8 | ) 9 | 10 | MlpPolicy = ActorCriticPolicy 11 | CnnPolicy = ActorCriticCnnPolicy 12 | MultiInputPolicy = MultiInputActorCriticPolicy 13 | 14 | register_policy("MlpPolicy", ActorCriticPolicy) 15 | register_policy("CnnPolicy", ActorCriticCnnPolicy) 16 | register_policy("MultiInputPolicy", MultiInputPolicy) 17 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RobustFieldAutonomyLab/Stochastic_Road_Network/1cdcd41c7311560bf7a5df0d4d8bca829fe2b958/thirdparty/stable_baselines3/common/__init__.py -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/atari_wrappers.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | from gym import spaces 4 | 5 | try: 6 | import cv2 # pytype:disable=import-error 7 | 8 | cv2.ocl.setUseOpenCL(False) 9 | except ImportError: 10 | cv2 = None 11 | 12 | from stable_baselines3.common.type_aliases import GymObs, GymStepReturn 13 | 14 | 15 | class NoopResetEnv(gym.Wrapper): 16 | """ 17 | Sample initial states by taking random number of no-ops on reset. 18 | No-op is assumed to be action 0. 19 | 20 | :param env: the environment to wrap 21 | :param noop_max: the maximum value of no-ops to run 22 | """ 23 | 24 | def __init__(self, env: gym.Env, noop_max: int = 30): 25 | gym.Wrapper.__init__(self, env) 26 | self.noop_max = noop_max 27 | self.override_num_noops = None 28 | self.noop_action = 0 29 | assert env.unwrapped.get_action_meanings()[0] == "NOOP" 30 | 31 | def reset(self, **kwargs) -> np.ndarray: 32 | self.env.reset(**kwargs) 33 | if self.override_num_noops is not None: 34 | noops = self.override_num_noops 35 | else: 36 | noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) 37 | assert noops > 0 38 | obs = np.zeros(0) 39 | for _ in range(noops): 40 | obs, _, done, _ = self.env.step(self.noop_action) 41 | if done: 42 | obs = self.env.reset(**kwargs) 43 | return obs 44 | 45 | 46 | class FireResetEnv(gym.Wrapper): 47 | """ 48 | Take action on reset for environments that are fixed until firing. 49 | 50 | :param env: the environment to wrap 51 | """ 52 | 53 | def __init__(self, env: gym.Env): 54 | gym.Wrapper.__init__(self, env) 55 | assert env.unwrapped.get_action_meanings()[1] == "FIRE" 56 | assert len(env.unwrapped.get_action_meanings()) >= 3 57 | 58 | def reset(self, **kwargs) -> np.ndarray: 59 | self.env.reset(**kwargs) 60 | obs, _, done, _ = self.env.step(1) 61 | if done: 62 | self.env.reset(**kwargs) 63 | obs, _, done, _ = self.env.step(2) 64 | if done: 65 | self.env.reset(**kwargs) 66 | return obs 67 | 68 | 69 | class EpisodicLifeEnv(gym.Wrapper): 70 | """ 71 | Make end-of-life == end-of-episode, but only reset on true game over. 72 | Done by DeepMind for the DQN and co. since it helps value estimation. 73 | 74 | :param env: the environment to wrap 75 | """ 76 | 77 | def __init__(self, env: gym.Env): 78 | gym.Wrapper.__init__(self, env) 79 | self.lives = 0 80 | self.was_real_done = True 81 | 82 | def step(self, action: int) -> GymStepReturn: 83 | obs, reward, done, info = self.env.step(action) 84 | self.was_real_done = done 85 | # check current lives, make loss of life terminal, 86 | # then update lives to handle bonus lives 87 | lives = self.env.unwrapped.ale.lives() 88 | if 0 < lives < self.lives: 89 | # for Qbert sometimes we stay in lives == 0 condtion for a few frames 90 | # so its important to keep lives > 0, so that we only reset once 91 | # the environment advertises done. 92 | done = True 93 | self.lives = lives 94 | return obs, reward, done, info 95 | 96 | def reset(self, **kwargs) -> np.ndarray: 97 | """ 98 | Calls the Gym environment reset, only when lives are exhausted. 99 | This way all states are still reachable even though lives are episodic, 100 | and the learner need not know about any of this behind-the-scenes. 101 | 102 | :param kwargs: Extra keywords passed to env.reset() call 103 | :return: the first observation of the environment 104 | """ 105 | if self.was_real_done: 106 | obs = self.env.reset(**kwargs) 107 | else: 108 | # no-op step to advance from terminal/lost life state 109 | obs, _, _, _ = self.env.step(0) 110 | self.lives = self.env.unwrapped.ale.lives() 111 | return obs 112 | 113 | 114 | class MaxAndSkipEnv(gym.Wrapper): 115 | """ 116 | Return only every ``skip``-th frame (frameskipping) 117 | 118 | :param env: the environment 119 | :param skip: number of ``skip``-th frame 120 | """ 121 | 122 | def __init__(self, env: gym.Env, skip: int = 4): 123 | gym.Wrapper.__init__(self, env) 124 | # most recent raw observations (for max pooling across time steps) 125 | self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=env.observation_space.dtype) 126 | self._skip = skip 127 | 128 | def step(self, action: int) -> GymStepReturn: 129 | """ 130 | Step the environment with the given action 131 | Repeat action, sum reward, and max over last observations. 132 | 133 | :param action: the action 134 | :return: observation, reward, done, information 135 | """ 136 | total_reward = 0.0 137 | done = None 138 | for i in range(self._skip): 139 | obs, reward, done, info = self.env.step(action) 140 | if i == self._skip - 2: 141 | self._obs_buffer[0] = obs 142 | if i == self._skip - 1: 143 | self._obs_buffer[1] = obs 144 | total_reward += reward 145 | if done: 146 | break 147 | # Note that the observation on the done=True frame 148 | # doesn't matter 149 | max_frame = self._obs_buffer.max(axis=0) 150 | 151 | return max_frame, total_reward, done, info 152 | 153 | def reset(self, **kwargs) -> GymObs: 154 | return self.env.reset(**kwargs) 155 | 156 | 157 | class ClipRewardEnv(gym.RewardWrapper): 158 | """ 159 | Clips the reward to {+1, 0, -1} by its sign. 160 | 161 | :param env: the environment 162 | """ 163 | 164 | def __init__(self, env: gym.Env): 165 | gym.RewardWrapper.__init__(self, env) 166 | 167 | def reward(self, reward: float) -> float: 168 | """ 169 | Bin reward to {+1, 0, -1} by its sign. 170 | 171 | :param reward: 172 | :return: 173 | """ 174 | return np.sign(reward) 175 | 176 | 177 | class WarpFrame(gym.ObservationWrapper): 178 | """ 179 | Convert to grayscale and warp frames to 84x84 (default) 180 | as done in the Nature paper and later work. 181 | 182 | :param env: the environment 183 | :param width: 184 | :param height: 185 | """ 186 | 187 | def __init__(self, env: gym.Env, width: int = 84, height: int = 84): 188 | gym.ObservationWrapper.__init__(self, env) 189 | self.width = width 190 | self.height = height 191 | self.observation_space = spaces.Box( 192 | low=0, high=255, shape=(self.height, self.width, 1), dtype=env.observation_space.dtype 193 | ) 194 | 195 | def observation(self, frame: np.ndarray) -> np.ndarray: 196 | """ 197 | returns the current observation from a frame 198 | 199 | :param frame: environment frame 200 | :return: the observation 201 | """ 202 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) 203 | frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA) 204 | return frame[:, :, None] 205 | 206 | 207 | class AtariWrapper(gym.Wrapper): 208 | """ 209 | Atari 2600 preprocessings 210 | 211 | Specifically: 212 | 213 | * NoopReset: obtain initial state by taking random number of no-ops on reset. 214 | * Frame skipping: 4 by default 215 | * Max-pooling: most recent two observations 216 | * Termination signal when a life is lost. 217 | * Resize to a square image: 84x84 by default 218 | * Grayscale observation 219 | * Clip reward to {-1, 0, 1} 220 | 221 | :param env: gym environment 222 | :param noop_max: max number of no-ops 223 | :param frame_skip: the frequency at which the agent experiences the game. 224 | :param screen_size: resize Atari frame 225 | :param terminal_on_life_loss: if True, then step() returns done=True whenever a life is lost. 226 | :param clip_reward: If True (default), the reward is clip to {-1, 0, 1} depending on its sign. 227 | """ 228 | 229 | def __init__( 230 | self, 231 | env: gym.Env, 232 | noop_max: int = 30, 233 | frame_skip: int = 4, 234 | screen_size: int = 84, 235 | terminal_on_life_loss: bool = True, 236 | clip_reward: bool = True, 237 | ): 238 | env = NoopResetEnv(env, noop_max=noop_max) 239 | env = MaxAndSkipEnv(env, skip=frame_skip) 240 | if terminal_on_life_loss: 241 | env = EpisodicLifeEnv(env) 242 | if "FIRE" in env.unwrapped.get_action_meanings(): 243 | env = FireResetEnv(env) 244 | env = WarpFrame(env, width=screen_size, height=screen_size) 245 | if clip_reward: 246 | env = ClipRewardEnv(env) 247 | 248 | super(AtariWrapper, self).__init__(env) 249 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/env_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Callable, Dict, Optional, Type, Union 3 | 4 | import gym 5 | 6 | from stable_baselines3.common.atari_wrappers import AtariWrapper 7 | from stable_baselines3.common.monitor import Monitor 8 | from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv 9 | 10 | 11 | def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[gym.Wrapper]: 12 | """ 13 | Retrieve a ``VecEnvWrapper`` object by recursively searching. 14 | 15 | :param env: Environment to unwrap 16 | :param wrapper_class: Wrapper to look for 17 | :return: Environment unwrapped till ``wrapper_class`` if it has been wrapped with it 18 | """ 19 | env_tmp = env 20 | while isinstance(env_tmp, gym.Wrapper): 21 | if isinstance(env_tmp, wrapper_class): 22 | return env_tmp 23 | env_tmp = env_tmp.env 24 | return None 25 | 26 | 27 | def is_wrapped(env: Type[gym.Env], wrapper_class: Type[gym.Wrapper]) -> bool: 28 | """ 29 | Check if a given environment has been wrapped with a given wrapper. 30 | 31 | :param env: Environment to check 32 | :param wrapper_class: Wrapper class to look for 33 | :return: True if environment has been wrapped with ``wrapper_class``. 34 | """ 35 | return unwrap_wrapper(env, wrapper_class) is not None 36 | 37 | 38 | def make_vec_env( 39 | env_id: Union[str, Type[gym.Env]], 40 | n_envs: int = 1, 41 | seed: Optional[int] = None, 42 | start_index: int = 0, 43 | monitor_dir: Optional[str] = None, 44 | wrapper_class: Optional[Callable[[gym.Env], gym.Env]] = None, 45 | env_kwargs: Optional[Dict[str, Any]] = None, 46 | vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None, 47 | vec_env_kwargs: Optional[Dict[str, Any]] = None, 48 | monitor_kwargs: Optional[Dict[str, Any]] = None, 49 | wrapper_kwargs: Optional[Dict[str, Any]] = None, 50 | ) -> VecEnv: 51 | """ 52 | Create a wrapped, monitored ``VecEnv``. 53 | By default it uses a ``DummyVecEnv`` which is usually faster 54 | than a ``SubprocVecEnv``. 55 | 56 | :param env_id: the environment ID or the environment class 57 | :param n_envs: the number of environments you wish to have in parallel 58 | :param seed: the initial seed for the random number generator 59 | :param start_index: start rank index 60 | :param monitor_dir: Path to a folder where the monitor files will be saved. 61 | If None, no file will be written, however, the env will still be wrapped 62 | in a Monitor wrapper to provide additional information about training. 63 | :param wrapper_class: Additional wrapper to use on the environment. 64 | This can also be a function with single argument that wraps the environment in many things. 65 | :param env_kwargs: Optional keyword argument to pass to the env constructor 66 | :param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None. 67 | :param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor. 68 | :param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor. 69 | :param wrapper_kwargs: Keyword arguments to pass to the ``Wrapper`` class constructor. 70 | :return: The wrapped environment 71 | """ 72 | env_kwargs = {} if env_kwargs is None else env_kwargs 73 | vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs 74 | monitor_kwargs = {} if monitor_kwargs is None else monitor_kwargs 75 | wrapper_kwargs = {} if wrapper_kwargs is None else wrapper_kwargs 76 | 77 | def make_env(rank): 78 | def _init(): 79 | if isinstance(env_id, str): 80 | env = gym.make(env_id, **env_kwargs) 81 | else: 82 | env = env_id(**env_kwargs) 83 | if seed is not None: 84 | env.seed(seed + rank) 85 | env.action_space.seed(seed + rank) 86 | # Wrap the env in a Monitor wrapper 87 | # to have additional training information 88 | monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None 89 | # Create the monitor folder if needed 90 | if monitor_path is not None: 91 | os.makedirs(monitor_dir, exist_ok=True) 92 | env = Monitor(env, filename=monitor_path, **monitor_kwargs) 93 | # Optionally, wrap the environment with the provided wrapper 94 | if wrapper_class is not None: 95 | env = wrapper_class(env, **wrapper_kwargs) 96 | return env 97 | 98 | return _init 99 | 100 | # No custom VecEnv is passed 101 | if vec_env_cls is None: 102 | # Default: use a DummyVecEnv 103 | vec_env_cls = DummyVecEnv 104 | 105 | return vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs) 106 | 107 | 108 | def make_atari_env( 109 | env_id: Union[str, Type[gym.Env]], 110 | n_envs: int = 1, 111 | seed: Optional[int] = None, 112 | start_index: int = 0, 113 | monitor_dir: Optional[str] = None, 114 | wrapper_kwargs: Optional[Dict[str, Any]] = None, 115 | env_kwargs: Optional[Dict[str, Any]] = None, 116 | vec_env_cls: Optional[Union[DummyVecEnv, SubprocVecEnv]] = None, 117 | vec_env_kwargs: Optional[Dict[str, Any]] = None, 118 | monitor_kwargs: Optional[Dict[str, Any]] = None, 119 | ) -> VecEnv: 120 | """ 121 | Create a wrapped, monitored VecEnv for Atari. 122 | It is a wrapper around ``make_vec_env`` that includes common preprocessing for Atari games. 123 | 124 | :param env_id: the environment ID or the environment class 125 | :param n_envs: the number of environments you wish to have in parallel 126 | :param seed: the initial seed for the random number generator 127 | :param start_index: start rank index 128 | :param monitor_dir: Path to a folder where the monitor files will be saved. 129 | If None, no file will be written, however, the env will still be wrapped 130 | in a Monitor wrapper to provide additional information about training. 131 | :param wrapper_kwargs: Optional keyword argument to pass to the ``AtariWrapper`` 132 | :param env_kwargs: Optional keyword argument to pass to the env constructor 133 | :param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None. 134 | :param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor. 135 | :param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor. 136 | :return: The wrapped environment 137 | """ 138 | if wrapper_kwargs is None: 139 | wrapper_kwargs = {} 140 | 141 | def atari_wrapper(env: gym.Env) -> gym.Env: 142 | env = AtariWrapper(env, **wrapper_kwargs) 143 | return env 144 | 145 | return make_vec_env( 146 | env_id, 147 | n_envs=n_envs, 148 | seed=seed, 149 | start_index=start_index, 150 | monitor_dir=monitor_dir, 151 | wrapper_class=atari_wrapper, 152 | env_kwargs=env_kwargs, 153 | vec_env_cls=vec_env_cls, 154 | vec_env_kwargs=vec_env_kwargs, 155 | monitor_kwargs=monitor_kwargs, 156 | ) 157 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.common.envs.bit_flipping_env import BitFlippingEnv 2 | from stable_baselines3.common.envs.identity_env import ( 3 | FakeImageEnv, 4 | IdentityEnv, 5 | IdentityEnvBox, 6 | IdentityEnvMultiBinary, 7 | IdentityEnvMultiDiscrete, 8 | ) 9 | from stable_baselines3.common.envs.multi_input_envs import SimpleMultiObsEnv 10 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/envs/bit_flipping_env.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Any, Dict, Optional, Union 3 | 4 | import numpy as np 5 | from gym import GoalEnv, spaces 6 | from gym.envs.registration import EnvSpec 7 | 8 | from stable_baselines3.common.type_aliases import GymStepReturn 9 | 10 | 11 | class BitFlippingEnv(GoalEnv): 12 | """ 13 | Simple bit flipping env, useful to test HER. 14 | The goal is to flip all the bits to get a vector of ones. 15 | In the continuous variant, if the ith action component has a value > 0, 16 | then the ith bit will be flipped. 17 | 18 | :param n_bits: Number of bits to flip 19 | :param continuous: Whether to use the continuous actions version or not, 20 | by default, it uses the discrete one 21 | :param max_steps: Max number of steps, by default, equal to n_bits 22 | :param discrete_obs_space: Whether to use the discrete observation 23 | version or not, by default, it uses the ``MultiBinary`` one 24 | :param image_obs_space: Use image as input instead of the ``MultiBinary`` one. 25 | :param channel_first: Whether to use channel-first or last image. 26 | """ 27 | 28 | spec = EnvSpec("BitFlippingEnv-v0") 29 | 30 | def __init__( 31 | self, 32 | n_bits: int = 10, 33 | continuous: bool = False, 34 | max_steps: Optional[int] = None, 35 | discrete_obs_space: bool = False, 36 | image_obs_space: bool = False, 37 | channel_first: bool = True, 38 | ): 39 | super(BitFlippingEnv, self).__init__() 40 | # Shape of the observation when using image space 41 | self.image_shape = (1, 36, 36) if channel_first else (36, 36, 1) 42 | # The achieved goal is determined by the current state 43 | # here, it is a special where they are equal 44 | if discrete_obs_space: 45 | # In the discrete case, the agent act on the binary 46 | # representation of the observation 47 | self.observation_space = spaces.Dict( 48 | { 49 | "observation": spaces.Discrete(2 ** n_bits), 50 | "achieved_goal": spaces.Discrete(2 ** n_bits), 51 | "desired_goal": spaces.Discrete(2 ** n_bits), 52 | } 53 | ) 54 | elif image_obs_space: 55 | # When using image as input, 56 | # one image contains the bits 0 -> 0, 1 -> 255 57 | # and the rest is filled with zeros 58 | self.observation_space = spaces.Dict( 59 | { 60 | "observation": spaces.Box( 61 | low=0, 62 | high=255, 63 | shape=self.image_shape, 64 | dtype=np.uint8, 65 | ), 66 | "achieved_goal": spaces.Box( 67 | low=0, 68 | high=255, 69 | shape=self.image_shape, 70 | dtype=np.uint8, 71 | ), 72 | "desired_goal": spaces.Box( 73 | low=0, 74 | high=255, 75 | shape=self.image_shape, 76 | dtype=np.uint8, 77 | ), 78 | } 79 | ) 80 | else: 81 | self.observation_space = spaces.Dict( 82 | { 83 | "observation": spaces.MultiBinary(n_bits), 84 | "achieved_goal": spaces.MultiBinary(n_bits), 85 | "desired_goal": spaces.MultiBinary(n_bits), 86 | } 87 | ) 88 | 89 | self.obs_space = spaces.MultiBinary(n_bits) 90 | 91 | if continuous: 92 | self.action_space = spaces.Box(-1, 1, shape=(n_bits,), dtype=np.float32) 93 | else: 94 | self.action_space = spaces.Discrete(n_bits) 95 | self.continuous = continuous 96 | self.discrete_obs_space = discrete_obs_space 97 | self.image_obs_space = image_obs_space 98 | self.state = None 99 | self.desired_goal = np.ones((n_bits,)) 100 | if max_steps is None: 101 | max_steps = n_bits 102 | self.max_steps = max_steps 103 | self.current_step = 0 104 | 105 | def seed(self, seed: int) -> None: 106 | self.obs_space.seed(seed) 107 | 108 | def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]: 109 | """ 110 | Convert to discrete space if needed. 111 | 112 | :param state: 113 | :return: 114 | """ 115 | if self.discrete_obs_space: 116 | # The internal state is the binary representation of the 117 | # observed one 118 | return int(sum([state[i] * 2 ** i for i in range(len(state))])) 119 | 120 | if self.image_obs_space: 121 | size = np.prod(self.image_shape) 122 | image = np.concatenate((state * 255, np.zeros(size - len(state), dtype=np.uint8))) 123 | return image.reshape(self.image_shape).astype(np.uint8) 124 | return state 125 | 126 | def convert_to_bit_vector(self, state: Union[int, np.ndarray], batch_size: int) -> np.ndarray: 127 | """ 128 | Convert to bit vector if needed. 129 | 130 | :param state: 131 | :param batch_size: 132 | :return: 133 | """ 134 | # Convert back to bit vector 135 | if isinstance(state, int): 136 | state = np.array(state).reshape(batch_size, -1) 137 | # Convert to binary representation 138 | state = (((state[:, :] & (1 << np.arange(len(self.state))))) > 0).astype(int) 139 | elif self.image_obs_space: 140 | state = state.reshape(batch_size, -1)[:, : len(self.state)] / 255 141 | else: 142 | state = np.array(state).reshape(batch_size, -1) 143 | 144 | return state 145 | 146 | def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]: 147 | """ 148 | Helper to create the observation. 149 | 150 | :return: The current observation. 151 | """ 152 | return OrderedDict( 153 | [ 154 | ("observation", self.convert_if_needed(self.state.copy())), 155 | ("achieved_goal", self.convert_if_needed(self.state.copy())), 156 | ("desired_goal", self.convert_if_needed(self.desired_goal.copy())), 157 | ] 158 | ) 159 | 160 | def reset(self) -> Dict[str, Union[int, np.ndarray]]: 161 | self.current_step = 0 162 | self.state = self.obs_space.sample() 163 | return self._get_obs() 164 | 165 | def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: 166 | if self.continuous: 167 | self.state[action > 0] = 1 - self.state[action > 0] 168 | else: 169 | self.state[action] = 1 - self.state[action] 170 | obs = self._get_obs() 171 | reward = float(self.compute_reward(obs["achieved_goal"], obs["desired_goal"], None)) 172 | done = reward == 0 173 | self.current_step += 1 174 | # Episode terminate when we reached the goal or the max number of steps 175 | info = {"is_success": done} 176 | done = done or self.current_step >= self.max_steps 177 | return obs, reward, done, info 178 | 179 | def compute_reward( 180 | self, achieved_goal: Union[int, np.ndarray], desired_goal: Union[int, np.ndarray], _info: Optional[Dict[str, Any]] 181 | ) -> np.float32: 182 | # As we are using a vectorized version, we need to keep track of the `batch_size` 183 | if isinstance(achieved_goal, int): 184 | batch_size = 1 185 | elif self.image_obs_space: 186 | batch_size = achieved_goal.shape[0] if len(achieved_goal.shape) > 3 else 1 187 | else: 188 | batch_size = achieved_goal.shape[0] if len(achieved_goal.shape) > 1 else 1 189 | 190 | desired_goal = self.convert_to_bit_vector(desired_goal, batch_size) 191 | achieved_goal = self.convert_to_bit_vector(achieved_goal, batch_size) 192 | 193 | # Deceptive reward: it is positive only when the goal is achieved 194 | # Here we are using a vectorized version 195 | distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1) 196 | return -(distance > 0).astype(np.float32) 197 | 198 | def render(self, mode: str = "human") -> Optional[np.ndarray]: 199 | if mode == "rgb_array": 200 | return self.state.copy() 201 | print(self.state) 202 | 203 | def close(self) -> None: 204 | pass 205 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/envs/identity_env.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import numpy as np 4 | from gym import Env, Space 5 | from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete 6 | 7 | from stable_baselines3.common.type_aliases import GymObs, GymStepReturn 8 | 9 | 10 | class IdentityEnv(Env): 11 | def __init__(self, dim: Optional[int] = None, space: Optional[Space] = None, ep_length: int = 100): 12 | """ 13 | Identity environment for testing purposes 14 | 15 | :param dim: the size of the action and observation dimension you want 16 | to learn. Provide at most one of ``dim`` and ``space``. If both are 17 | None, then initialization proceeds with ``dim=1`` and ``space=None``. 18 | :param space: the action and observation space. Provide at most one of 19 | ``dim`` and ``space``. 20 | :param ep_length: the length of each episode in timesteps 21 | """ 22 | if space is None: 23 | if dim is None: 24 | dim = 1 25 | space = Discrete(dim) 26 | else: 27 | assert dim is None, "arguments for both 'dim' and 'space' provided: at most one allowed" 28 | 29 | self.action_space = self.observation_space = space 30 | self.ep_length = ep_length 31 | self.current_step = 0 32 | self.num_resets = -1 # Becomes 0 after __init__ exits. 33 | self.reset() 34 | 35 | def reset(self) -> GymObs: 36 | self.current_step = 0 37 | self.num_resets += 1 38 | self._choose_next_state() 39 | return self.state 40 | 41 | def step(self, action: Union[int, np.ndarray]) -> GymStepReturn: 42 | reward = self._get_reward(action) 43 | self._choose_next_state() 44 | self.current_step += 1 45 | done = self.current_step >= self.ep_length 46 | return self.state, reward, done, {} 47 | 48 | def _choose_next_state(self) -> None: 49 | self.state = self.action_space.sample() 50 | 51 | def _get_reward(self, action: Union[int, np.ndarray]) -> float: 52 | return 1.0 if np.all(self.state == action) else 0.0 53 | 54 | def render(self, mode: str = "human") -> None: 55 | pass 56 | 57 | 58 | class IdentityEnvBox(IdentityEnv): 59 | def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_length: int = 100): 60 | """ 61 | Identity environment for testing purposes 62 | 63 | :param low: the lower bound of the box dim 64 | :param high: the upper bound of the box dim 65 | :param eps: the epsilon bound for correct value 66 | :param ep_length: the length of each episode in timesteps 67 | """ 68 | space = Box(low=low, high=high, shape=(1,), dtype=np.float32) 69 | super().__init__(ep_length=ep_length, space=space) 70 | self.eps = eps 71 | 72 | def step(self, action: np.ndarray) -> GymStepReturn: 73 | reward = self._get_reward(action) 74 | self._choose_next_state() 75 | self.current_step += 1 76 | done = self.current_step >= self.ep_length 77 | return self.state, reward, done, {} 78 | 79 | def _get_reward(self, action: np.ndarray) -> float: 80 | return 1.0 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0.0 81 | 82 | 83 | class IdentityEnvMultiDiscrete(IdentityEnv): 84 | def __init__(self, dim: int = 1, ep_length: int = 100): 85 | """ 86 | Identity environment for testing purposes 87 | 88 | :param dim: the size of the dimensions you want to learn 89 | :param ep_length: the length of each episode in timesteps 90 | """ 91 | space = MultiDiscrete([dim, dim]) 92 | super().__init__(ep_length=ep_length, space=space) 93 | 94 | 95 | class IdentityEnvMultiBinary(IdentityEnv): 96 | def __init__(self, dim: int = 1, ep_length: int = 100): 97 | """ 98 | Identity environment for testing purposes 99 | 100 | :param dim: the size of the dimensions you want to learn 101 | :param ep_length: the length of each episode in timesteps 102 | """ 103 | space = MultiBinary(dim) 104 | super().__init__(ep_length=ep_length, space=space) 105 | 106 | 107 | class FakeImageEnv(Env): 108 | """ 109 | Fake image environment for testing purposes, it mimics Atari games. 110 | 111 | :param action_dim: Number of discrete actions 112 | :param screen_height: Height of the image 113 | :param screen_width: Width of the image 114 | :param n_channels: Number of color channels 115 | :param discrete: Create discrete action space instead of continuous 116 | :param channel_first: Put channels on first axis instead of last 117 | """ 118 | 119 | def __init__( 120 | self, 121 | action_dim: int = 6, 122 | screen_height: int = 84, 123 | screen_width: int = 84, 124 | n_channels: int = 1, 125 | discrete: bool = True, 126 | channel_first: bool = False, 127 | ): 128 | self.observation_shape = (screen_height, screen_width, n_channels) 129 | if channel_first: 130 | self.observation_shape = (n_channels, screen_height, screen_width) 131 | self.observation_space = Box(low=0, high=255, shape=self.observation_shape, dtype=np.uint8) 132 | if discrete: 133 | self.action_space = Discrete(action_dim) 134 | else: 135 | self.action_space = Box(low=-1, high=1, shape=(5,), dtype=np.float32) 136 | self.ep_length = 10 137 | self.current_step = 0 138 | 139 | def reset(self) -> np.ndarray: 140 | self.current_step = 0 141 | return self.observation_space.sample() 142 | 143 | def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: 144 | reward = 0.0 145 | self.current_step += 1 146 | done = self.current_step >= self.ep_length 147 | return self.observation_space.sample(), reward, done, {} 148 | 149 | def render(self, mode: str = "human") -> None: 150 | pass 151 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/envs/multi_input_envs.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | 3 | import gym 4 | import numpy as np 5 | 6 | from stable_baselines3.common.type_aliases import GymStepReturn 7 | 8 | 9 | class SimpleMultiObsEnv(gym.Env): 10 | """ 11 | Base class for GridWorld-based MultiObs Environments 4x4 grid world. 12 | 13 | .. code-block:: text 14 | 15 | ____________ 16 | | 0 1 2 3| 17 | | 4|¯5¯¯6¯| 7| 18 | | 8|_9_10_|11| 19 | |12 13 14 15| 20 | ¯¯¯¯¯¯¯¯¯¯¯¯¯¯ 21 | 22 | start is 0 23 | states 5, 6, 9, and 10 are blocked 24 | goal is 15 25 | actions are = [left, down, right, up] 26 | 27 | simple linear state env of 15 states but encoded with a vector and an image observation: 28 | each column is represented by a random vector and each row is 29 | represented by a random image, both sampled once at creation time. 30 | 31 | :param num_col: Number of columns in the grid 32 | :param num_row: Number of rows in the grid 33 | :param random_start: If true, agent starts in random position 34 | :param channel_last: If true, the image will be channel last, else it will be channel first 35 | """ 36 | 37 | def __init__( 38 | self, 39 | num_col: int = 4, 40 | num_row: int = 4, 41 | random_start: bool = True, 42 | discrete_actions: bool = True, 43 | channel_last: bool = True, 44 | ): 45 | super(SimpleMultiObsEnv, self).__init__() 46 | 47 | self.vector_size = 5 48 | if channel_last: 49 | self.img_size = [64, 64, 1] 50 | else: 51 | self.img_size = [1, 64, 64] 52 | 53 | self.random_start = random_start 54 | self.discrete_actions = discrete_actions 55 | if discrete_actions: 56 | self.action_space = gym.spaces.Discrete(4) 57 | else: 58 | self.action_space = gym.spaces.Box(0, 1, (4,)) 59 | 60 | self.observation_space = gym.spaces.Dict( 61 | spaces={ 62 | "vec": gym.spaces.Box(0, 1, (self.vector_size,), dtype=np.float64), 63 | "img": gym.spaces.Box(0, 255, self.img_size, dtype=np.uint8), 64 | } 65 | ) 66 | self.count = 0 67 | # Timeout 68 | self.max_count = 100 69 | self.log = "" 70 | self.state = 0 71 | self.action2str = ["left", "down", "right", "up"] 72 | self.init_possible_transitions() 73 | 74 | self.num_col = num_col 75 | self.state_mapping = [] 76 | self.init_state_mapping(num_col, num_row) 77 | 78 | self.max_state = len(self.state_mapping) - 1 79 | 80 | def init_state_mapping(self, num_col: int, num_row: int) -> None: 81 | """ 82 | Initializes the state_mapping array which holds the observation values for each state 83 | 84 | :param num_col: Number of columns. 85 | :param num_row: Number of rows. 86 | """ 87 | # Each column is represented by a random vector 88 | col_vecs = np.random.random((num_col, self.vector_size)) 89 | # Each row is represented by a random image 90 | row_imgs = np.random.randint(0, 255, (num_row, 64, 64), dtype=np.uint8) 91 | 92 | for i in range(num_col): 93 | for j in range(num_row): 94 | self.state_mapping.append({"vec": col_vecs[i], "img": row_imgs[j].reshape(self.img_size)}) 95 | 96 | def get_state_mapping(self) -> Dict[str, np.ndarray]: 97 | """ 98 | Uses the state to get the observation mapping. 99 | 100 | :return: observation dict {'vec': ..., 'img': ...} 101 | """ 102 | return self.state_mapping[self.state] 103 | 104 | def init_possible_transitions(self) -> None: 105 | """ 106 | Initializes the transitions of the environment 107 | The environment exploits the cardinal directions of the grid by noting that 108 | they correspond to simple addition and subtraction from the cell id within the grid 109 | 110 | - up => means moving up a row => means subtracting the length of a column 111 | - down => means moving down a row => means adding the length of a column 112 | - left => means moving left by one => means subtracting 1 113 | - right => means moving right by one => means adding 1 114 | 115 | Thus one only needs to specify in which states each action is possible 116 | in order to define the transitions of the environment 117 | """ 118 | self.left_possible = [1, 2, 3, 13, 14, 15] 119 | self.down_possible = [0, 4, 8, 3, 7, 11] 120 | self.right_possible = [0, 1, 2, 12, 13, 14] 121 | self.up_possible = [4, 8, 12, 7, 11, 15] 122 | 123 | def step(self, action: Union[int, float, np.ndarray]) -> GymStepReturn: 124 | """ 125 | Run one timestep of the environment's dynamics. When end of 126 | episode is reached, you are responsible for calling `reset()` 127 | to reset this environment's state. 128 | Accepts an action and returns a tuple (observation, reward, done, info). 129 | 130 | :param action: 131 | :return: tuple (observation, reward, done, info). 132 | """ 133 | if not self.discrete_actions: 134 | action = np.argmax(action) 135 | else: 136 | action = int(action) 137 | 138 | self.count += 1 139 | 140 | prev_state = self.state 141 | 142 | reward = -0.1 143 | # define state transition 144 | if self.state in self.left_possible and action == 0: # left 145 | self.state -= 1 146 | elif self.state in self.down_possible and action == 1: # down 147 | self.state += self.num_col 148 | elif self.state in self.right_possible and action == 2: # right 149 | self.state += 1 150 | elif self.state in self.up_possible and action == 3: # up 151 | self.state -= self.num_col 152 | 153 | got_to_end = self.state == self.max_state 154 | reward = 1 if got_to_end else reward 155 | done = self.count > self.max_count or got_to_end 156 | 157 | self.log = f"Went {self.action2str[action]} in state {prev_state}, got to state {self.state}" 158 | 159 | return self.get_state_mapping(), reward, done, {"got_to_end": got_to_end} 160 | 161 | def render(self, mode: str = "human") -> None: 162 | """ 163 | Prints the log of the environment. 164 | 165 | :param mode: 166 | """ 167 | print(self.log) 168 | 169 | def reset(self) -> Dict[str, np.ndarray]: 170 | """ 171 | Resets the environment state and step count and returns reset observation. 172 | 173 | :return: observation dict {'vec': ..., 'img': ...} 174 | """ 175 | self.count = 0 176 | if not self.random_start: 177 | self.state = 0 178 | else: 179 | self.state = np.random.randint(0, self.max_state) 180 | return self.state_mapping[self.state] 181 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/evaluation.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 3 | 4 | import gym 5 | import numpy as np 6 | 7 | from stable_baselines3.common import base_class 8 | from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped 9 | 10 | ##### local modification ##### 11 | def ssd_policy(quantiles:np.ndarray, use_threshold:bool=False, mean_threshold:float=1e-03): 12 | means = np.mean(quantiles,axis=0) 13 | sort_idx = np.argsort(-1*means) 14 | best_1 = sort_idx[0] 15 | best_2 = sort_idx[1] 16 | if means[best_1] - means[best_2] > mean_threshold: 17 | return best_1 18 | else: 19 | if use_threshold: 20 | signed_second_moment = -1 * np.var(quantiles,axis=0) 21 | else: 22 | signed_second_moment = -1 * np.mean(quantiles**2,axis=0) 23 | action = best_1 24 | if signed_second_moment[best_2] > signed_second_moment[best_1]: 25 | action = best_2 26 | return action 27 | 28 | 29 | def evaluate_policy( 30 | model: "base_class.BaseAlgorithm", 31 | env: Union[gym.Env, VecEnv], 32 | n_eval_episodes: int = 10, 33 | deterministic: bool = True, 34 | render: bool = False, 35 | callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None, 36 | reward_threshold: Optional[float] = None, 37 | return_episode_rewards: bool = False, 38 | warn: bool = True, 39 | ##### local modification ##### 40 | eval_policy: str = "Greedy", 41 | ssd_thres: float = 1e-03 42 | ) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]: 43 | """ 44 | Runs policy for ``n_eval_episodes`` episodes and returns average reward. 45 | If a vector env is passed in, this divides the episodes to evaluate onto the 46 | different elements of the vector env. This static division of work is done to 47 | remove bias. See https://github.com/DLR-RM/stable-baselines3/issues/402 for more 48 | details and discussion. 49 | 50 | .. note:: 51 | If environment has not been wrapped with ``Monitor`` wrapper, reward and 52 | episode lengths are counted as it appears with ``env.step`` calls. If 53 | the environment contains wrappers that modify rewards or episode lengths 54 | (e.g. reward scaling, early episode reset), these will affect the evaluation 55 | results as well. You can avoid this by wrapping environment with ``Monitor`` 56 | wrapper before anything else. 57 | 58 | :param model: The RL agent you want to evaluate. 59 | :param env: The gym environment or ``VecEnv`` environment. 60 | :param n_eval_episodes: Number of episode to evaluate the agent 61 | :param deterministic: Whether to use deterministic or stochastic actions 62 | :param render: Whether to render the environment or not 63 | :param callback: callback function to do additional checks, 64 | called after each step. Gets locals() and globals() passed as parameters. 65 | :param reward_threshold: Minimum expected reward per episode, 66 | this will raise an error if the performance is not met 67 | :param return_episode_rewards: If True, a list of rewards and episode lengths 68 | per episode will be returned instead of the mean. 69 | :param warn: If True (default), warns user about lack of a Monitor wrapper in the 70 | evaluation environment. 71 | :return: Mean reward per episode, std of reward per episode. 72 | Returns ([float], [int]) when ``return_episode_rewards`` is True, first 73 | list containing per-episode rewards and second containing per-episode lengths 74 | (in number of steps). 75 | """ 76 | is_monitor_wrapped = False 77 | # Avoid circular import 78 | from stable_baselines3.common.monitor import Monitor 79 | 80 | if not isinstance(env, VecEnv): 81 | env = DummyVecEnv([lambda: env]) 82 | 83 | is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0] 84 | 85 | if not is_monitor_wrapped and warn: 86 | warnings.warn( 87 | "Evaluation environment is not wrapped with a ``Monitor`` wrapper. " 88 | "This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. " 89 | "Consider wrapping environment first with ``Monitor`` wrapper.", 90 | UserWarning, 91 | ) 92 | 93 | ##### local modification ##### 94 | # store quantiles prediction for all state action pair if the agent is QR-DQN 95 | if env.save_q_vals: 96 | print("saving quantiles (QR-DQN)") 97 | all_quantiles = [] 98 | for i in range(env.num_states): 99 | obs = env.get_obs_at_state(i) 100 | q_vals = model.predict_quantiles(obs) 101 | all_quantiles.append(q_vals.cpu().data.numpy()[0]) 102 | 103 | env.save_quantiles(np.array(all_quantiles)) 104 | 105 | n_envs = env.num_envs 106 | episode_rewards = [] 107 | episode_lengths = [] 108 | 109 | episode_counts = np.zeros(n_envs, dtype="int") 110 | # Divides episodes among different sub environments in the vector as evenly as possible 111 | episode_count_targets = np.array([(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype="int") 112 | 113 | current_rewards = np.zeros(n_envs) 114 | current_lengths = np.zeros(n_envs, dtype="int") 115 | observations = env.reset() 116 | states = None 117 | while (episode_counts < episode_count_targets).any(): 118 | ##### local modification ##### 119 | if eval_policy == "Greedy": 120 | actions, states = model.predict(observations, state=states, deterministic=deterministic) 121 | # TODO: consider multi environments case 122 | elif eval_policy == "SSD": 123 | q_vals = model.predict_quantiles(observations) 124 | actions = np.array([ssd_policy(q_vals.cpu().data.numpy()[0])]) 125 | states = None 126 | elif eval_policy == "Thresholded_SSD": 127 | q_vals = model.predict_quantiles(observations) 128 | actions = np.array([ssd_policy(q_vals.cpu().data.numpy()[0],use_threshold=True,mean_threshold=ssd_thres)]) 129 | states = None 130 | else: 131 | raise RuntimeError("The evaluation policy is not available.") 132 | 133 | observations, rewards, dones, infos = env.step(actions) 134 | ##### local modification ##### 135 | #current_rewards += rewards 136 | current_rewards += env.discount ** current_lengths[0] * rewards 137 | current_lengths += 1 138 | for i in range(n_envs): 139 | if episode_counts[i] < episode_count_targets[i]: 140 | 141 | # unpack values so that the callback can access the local variables 142 | reward = rewards[i] 143 | done = dones[i] 144 | info = infos[i] 145 | 146 | if callback is not None: 147 | callback(locals(), globals()) 148 | 149 | ##### local modification ##### 150 | # if dones[i]: 151 | if dones[i] or current_lengths[i] >= 1000: 152 | print("Eval_steps: ",current_lengths[i]," Eval_return: ",current_rewards[i]) 153 | # if is_monitor_wrapped: 154 | # # Atari wrapper can send a "done" signal when 155 | # # the agent loses a life, but it does not correspond 156 | # # to the true end of episode 157 | # if "episode" in info.keys(): 158 | # # Do not trust "done" with episode endings. 159 | # # Monitor wrapper includes "episode" key in info if environment 160 | # # has been wrapped with it. Use those rewards instead. 161 | # episode_rewards.append(info["episode"]["r"]) 162 | # episode_lengths.append(info["episode"]["l"]) 163 | # # Only increment at the real end of an episode 164 | # episode_counts[i] += 1 165 | # else: 166 | # episode_rewards.append(current_rewards[i]) 167 | # episode_lengths.append(current_lengths[i]) 168 | # episode_counts[i] += 1 169 | 170 | episode_rewards.append(current_rewards[i]) 171 | episode_lengths.append(current_lengths[i]) 172 | episode_counts[i] += 1 173 | 174 | current_rewards[i] = 0 175 | current_lengths[i] = 0 176 | if states is not None: 177 | states[i] *= 0 178 | 179 | if render: 180 | env.render() 181 | 182 | mean_reward = np.mean(episode_rewards) 183 | std_reward = np.std(episode_rewards) 184 | if reward_threshold is not None: 185 | assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}" 186 | if return_episode_rewards: 187 | return episode_rewards, episode_lengths 188 | return mean_reward, std_reward 189 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/monitor.py: -------------------------------------------------------------------------------- 1 | __all__ = ["Monitor", "ResultsWriter", "get_monitor_files", "load_results"] 2 | 3 | import csv 4 | import json 5 | import os 6 | import time 7 | from glob import glob 8 | from typing import Dict, List, Optional, Tuple, Union 9 | 10 | import gym 11 | import numpy as np 12 | import pandas 13 | 14 | from stable_baselines3.common.type_aliases import GymObs, GymStepReturn 15 | 16 | 17 | class Monitor(gym.Wrapper): 18 | """ 19 | A monitor wrapper for Gym environments, it is used to know the episode reward, length, time and other data. 20 | 21 | :param env: The environment 22 | :param filename: the location to save a log file, can be None for no log 23 | :param allow_early_resets: allows the reset of the environment before it is done 24 | :param reset_keywords: extra keywords for the reset call, 25 | if extra parameters are needed at reset 26 | :param info_keywords: extra information to log, from the information return of env.step() 27 | """ 28 | 29 | EXT = "monitor.csv" 30 | 31 | def __init__( 32 | self, 33 | env: gym.Env, 34 | filename: Optional[str] = None, 35 | allow_early_resets: bool = True, 36 | reset_keywords: Tuple[str, ...] = (), 37 | info_keywords: Tuple[str, ...] = (), 38 | ): 39 | super(Monitor, self).__init__(env=env) 40 | self.t_start = time.time() 41 | if filename is not None: 42 | self.results_writer = ResultsWriter( 43 | filename, 44 | header={"t_start": self.t_start, "env_id": env.spec and env.spec.id}, 45 | extra_keys=reset_keywords + info_keywords, 46 | ) 47 | else: 48 | self.results_writer = None 49 | self.reset_keywords = reset_keywords 50 | self.info_keywords = info_keywords 51 | self.allow_early_resets = allow_early_resets 52 | self.rewards = None 53 | self.needs_reset = True 54 | self.episode_returns = [] 55 | self.episode_lengths = [] 56 | self.episode_times = [] 57 | self.total_steps = 0 58 | self.current_reset_info = {} # extra info about the current episode, that was passed in during reset() 59 | 60 | def reset(self, **kwargs) -> GymObs: 61 | """ 62 | Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True 63 | 64 | :param kwargs: Extra keywords saved for the next episode. only if defined by reset_keywords 65 | :return: the first observation of the environment 66 | """ 67 | if not self.allow_early_resets and not self.needs_reset: 68 | raise RuntimeError( 69 | "Tried to reset an environment before done. If you want to allow early resets, " 70 | "wrap your env with Monitor(env, path, allow_early_resets=True)" 71 | ) 72 | self.rewards = [] 73 | self.needs_reset = False 74 | for key in self.reset_keywords: 75 | value = kwargs.get(key) 76 | if value is None: 77 | raise ValueError(f"Expected you to pass keyword argument {key} into reset") 78 | self.current_reset_info[key] = value 79 | return self.env.reset(**kwargs) 80 | 81 | def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: 82 | """ 83 | Step the environment with the given action 84 | 85 | :param action: the action 86 | :return: observation, reward, done, information 87 | """ 88 | if self.needs_reset: 89 | raise RuntimeError("Tried to step environment that needs reset") 90 | observation, reward, done, info = self.env.step(action) 91 | self.rewards.append(reward) 92 | if done: 93 | self.needs_reset = True 94 | ep_rew = sum(self.rewards) 95 | ep_len = len(self.rewards) 96 | ep_info = {"r": round(ep_rew, 6), "l": ep_len, "t": round(time.time() - self.t_start, 6)} 97 | for key in self.info_keywords: 98 | ep_info[key] = info[key] 99 | self.episode_returns.append(ep_rew) 100 | self.episode_lengths.append(ep_len) 101 | self.episode_times.append(time.time() - self.t_start) 102 | ep_info.update(self.current_reset_info) 103 | if self.results_writer: 104 | self.results_writer.write_row(ep_info) 105 | info["episode"] = ep_info 106 | self.total_steps += 1 107 | return observation, reward, done, info 108 | 109 | def close(self) -> None: 110 | """ 111 | Closes the environment 112 | """ 113 | super(Monitor, self).close() 114 | if self.results_writer is not None: 115 | self.results_writer.close() 116 | 117 | def get_total_steps(self) -> int: 118 | """ 119 | Returns the total number of timesteps 120 | 121 | :return: 122 | """ 123 | return self.total_steps 124 | 125 | def get_episode_rewards(self) -> List[float]: 126 | """ 127 | Returns the rewards of all the episodes 128 | 129 | :return: 130 | """ 131 | return self.episode_returns 132 | 133 | def get_episode_lengths(self) -> List[int]: 134 | """ 135 | Returns the number of timesteps of all the episodes 136 | 137 | :return: 138 | """ 139 | return self.episode_lengths 140 | 141 | def get_episode_times(self) -> List[float]: 142 | """ 143 | Returns the runtime in seconds of all the episodes 144 | 145 | :return: 146 | """ 147 | return self.episode_times 148 | 149 | 150 | class LoadMonitorResultsError(Exception): 151 | """ 152 | Raised when loading the monitor log fails. 153 | """ 154 | 155 | pass 156 | 157 | 158 | class ResultsWriter: 159 | """ 160 | A result writer that saves the data from the `Monitor` class 161 | 162 | :param filename: the location to save a log file, can be None for no log 163 | :param header: the header dictionary object of the saved csv 164 | :param reset_keywords: the extra information to log, typically is composed of 165 | ``reset_keywords`` and ``info_keywords`` 166 | """ 167 | 168 | def __init__( 169 | self, 170 | filename: str = "", 171 | header: Optional[Dict[str, Union[float, str]]] = None, 172 | extra_keys: Tuple[str, ...] = (), 173 | ): 174 | if header is None: 175 | header = {} 176 | if not filename.endswith(Monitor.EXT): 177 | if os.path.isdir(filename): 178 | filename = os.path.join(filename, Monitor.EXT) 179 | else: 180 | filename = filename + "." + Monitor.EXT 181 | self.file_handler = open(filename, "wt") 182 | self.file_handler.write("#%s\n" % json.dumps(header)) 183 | self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t") + extra_keys) 184 | self.logger.writeheader() 185 | self.file_handler.flush() 186 | 187 | def write_row(self, epinfo: Dict[str, Union[float, int]]) -> None: 188 | """ 189 | Close the file handler 190 | 191 | :param epinfo: the information on episodic return, length, and time 192 | """ 193 | if self.logger: 194 | self.logger.writerow(epinfo) 195 | self.file_handler.flush() 196 | 197 | def close(self) -> None: 198 | """ 199 | Close the file handler 200 | """ 201 | self.file_handler.close() 202 | 203 | 204 | def get_monitor_files(path: str) -> List[str]: 205 | """ 206 | get all the monitor files in the given path 207 | 208 | :param path: the logging folder 209 | :return: the log files 210 | """ 211 | return glob(os.path.join(path, "*" + Monitor.EXT)) 212 | 213 | 214 | def load_results(path: str) -> pandas.DataFrame: 215 | """ 216 | Load all Monitor logs from a given directory path matching ``*monitor.csv`` 217 | 218 | :param path: the directory path containing the log file(s) 219 | :return: the logged data 220 | """ 221 | monitor_files = get_monitor_files(path) 222 | if len(monitor_files) == 0: 223 | raise LoadMonitorResultsError(f"No monitor files of the form *{Monitor.EXT} found in {path}") 224 | data_frames, headers = [], [] 225 | for file_name in monitor_files: 226 | with open(file_name, "rt") as file_handler: 227 | first_line = file_handler.readline() 228 | assert first_line[0] == "#" 229 | header = json.loads(first_line[1:]) 230 | data_frame = pandas.read_csv(file_handler, index_col=None) 231 | headers.append(header) 232 | data_frame["t"] += header["t_start"] 233 | data_frames.append(data_frame) 234 | data_frame = pandas.concat(data_frames) 235 | data_frame.sort_values("t", inplace=True) 236 | data_frame.reset_index(inplace=True) 237 | data_frame["t"] -= min(header["t_start"] for header in headers) 238 | return data_frame 239 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/noise.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from abc import ABC, abstractmethod 3 | from typing import Iterable, List, Optional 4 | 5 | import numpy as np 6 | 7 | 8 | class ActionNoise(ABC): 9 | """ 10 | The action noise base class 11 | """ 12 | 13 | def __init__(self): 14 | super(ActionNoise, self).__init__() 15 | 16 | def reset(self) -> None: 17 | """ 18 | call end of episode reset for the noise 19 | """ 20 | pass 21 | 22 | @abstractmethod 23 | def __call__(self) -> np.ndarray: 24 | raise NotImplementedError() 25 | 26 | 27 | class NormalActionNoise(ActionNoise): 28 | """ 29 | A Gaussian action noise 30 | 31 | :param mean: the mean value of the noise 32 | :param sigma: the scale of the noise (std here) 33 | """ 34 | 35 | def __init__(self, mean: np.ndarray, sigma: np.ndarray): 36 | self._mu = mean 37 | self._sigma = sigma 38 | super(NormalActionNoise, self).__init__() 39 | 40 | def __call__(self) -> np.ndarray: 41 | return np.random.normal(self._mu, self._sigma) 42 | 43 | def __repr__(self) -> str: 44 | return f"NormalActionNoise(mu={self._mu}, sigma={self._sigma})" 45 | 46 | 47 | class OrnsteinUhlenbeckActionNoise(ActionNoise): 48 | """ 49 | An Ornstein Uhlenbeck action noise, this is designed to approximate Brownian motion with friction. 50 | 51 | Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab 52 | 53 | :param mean: the mean of the noise 54 | :param sigma: the scale of the noise 55 | :param theta: the rate of mean reversion 56 | :param dt: the timestep for the noise 57 | :param initial_noise: the initial value for the noise output, (if None: 0) 58 | """ 59 | 60 | def __init__( 61 | self, 62 | mean: np.ndarray, 63 | sigma: np.ndarray, 64 | theta: float = 0.15, 65 | dt: float = 1e-2, 66 | initial_noise: Optional[np.ndarray] = None, 67 | ): 68 | self._theta = theta 69 | self._mu = mean 70 | self._sigma = sigma 71 | self._dt = dt 72 | self.initial_noise = initial_noise 73 | self.noise_prev = np.zeros_like(self._mu) 74 | self.reset() 75 | super(OrnsteinUhlenbeckActionNoise, self).__init__() 76 | 77 | def __call__(self) -> np.ndarray: 78 | noise = ( 79 | self.noise_prev 80 | + self._theta * (self._mu - self.noise_prev) * self._dt 81 | + self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape) 82 | ) 83 | self.noise_prev = noise 84 | return noise 85 | 86 | def reset(self) -> None: 87 | """ 88 | reset the Ornstein Uhlenbeck noise, to the initial position 89 | """ 90 | self.noise_prev = self.initial_noise if self.initial_noise is not None else np.zeros_like(self._mu) 91 | 92 | def __repr__(self) -> str: 93 | return f"OrnsteinUhlenbeckActionNoise(mu={self._mu}, sigma={self._sigma})" 94 | 95 | 96 | class VectorizedActionNoise(ActionNoise): 97 | """ 98 | A Vectorized action noise for parallel environments. 99 | 100 | :param base_noise: ActionNoise The noise generator to use 101 | :param n_envs: The number of parallel environments 102 | """ 103 | 104 | def __init__(self, base_noise: ActionNoise, n_envs: int): 105 | try: 106 | self.n_envs = int(n_envs) 107 | assert self.n_envs > 0 108 | except (TypeError, AssertionError): 109 | raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0") 110 | 111 | self.base_noise = base_noise 112 | self.noises = [copy.deepcopy(self.base_noise) for _ in range(n_envs)] 113 | 114 | def reset(self, indices: Optional[Iterable[int]] = None) -> None: 115 | """ 116 | Reset all the noise processes, or those listed in indices 117 | 118 | :param indices: Optional[Iterable[int]] The indices to reset. Default: None. 119 | If the parameter is None, then all processes are reset to their initial position. 120 | """ 121 | if indices is None: 122 | indices = range(len(self.noises)) 123 | 124 | for index in indices: 125 | self.noises[index].reset() 126 | 127 | def __repr__(self) -> str: 128 | return f"VecNoise(BaseNoise={repr(self.base_noise)}), n_envs={len(self.noises)})" 129 | 130 | def __call__(self) -> np.ndarray: 131 | """ 132 | Generate and stack the action noise from each noise object 133 | """ 134 | noise = np.stack([noise() for noise in self.noises]) 135 | return noise 136 | 137 | @property 138 | def base_noise(self) -> ActionNoise: 139 | return self._base_noise 140 | 141 | @base_noise.setter 142 | def base_noise(self, base_noise: ActionNoise) -> None: 143 | if base_noise is None: 144 | raise ValueError("Expected base_noise to be an instance of ActionNoise, not None", ActionNoise) 145 | if not isinstance(base_noise, ActionNoise): 146 | raise TypeError("Expected base_noise to be an instance of type ActionNoise", ActionNoise) 147 | self._base_noise = base_noise 148 | 149 | @property 150 | def noises(self) -> List[ActionNoise]: 151 | return self._noises 152 | 153 | @noises.setter 154 | def noises(self, noises: List[ActionNoise]) -> None: 155 | noises = list(noises) # raises TypeError if not iterable 156 | assert len(noises) == self.n_envs, f"Expected a list of {self.n_envs} ActionNoises, found {len(noises)}." 157 | 158 | different_types = [i for i, noise in enumerate(noises) if not isinstance(noise, type(self.base_noise))] 159 | 160 | if len(different_types): 161 | raise ValueError( 162 | f"Noise instances at indices {different_types} don't match the type of base_noise", type(self.base_noise) 163 | ) 164 | 165 | self._noises = noises 166 | for noise in noises: 167 | noise.reset() 168 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/preprocessing.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Dict, Tuple, Union 3 | 4 | import numpy as np 5 | import torch as th 6 | from gym import spaces 7 | from torch.nn import functional as F 8 | 9 | 10 | def is_image_space_channels_first(observation_space: spaces.Box) -> bool: 11 | """ 12 | Check if an image observation space (see ``is_image_space``) 13 | is channels-first (CxHxW, True) or channels-last (HxWxC, False). 14 | 15 | Use a heuristic that channel dimension is the smallest of the three. 16 | If second dimension is smallest, raise an exception (no support). 17 | 18 | :param observation_space: 19 | :return: True if observation space is channels-first image, False if channels-last. 20 | """ 21 | smallest_dimension = np.argmin(observation_space.shape).item() 22 | if smallest_dimension == 1: 23 | warnings.warn("Treating image space as channels-last, while second dimension was smallest of the three.") 24 | return smallest_dimension == 0 25 | 26 | 27 | def is_image_space( 28 | observation_space: spaces.Space, 29 | check_channels: bool = False, 30 | ) -> bool: 31 | """ 32 | Check if a observation space has the shape, limits and dtype 33 | of a valid image. 34 | The check is conservative, so that it returns False if there is a doubt. 35 | 36 | Valid images: RGB, RGBD, GrayScale with values in [0, 255] 37 | 38 | :param observation_space: 39 | :param check_channels: Whether to do or not the check for the number of channels. 40 | e.g., with frame-stacking, the observation space may have more channels than expected. 41 | :return: 42 | """ 43 | if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3: 44 | # Check the type 45 | if observation_space.dtype != np.uint8: 46 | return False 47 | 48 | # Check the value range 49 | if np.any(observation_space.low != 0) or np.any(observation_space.high != 255): 50 | return False 51 | 52 | # Skip channels check 53 | if not check_channels: 54 | return True 55 | # Check the number of channels 56 | if is_image_space_channels_first(observation_space): 57 | n_channels = observation_space.shape[0] 58 | else: 59 | n_channels = observation_space.shape[-1] 60 | # RGB, RGBD, GrayScale 61 | return n_channels in [1, 3, 4] 62 | return False 63 | 64 | 65 | def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) -> np.ndarray: 66 | """ 67 | Handle the different cases for images as PyTorch use channel first format. 68 | 69 | :param observation: 70 | :param observation_space: 71 | :return: channel first observation if observation is an image 72 | """ 73 | # Avoid circular import 74 | from stable_baselines3.common.vec_env import VecTransposeImage 75 | 76 | if is_image_space(observation_space): 77 | if not (observation.shape == observation_space.shape or observation.shape[1:] == observation_space.shape): 78 | # Try to re-order the channels 79 | transpose_obs = VecTransposeImage.transpose_image(observation) 80 | if transpose_obs.shape == observation_space.shape or transpose_obs.shape[1:] == observation_space.shape: 81 | observation = transpose_obs 82 | return observation 83 | 84 | 85 | def preprocess_obs( 86 | obs: th.Tensor, 87 | observation_space: spaces.Space, 88 | normalize_images: bool = True, 89 | ) -> Union[th.Tensor, Dict[str, th.Tensor]]: 90 | """ 91 | Preprocess observation to be to a neural network. 92 | For images, it normalizes the values by dividing them by 255 (to have values in [0, 1]) 93 | For discrete observations, it create a one hot vector. 94 | 95 | :param obs: Observation 96 | :param observation_space: 97 | :param normalize_images: Whether to normalize images or not 98 | (True by default) 99 | :return: 100 | """ 101 | if isinstance(observation_space, spaces.Box): 102 | if is_image_space(observation_space) and normalize_images: 103 | return obs.float() / 255.0 104 | return obs.float() 105 | 106 | elif isinstance(observation_space, spaces.Discrete): 107 | # One hot encoding and convert to float to avoid errors 108 | return F.one_hot(obs.long(), num_classes=observation_space.n).float() 109 | 110 | elif isinstance(observation_space, spaces.MultiDiscrete): 111 | # Tensor concatenation of one hot encodings of each Categorical sub-space 112 | return th.cat( 113 | [ 114 | F.one_hot(obs_.long(), num_classes=int(observation_space.nvec[idx])).float() 115 | for idx, obs_ in enumerate(th.split(obs.long(), 1, dim=1)) 116 | ], 117 | dim=-1, 118 | ).view(obs.shape[0], sum(observation_space.nvec)) 119 | 120 | elif isinstance(observation_space, spaces.MultiBinary): 121 | return obs.float() 122 | 123 | elif isinstance(observation_space, spaces.Dict): 124 | # Do not modify by reference the original observation 125 | preprocessed_obs = {} 126 | for key, _obs in obs.items(): 127 | preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images) 128 | return preprocessed_obs 129 | 130 | else: 131 | raise NotImplementedError(f"Preprocessing not implemented for {observation_space}") 132 | 133 | 134 | def get_obs_shape( 135 | observation_space: spaces.Space, 136 | ) -> Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]]: 137 | """ 138 | Get the shape of the observation (useful for the buffers). 139 | 140 | :param observation_space: 141 | :return: 142 | """ 143 | if isinstance(observation_space, spaces.Box): 144 | return observation_space.shape 145 | elif isinstance(observation_space, spaces.Discrete): 146 | # Observation is an int 147 | return (1,) 148 | elif isinstance(observation_space, spaces.MultiDiscrete): 149 | # Number of discrete features 150 | return (int(len(observation_space.nvec)),) 151 | elif isinstance(observation_space, spaces.MultiBinary): 152 | # Number of binary features 153 | return (int(observation_space.n),) 154 | elif isinstance(observation_space, spaces.Dict): 155 | return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} 156 | 157 | else: 158 | raise NotImplementedError(f"{observation_space} observation space is not supported") 159 | 160 | 161 | def get_flattened_obs_dim(observation_space: spaces.Space) -> int: 162 | """ 163 | Get the dimension of the observation space when flattened. 164 | It does not apply to image observation space. 165 | 166 | Used by the ``FlattenExtractor`` to compute the input shape. 167 | 168 | :param observation_space: 169 | :return: 170 | """ 171 | # See issue https://github.com/openai/gym/issues/1915 172 | # it may be a problem for Dict/Tuple spaces too... 173 | if isinstance(observation_space, spaces.MultiDiscrete): 174 | return sum(observation_space.nvec) 175 | else: 176 | # Use Gym internal method 177 | return spaces.utils.flatdim(observation_space) 178 | 179 | 180 | def get_action_dim(action_space: spaces.Space) -> int: 181 | """ 182 | Get the dimension of the action space. 183 | 184 | :param action_space: 185 | :return: 186 | """ 187 | if isinstance(action_space, spaces.Box): 188 | return int(np.prod(action_space.shape)) 189 | elif isinstance(action_space, spaces.Discrete): 190 | # Action is an int 191 | return 1 192 | elif isinstance(action_space, spaces.MultiDiscrete): 193 | # Number of discrete actions 194 | return int(len(action_space.nvec)) 195 | elif isinstance(action_space, spaces.MultiBinary): 196 | # Number of binary actions 197 | return int(action_space.n) 198 | else: 199 | raise NotImplementedError(f"{action_space} action space is not supported") 200 | 201 | 202 | def check_for_nested_spaces(obs_space: spaces.Space): 203 | """ 204 | Make sure the observation space does not have nested spaces (Dicts/Tuples inside Dicts/Tuples). 205 | If so, raise an Exception informing that there is no support for this. 206 | 207 | :param obs_space: an observation space 208 | :return: 209 | """ 210 | if isinstance(obs_space, (spaces.Dict, spaces.Tuple)): 211 | sub_spaces = obs_space.spaces.values() if isinstance(obs_space, spaces.Dict) else obs_space.spaces 212 | for sub_space in sub_spaces: 213 | if isinstance(sub_space, (spaces.Dict, spaces.Tuple)): 214 | raise NotImplementedError( 215 | "Nested observation spaces are not supported (Tuple/Dict space inside Tuple/Dict space)." 216 | ) 217 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/results_plotter.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional, Tuple 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | # import matplotlib 7 | # matplotlib.use('TkAgg') # Can change to 'Agg' for non-interactive mode 8 | from matplotlib import pyplot as plt 9 | 10 | from stable_baselines3.common.monitor import load_results 11 | 12 | X_TIMESTEPS = "timesteps" 13 | X_EPISODES = "episodes" 14 | X_WALLTIME = "walltime_hrs" 15 | POSSIBLE_X_AXES = [X_TIMESTEPS, X_EPISODES, X_WALLTIME] 16 | EPISODES_WINDOW = 100 17 | 18 | 19 | def rolling_window(array: np.ndarray, window: int) -> np.ndarray: 20 | """ 21 | Apply a rolling window to a np.ndarray 22 | 23 | :param array: the input Array 24 | :param window: length of the rolling window 25 | :return: rolling window on the input array 26 | """ 27 | shape = array.shape[:-1] + (array.shape[-1] - window + 1, window) 28 | strides = array.strides + (array.strides[-1],) 29 | return np.lib.stride_tricks.as_strided(array, shape=shape, strides=strides) 30 | 31 | 32 | def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callable) -> Tuple[np.ndarray, np.ndarray]: 33 | """ 34 | Apply a function to the rolling window of 2 arrays 35 | 36 | :param var_1: variable 1 37 | :param var_2: variable 2 38 | :param window: length of the rolling window 39 | :param func: function to apply on the rolling window on variable 2 (such as np.mean) 40 | :return: the rolling output with applied function 41 | """ 42 | var_2_window = rolling_window(var_2, window) 43 | function_on_var2 = func(var_2_window, axis=-1) 44 | return var_1[window - 1 :], function_on_var2 45 | 46 | 47 | def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]: 48 | """ 49 | Decompose a data frame variable to x ans ys 50 | 51 | :param data_frame: the input data 52 | :param x_axis: the axis for the x and y output 53 | (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs') 54 | :return: the x and y output 55 | """ 56 | if x_axis == X_TIMESTEPS: 57 | x_var = np.cumsum(data_frame.l.values) 58 | y_var = data_frame.r.values 59 | elif x_axis == X_EPISODES: 60 | x_var = np.arange(len(data_frame)) 61 | y_var = data_frame.r.values 62 | elif x_axis == X_WALLTIME: 63 | # Convert to hours 64 | x_var = data_frame.t.values / 3600.0 65 | y_var = data_frame.r.values 66 | else: 67 | raise NotImplementedError 68 | return x_var, y_var 69 | 70 | 71 | def plot_curves( 72 | xy_list: List[Tuple[np.ndarray, np.ndarray]], x_axis: str, title: str, figsize: Tuple[int, int] = (8, 2) 73 | ) -> None: 74 | """ 75 | plot the curves 76 | 77 | :param xy_list: the x and y coordinates to plot 78 | :param x_axis: the axis for the x and y output 79 | (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs') 80 | :param title: the title of the plot 81 | :param figsize: Size of the figure (width, height) 82 | """ 83 | 84 | plt.figure(title, figsize=figsize) 85 | max_x = max(xy[0][-1] for xy in xy_list) 86 | min_x = 0 87 | for (_, (x, y)) in enumerate(xy_list): 88 | plt.scatter(x, y, s=2) 89 | # Do not plot the smoothed curve at all if the timeseries is shorter than window size. 90 | if x.shape[0] >= EPISODES_WINDOW: 91 | # Compute and plot rolling mean with window of size EPISODE_WINDOW 92 | x, y_mean = window_func(x, y, EPISODES_WINDOW, np.mean) 93 | plt.plot(x, y_mean) 94 | plt.xlim(min_x, max_x) 95 | plt.title(title) 96 | plt.xlabel(x_axis) 97 | plt.ylabel("Episode Rewards") 98 | plt.tight_layout() 99 | 100 | 101 | def plot_results( 102 | dirs: List[str], num_timesteps: Optional[int], x_axis: str, task_name: str, figsize: Tuple[int, int] = (8, 2) 103 | ) -> None: 104 | """ 105 | Plot the results using csv files from ``Monitor`` wrapper. 106 | 107 | :param dirs: the save location of the results to plot 108 | :param num_timesteps: only plot the points below this value 109 | :param x_axis: the axis for the x and y output 110 | (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs') 111 | :param task_name: the title of the task to plot 112 | :param figsize: Size of the figure (width, height) 113 | """ 114 | 115 | data_frames = [] 116 | for folder in dirs: 117 | data_frame = load_results(folder) 118 | if num_timesteps is not None: 119 | data_frame = data_frame[data_frame.l.cumsum() <= num_timesteps] 120 | data_frames.append(data_frame) 121 | xy_list = [ts2xy(data_frame, x_axis) for data_frame in data_frames] 122 | plot_curves(xy_list, x_axis, task_name, figsize) 123 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/running_mean_std.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | 5 | 6 | class RunningMeanStd(object): 7 | def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()): 8 | """ 9 | Calulates the running mean and std of a data stream 10 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 11 | 12 | :param epsilon: helps with arithmetic issues 13 | :param shape: the shape of the data stream's output 14 | """ 15 | self.mean = np.zeros(shape, np.float64) 16 | self.var = np.ones(shape, np.float64) 17 | self.count = epsilon 18 | 19 | def update(self, arr: np.ndarray) -> None: 20 | batch_mean = np.mean(arr, axis=0) 21 | batch_var = np.var(arr, axis=0) 22 | batch_count = arr.shape[0] 23 | self.update_from_moments(batch_mean, batch_var, batch_count) 24 | 25 | def update_from_moments(self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: int) -> None: 26 | delta = batch_mean - self.mean 27 | tot_count = self.count + batch_count 28 | 29 | new_mean = self.mean + delta * batch_count / tot_count 30 | m_a = self.var * self.count 31 | m_b = batch_var * batch_count 32 | m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count) 33 | new_var = m_2 / (self.count + batch_count) 34 | 35 | new_count = batch_count + self.count 36 | 37 | self.mean = new_mean 38 | self.var = new_var 39 | self.count = new_count 40 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/sb2_compat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RobustFieldAutonomyLab/Stochastic_Road_Network/1cdcd41c7311560bf7a5df0d4d8bca829fe2b958/thirdparty/stable_baselines3/common/sb2_compat/__init__.py -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Iterable, Optional 2 | 3 | import torch 4 | from torch.optim import Optimizer 5 | 6 | 7 | class RMSpropTFLike(Optimizer): 8 | r"""Implements RMSprop algorithm with closer match to Tensorflow version. 9 | 10 | For reproducibility with original stable-baselines. Use this 11 | version with e.g. A2C for stabler learning than with the PyTorch 12 | RMSProp. Based on the PyTorch v1.5.0 implementation of RMSprop. 13 | 14 | See a more throughout conversion in pytorch-image-models repository: 15 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/rmsprop_tf.py 16 | 17 | Changes to the original RMSprop: 18 | - Move epsilon inside square root 19 | - Initialize squared gradient to ones rather than zeros 20 | 21 | Proposed by G. Hinton in his 22 | `course `_. 23 | 24 | The centered version first appears in `Generating Sequences 25 | With Recurrent Neural Networks `_. 26 | 27 | The implementation here takes the square root of the gradient average before 28 | adding epsilon (note that TensorFlow interchanges these two operations). The effective 29 | learning rate is thus :math:`\alpha/(\sqrt{v} + \epsilon)` where :math:`\alpha` 30 | is the scheduled learning rate and :math:`v` is the weighted moving average 31 | of the squared gradient. 32 | 33 | :params: iterable of parameters to optimize or dicts defining 34 | parameter groups 35 | :param lr: learning rate (default: 1e-2) 36 | :param momentum: momentum factor (default: 0) 37 | :param alpha: smoothing constant (default: 0.99) 38 | :param eps: term added to the denominator to improve 39 | numerical stability (default: 1e-8) 40 | :param centered: if ``True``, compute the centered RMSProp, 41 | the gradient is normalized by an estimation of its variance 42 | :param weight_decay: weight decay (L2 penalty) (default: 0) 43 | 44 | """ 45 | 46 | def __init__( 47 | self, 48 | params: Iterable[torch.nn.Parameter], 49 | lr: float = 1e-2, 50 | alpha: float = 0.99, 51 | eps: float = 1e-8, 52 | weight_decay: float = 0, 53 | momentum: float = 0, 54 | centered: bool = False, 55 | ): 56 | if not 0.0 <= lr: 57 | raise ValueError("Invalid learning rate: {}".format(lr)) 58 | if not 0.0 <= eps: 59 | raise ValueError("Invalid epsilon value: {}".format(eps)) 60 | if not 0.0 <= momentum: 61 | raise ValueError("Invalid momentum value: {}".format(momentum)) 62 | if not 0.0 <= weight_decay: 63 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 64 | if not 0.0 <= alpha: 65 | raise ValueError("Invalid alpha value: {}".format(alpha)) 66 | 67 | defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay) 68 | super(RMSpropTFLike, self).__init__(params, defaults) 69 | 70 | def __setstate__(self, state: Dict[str, Any]) -> None: 71 | super(RMSpropTFLike, self).__setstate__(state) 72 | for group in self.param_groups: 73 | group.setdefault("momentum", 0) 74 | group.setdefault("centered", False) 75 | 76 | @torch.no_grad() 77 | def step(self, closure: Optional[Callable[[], None]] = None) -> Optional[torch.Tensor]: 78 | """Performs a single optimization step. 79 | 80 | :param closure: A closure that reevaluates the model 81 | and returns the loss. 82 | :return: loss 83 | """ 84 | loss = None 85 | if closure is not None: 86 | with torch.enable_grad(): 87 | loss = closure() 88 | 89 | for group in self.param_groups: 90 | for p in group["params"]: 91 | if p.grad is None: 92 | continue 93 | grad = p.grad 94 | if grad.is_sparse: 95 | raise RuntimeError("RMSpropTF does not support sparse gradients") 96 | state = self.state[p] 97 | 98 | # State initialization 99 | if len(state) == 0: 100 | state["step"] = 0 101 | # PyTorch initialized to zeros here 102 | state["square_avg"] = torch.ones_like(p, memory_format=torch.preserve_format) 103 | if group["momentum"] > 0: 104 | state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format) 105 | if group["centered"]: 106 | state["grad_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format) 107 | 108 | square_avg = state["square_avg"] 109 | alpha = group["alpha"] 110 | 111 | state["step"] += 1 112 | 113 | if group["weight_decay"] != 0: 114 | grad = grad.add(p, alpha=group["weight_decay"]) 115 | 116 | square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) 117 | 118 | if group["centered"]: 119 | grad_avg = state["grad_avg"] 120 | grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha) 121 | # PyTorch added epsilon after square root 122 | # avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_(group['eps']) 123 | avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add_(group["eps"]).sqrt_() 124 | else: 125 | # PyTorch added epsilon after square root 126 | # avg = square_avg.sqrt().add_(group['eps']) 127 | avg = square_avg.add(group["eps"]).sqrt_() 128 | 129 | if group["momentum"] > 0: 130 | buf = state["momentum_buffer"] 131 | buf.mul_(group["momentum"]).addcdiv_(grad, avg) 132 | p.add_(buf, alpha=-group["lr"]) 133 | else: 134 | p.addcdiv_(grad, avg, value=-group["lr"]) 135 | 136 | return loss 137 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/type_aliases.py: -------------------------------------------------------------------------------- 1 | """Common aliases for type hints""" 2 | 3 | from enum import Enum 4 | from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Union 5 | 6 | import gym 7 | import numpy as np 8 | import torch as th 9 | 10 | from stable_baselines3.common import callbacks, vec_env 11 | 12 | GymEnv = Union[gym.Env, vec_env.VecEnv] 13 | GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int] 14 | GymStepReturn = Tuple[GymObs, float, bool, Dict] 15 | TensorDict = Dict[Union[str, int], th.Tensor] 16 | OptimizerStateDict = Dict[str, Any] 17 | MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback] 18 | 19 | # A schedule takes the remaining progress as input 20 | # and ouputs a scalar (e.g. learning rate, clip range, ...) 21 | Schedule = Callable[[float], float] 22 | 23 | 24 | class RolloutBufferSamples(NamedTuple): 25 | observations: th.Tensor 26 | actions: th.Tensor 27 | old_values: th.Tensor 28 | old_log_prob: th.Tensor 29 | advantages: th.Tensor 30 | returns: th.Tensor 31 | 32 | 33 | class DictRolloutBufferSamples(RolloutBufferSamples): 34 | observations: TensorDict 35 | actions: th.Tensor 36 | old_values: th.Tensor 37 | old_log_prob: th.Tensor 38 | advantages: th.Tensor 39 | returns: th.Tensor 40 | 41 | 42 | class ReplayBufferSamples(NamedTuple): 43 | observations: th.Tensor 44 | actions: th.Tensor 45 | next_observations: th.Tensor 46 | dones: th.Tensor 47 | rewards: th.Tensor 48 | 49 | 50 | class DictReplayBufferSamples(ReplayBufferSamples): 51 | observations: TensorDict 52 | actions: th.Tensor 53 | next_observations: th.Tensor 54 | dones: th.Tensor 55 | rewards: th.Tensor 56 | 57 | 58 | class RolloutReturn(NamedTuple): 59 | episode_reward: float 60 | episode_timesteps: int 61 | n_episodes: int 62 | continue_training: bool 63 | 64 | 65 | class TrainFrequencyUnit(Enum): 66 | STEP = "step" 67 | EPISODE = "episode" 68 | 69 | 70 | class TrainFreq(NamedTuple): 71 | frequency: int 72 | unit: TrainFrequencyUnit # either "step" or "episode" 73 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/vec_env/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa F401 2 | import typing 3 | from copy import deepcopy 4 | from typing import Optional, Type, Union 5 | 6 | from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper 7 | from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv 8 | from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations 9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 10 | from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan 11 | from stable_baselines3.common.vec_env.vec_extract_dict_obs import VecExtractDictObs 12 | from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack 13 | from stable_baselines3.common.vec_env.vec_monitor import VecMonitor 14 | from stable_baselines3.common.vec_env.vec_normalize import VecNormalize 15 | from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage 16 | from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder 17 | 18 | # Avoid circular import 19 | if typing.TYPE_CHECKING: 20 | from stable_baselines3.common.type_aliases import GymEnv 21 | 22 | 23 | def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> Optional[VecEnvWrapper]: 24 | """ 25 | Retrieve a ``VecEnvWrapper`` object by recursively searching. 26 | 27 | :param env: 28 | :param vec_wrapper_class: 29 | :return: 30 | """ 31 | env_tmp = env 32 | while isinstance(env_tmp, VecEnvWrapper): 33 | if isinstance(env_tmp, vec_wrapper_class): 34 | return env_tmp 35 | env_tmp = env_tmp.venv 36 | return None 37 | 38 | 39 | def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]: 40 | """ 41 | :param env: 42 | :return: 43 | """ 44 | return unwrap_vec_wrapper(env, VecNormalize) # pytype:disable=bad-return-type 45 | 46 | 47 | def is_vecenv_wrapped(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> bool: 48 | """ 49 | Check if an environment is already wrapped by a given ``VecEnvWrapper``. 50 | 51 | :param env: 52 | :param vec_wrapper_class: 53 | :return: 54 | """ 55 | return unwrap_vec_wrapper(env, vec_wrapper_class) is not None 56 | 57 | 58 | # Define here to avoid circular import 59 | def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None: 60 | """ 61 | Sync eval env and train env when using VecNormalize 62 | 63 | :param env: 64 | :param eval_env: 65 | """ 66 | env_tmp, eval_env_tmp = env, eval_env 67 | while isinstance(env_tmp, VecEnvWrapper): 68 | if isinstance(env_tmp, VecNormalize): 69 | eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms) 70 | eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms) 71 | env_tmp = env_tmp.venv 72 | eval_env_tmp = eval_env_tmp.venv 73 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/vec_env/dummy_vec_env.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from copy import deepcopy 3 | from typing import Any, Callable, List, Optional, Sequence, Type, Union 4 | 5 | import gym 6 | import numpy as np 7 | 8 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn 9 | from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info 10 | 11 | 12 | class DummyVecEnv(VecEnv): 13 | """ 14 | Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current 15 | Python process. This is useful for computationally simple environment such as ``cartpole-v1``, 16 | as the overhead of multiprocess or multithread outweighs the environment computation time. 17 | This can also be used for RL methods that 18 | require a vectorized environment, but that you want a single environments to train with. 19 | 20 | :param env_fns: a list of functions 21 | that return environments to vectorize 22 | """ 23 | 24 | def __init__(self, env_fns: List[Callable[[], gym.Env]]): 25 | self.envs = [fn() for fn in env_fns] 26 | env = self.envs[0] 27 | VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) 28 | obs_space = env.observation_space 29 | self.keys, shapes, dtypes = obs_space_info(obs_space) 30 | 31 | self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k])) for k in self.keys]) 32 | self.buf_dones = np.zeros((self.num_envs,), dtype=bool) 33 | self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) 34 | self.buf_infos = [{} for _ in range(self.num_envs)] 35 | self.actions = None 36 | self.metadata = env.metadata 37 | 38 | ##### local modification ##### 39 | self.discount = env.discount 40 | self.num_states = env.get_num_of_states() 41 | self.save_q_vals = True if env.agent == "QRDQN" else False 42 | 43 | ##### local modification ##### 44 | def get_obs_at_state(self, state:int) -> np.ndarray: 45 | return self.envs[0].get_obs_at_state(state) 46 | 47 | ##### local modification ##### 48 | def save_quantiles(self, quantiles:np.ndarray) -> None: 49 | self.envs[0].save_quantiles(quantiles) 50 | 51 | def step_async(self, actions: np.ndarray) -> None: 52 | self.actions = actions 53 | 54 | def step_wait(self) -> VecEnvStepReturn: 55 | for env_idx in range(self.num_envs): 56 | obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step( 57 | self.actions[env_idx] 58 | ) 59 | if self.buf_dones[env_idx]: 60 | # save final observation where user can get it, then reset 61 | self.buf_infos[env_idx]["terminal_observation"] = obs 62 | obs = self.envs[env_idx].reset() 63 | self._save_obs(env_idx, obs) 64 | return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos)) 65 | 66 | def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: 67 | seeds = list() 68 | for idx, env in enumerate(self.envs): 69 | seeds.append(env.seed(seed + idx)) 70 | return seeds 71 | 72 | def reset(self) -> VecEnvObs: 73 | for env_idx in range(self.num_envs): 74 | obs = self.envs[env_idx].reset() 75 | self._save_obs(env_idx, obs) 76 | return self._obs_from_buf() 77 | 78 | def close(self) -> None: 79 | for env in self.envs: 80 | env.close() 81 | 82 | def get_images(self) -> Sequence[np.ndarray]: 83 | return [env.render(mode="rgb_array") for env in self.envs] 84 | 85 | def render(self, mode: str = "human") -> Optional[np.ndarray]: 86 | """ 87 | Gym environment rendering. If there are multiple environments then 88 | they are tiled together in one image via ``BaseVecEnv.render()``. 89 | Otherwise (if ``self.num_envs == 1``), we pass the render call directly to the 90 | underlying environment. 91 | 92 | Therefore, some arguments such as ``mode`` will have values that are valid 93 | only when ``num_envs == 1``. 94 | 95 | :param mode: The rendering type. 96 | """ 97 | if self.num_envs == 1: 98 | return self.envs[0].render(mode=mode) 99 | else: 100 | return super().render(mode=mode) 101 | 102 | def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None: 103 | for key in self.keys: 104 | if key is None: 105 | self.buf_obs[key][env_idx] = obs 106 | else: 107 | self.buf_obs[key][env_idx] = obs[key] 108 | 109 | def _obs_from_buf(self) -> VecEnvObs: 110 | return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs)) 111 | 112 | def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: 113 | """Return attribute from vectorized environment (see base class).""" 114 | target_envs = self._get_target_envs(indices) 115 | return [getattr(env_i, attr_name) for env_i in target_envs] 116 | 117 | def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: 118 | """Set attribute inside vectorized environments (see base class).""" 119 | target_envs = self._get_target_envs(indices) 120 | for env_i in target_envs: 121 | setattr(env_i, attr_name, value) 122 | 123 | def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: 124 | """Call instance methods of vectorized environments.""" 125 | target_envs = self._get_target_envs(indices) 126 | return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs] 127 | 128 | def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: 129 | """Check if worker environments are wrapped with a given wrapper""" 130 | target_envs = self._get_target_envs(indices) 131 | # Import here to avoid a circular import 132 | from stable_baselines3.common import env_util 133 | 134 | return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs] 135 | 136 | def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]: 137 | indices = self._get_indices(indices) 138 | return [self.envs[i] for i in indices] 139 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/vec_env/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for dealing with vectorized environments. 3 | """ 4 | from collections import OrderedDict 5 | from typing import Any, Dict, List, Tuple 6 | 7 | import gym 8 | import numpy as np 9 | 10 | from stable_baselines3.common.preprocessing import check_for_nested_spaces 11 | from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs 12 | 13 | 14 | def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: 15 | """ 16 | Deep-copy a dict of numpy arrays. 17 | 18 | :param obs: a dict of numpy arrays. 19 | :return: a dict of copied numpy arrays. 20 | """ 21 | assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'" 22 | return OrderedDict([(k, np.copy(v)) for k, v in obs.items()]) 23 | 24 | 25 | def dict_to_obs(obs_space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs: 26 | """ 27 | Convert an internal representation raw_obs into the appropriate type 28 | specified by space. 29 | 30 | :param obs_space: an observation space. 31 | :param obs_dict: a dict of numpy arrays. 32 | :return: returns an observation of the same type as space. 33 | If space is Dict, function is identity; if space is Tuple, converts dict to Tuple; 34 | otherwise, space is unstructured and returns the value raw_obs[None]. 35 | """ 36 | if isinstance(obs_space, gym.spaces.Dict): 37 | return obs_dict 38 | elif isinstance(obs_space, gym.spaces.Tuple): 39 | assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space" 40 | return tuple((obs_dict[i] for i in range(len(obs_space.spaces)))) 41 | else: 42 | assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space" 43 | return obs_dict[None] 44 | 45 | 46 | def obs_space_info(obs_space: gym.spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[int, ...]], Dict[Any, np.dtype]]: 47 | """ 48 | Get dict-structured information about a gym.Space. 49 | 50 | Dict spaces are represented directly by their dict of subspaces. 51 | Tuple spaces are converted into a dict with keys indexing into the tuple. 52 | Unstructured spaces are represented by {None: obs_space}. 53 | 54 | :param obs_space: an observation space 55 | :return: A tuple (keys, shapes, dtypes): 56 | keys: a list of dict keys. 57 | shapes: a dict mapping keys to shapes. 58 | dtypes: a dict mapping keys to dtypes. 59 | """ 60 | check_for_nested_spaces(obs_space) 61 | if isinstance(obs_space, gym.spaces.Dict): 62 | assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces" 63 | subspaces = obs_space.spaces 64 | elif isinstance(obs_space, gym.spaces.Tuple): 65 | subspaces = {i: space for i, space in enumerate(obs_space.spaces)} 66 | else: 67 | assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'" 68 | subspaces = {None: obs_space} 69 | keys = [] 70 | shapes = {} 71 | dtypes = {} 72 | for key, box in subspaces.items(): 73 | keys.append(key) 74 | shapes[key] = box.shape 75 | dtypes[key] = box.dtype 76 | return keys, shapes, dtypes 77 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/vec_env/vec_check_nan.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | 5 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper 6 | 7 | 8 | class VecCheckNan(VecEnvWrapper): 9 | """ 10 | NaN and inf checking wrapper for vectorized environment, will raise a warning by default, 11 | allowing you to know from what the NaN of inf originated from. 12 | 13 | :param venv: the vectorized environment to wrap 14 | :param raise_exception: Whether or not to raise a ValueError, instead of a UserWarning 15 | :param warn_once: Whether or not to only warn once. 16 | :param check_inf: Whether or not to check for +inf or -inf as well 17 | """ 18 | 19 | def __init__(self, venv: VecEnv, raise_exception: bool = False, warn_once: bool = True, check_inf: bool = True): 20 | VecEnvWrapper.__init__(self, venv) 21 | self.raise_exception = raise_exception 22 | self.warn_once = warn_once 23 | self.check_inf = check_inf 24 | self._actions = None 25 | self._observations = None 26 | self._user_warned = False 27 | 28 | def step_async(self, actions: np.ndarray) -> None: 29 | self._check_val(async_step=True, actions=actions) 30 | 31 | self._actions = actions 32 | self.venv.step_async(actions) 33 | 34 | def step_wait(self) -> VecEnvStepReturn: 35 | observations, rewards, news, infos = self.venv.step_wait() 36 | 37 | self._check_val(async_step=False, observations=observations, rewards=rewards, news=news) 38 | 39 | self._observations = observations 40 | return observations, rewards, news, infos 41 | 42 | def reset(self) -> VecEnvObs: 43 | observations = self.venv.reset() 44 | self._actions = None 45 | 46 | self._check_val(async_step=False, observations=observations) 47 | 48 | self._observations = observations 49 | return observations 50 | 51 | def _check_val(self, *, async_step: bool, **kwargs) -> None: 52 | # if warn and warn once and have warned once: then stop checking 53 | if not self.raise_exception and self.warn_once and self._user_warned: 54 | return 55 | 56 | found = [] 57 | for name, val in kwargs.items(): 58 | has_nan = np.any(np.isnan(val)) 59 | has_inf = self.check_inf and np.any(np.isinf(val)) 60 | if has_inf: 61 | found.append((name, "inf")) 62 | if has_nan: 63 | found.append((name, "nan")) 64 | 65 | if found: 66 | self._user_warned = True 67 | msg = "" 68 | for i, (name, type_val) in enumerate(found): 69 | msg += f"found {type_val} in {name}" 70 | if i != len(found) - 1: 71 | msg += ", " 72 | 73 | msg += ".\r\nOriginated from the " 74 | 75 | if not async_step: 76 | if self._actions is None: 77 | msg += "environment observation (at reset)" 78 | else: 79 | msg += f"environment, Last given value was: \r\n\taction={self._actions}" 80 | else: 81 | msg += f"RL model, Last given value was: \r\n\tobservations={self._observations}" 82 | 83 | if self.raise_exception: 84 | raise ValueError(msg) 85 | else: 86 | warnings.warn(msg, UserWarning) 87 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/vec_env/vec_extract_dict_obs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper 4 | 5 | 6 | class VecExtractDictObs(VecEnvWrapper): 7 | """ 8 | A vectorized wrapper for extracting dictionary observations. 9 | 10 | :param venv: The vectorized environment 11 | :param key: The key of the dictionary observation 12 | """ 13 | 14 | def __init__(self, venv: VecEnv, key: str): 15 | self.key = key 16 | super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key]) 17 | 18 | def reset(self) -> np.ndarray: 19 | obs = self.venv.reset() 20 | return obs[self.key] 21 | 22 | def step_wait(self) -> VecEnvStepReturn: 23 | obs, reward, done, info = self.venv.step_wait() 24 | return obs[self.key], reward, done, info 25 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/vec_env/vec_frame_stack.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Union 2 | 3 | import numpy as np 4 | from gym import spaces 5 | 6 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper 7 | from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations 8 | 9 | 10 | class VecFrameStack(VecEnvWrapper): 11 | """ 12 | Frame stacking wrapper for vectorized environment. Designed for image observations. 13 | 14 | Uses the StackedObservations class, or StackedDictObservations depending on the observations space 15 | 16 | :param venv: the vectorized environment to wrap 17 | :param n_stack: Number of frames to stack 18 | :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension. 19 | If None, automatically detect channel to stack over in case of image observation or default to "last" (default). 20 | Alternatively channels_order can be a dictionary which can be used with environments with Dict observation spaces 21 | """ 22 | 23 | def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Dict[str, str]]] = None): 24 | self.venv = venv 25 | self.n_stack = n_stack 26 | 27 | wrapped_obs_space = venv.observation_space 28 | 29 | if isinstance(wrapped_obs_space, spaces.Box): 30 | assert not isinstance( 31 | channels_order, dict 32 | ), f"Expected None or string for channels_order but received {channels_order}" 33 | self.stackedobs = StackedObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order) 34 | 35 | elif isinstance(wrapped_obs_space, spaces.Dict): 36 | self.stackedobs = StackedDictObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order) 37 | 38 | else: 39 | raise Exception("VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces") 40 | 41 | observation_space = self.stackedobs.stack_observation_space(wrapped_obs_space) 42 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space) 43 | 44 | def step_wait( 45 | self, 46 | ) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]: 47 | 48 | observations, rewards, dones, infos = self.venv.step_wait() 49 | 50 | observations, infos = self.stackedobs.update(observations, dones, infos) 51 | 52 | return observations, rewards, dones, infos 53 | 54 | def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: 55 | """ 56 | Reset all environments 57 | """ 58 | observation = self.venv.reset() # pytype:disable=annotation-type-mismatch 59 | 60 | observation = self.stackedobs.reset(observation) 61 | return observation 62 | 63 | def close(self) -> None: 64 | self.venv.close() 65 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/vec_env/vec_monitor.py: -------------------------------------------------------------------------------- 1 | import time 2 | import warnings 3 | from typing import Optional, Tuple 4 | 5 | import numpy as np 6 | 7 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper 8 | 9 | 10 | class VecMonitor(VecEnvWrapper): 11 | """ 12 | A vectorized monitor wrapper for *vectorized* Gym environments, 13 | it is used to record the episode reward, length, time and other data. 14 | 15 | Some environments like `openai/procgen `_ 16 | or `gym3 `_ directly initialize the 17 | vectorized environments, without giving us a chance to use the ``Monitor`` 18 | wrapper. So this class simply does the job of the ``Monitor`` wrapper on 19 | a vectorized level. 20 | 21 | :param venv: The vectorized environment 22 | :param filename: the location to save a log file, can be None for no log 23 | :param info_keywords: extra information to log, from the information return of env.step() 24 | """ 25 | 26 | def __init__( 27 | self, 28 | venv: VecEnv, 29 | filename: Optional[str] = None, 30 | info_keywords: Tuple[str, ...] = (), 31 | ): 32 | # Avoid circular import 33 | from stable_baselines3.common.monitor import Monitor, ResultsWriter 34 | 35 | # This check is not valid for special `VecEnv` 36 | # like the ones created by Procgen, that does follow completely 37 | # the `VecEnv` interface 38 | try: 39 | is_wrapped_with_monitor = venv.env_is_wrapped(Monitor)[0] 40 | except AttributeError: 41 | is_wrapped_with_monitor = False 42 | 43 | if is_wrapped_with_monitor: 44 | warnings.warn( 45 | "The environment is already wrapped with a `Monitor` wrapper" 46 | "but you are wrapping it with a `VecMonitor` wrapper, the `Monitor` statistics will be" 47 | "overwritten by the `VecMonitor` ones.", 48 | UserWarning, 49 | ) 50 | 51 | VecEnvWrapper.__init__(self, venv) 52 | self.episode_returns = None 53 | self.episode_lengths = None 54 | self.episode_count = 0 55 | self.t_start = time.time() 56 | 57 | env_id = None 58 | if hasattr(venv, "spec") and venv.spec is not None: 59 | env_id = venv.spec.id 60 | 61 | if filename: 62 | self.results_writer = ResultsWriter( 63 | filename, header={"t_start": self.t_start, "env_id": env_id}, extra_keys=info_keywords 64 | ) 65 | else: 66 | self.results_writer = None 67 | self.info_keywords = info_keywords 68 | 69 | def reset(self) -> VecEnvObs: 70 | obs = self.venv.reset() 71 | self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) 72 | self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) 73 | return obs 74 | 75 | def step_wait(self) -> VecEnvStepReturn: 76 | obs, rewards, dones, infos = self.venv.step_wait() 77 | self.episode_returns += rewards 78 | self.episode_lengths += 1 79 | new_infos = list(infos[:]) 80 | for i in range(len(dones)): 81 | if dones[i]: 82 | info = infos[i].copy() 83 | episode_return = self.episode_returns[i] 84 | episode_length = self.episode_lengths[i] 85 | episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)} 86 | info["episode"] = episode_info 87 | self.episode_count += 1 88 | self.episode_returns[i] = 0 89 | self.episode_lengths[i] = 0 90 | if self.results_writer: 91 | self.results_writer.write_row(episode_info) 92 | new_infos[i] = info 93 | return obs, rewards, dones, new_infos 94 | 95 | def close(self) -> None: 96 | if self.results_writer: 97 | self.results_writer.close() 98 | return self.venv.close() 99 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/vec_env/vec_transpose.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Dict, Union 3 | 4 | import numpy as np 5 | from gym import spaces 6 | 7 | from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first 8 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper 9 | 10 | 11 | class VecTransposeImage(VecEnvWrapper): 12 | """ 13 | Re-order channels, from HxWxC to CxHxW. 14 | It is required for PyTorch convolution layers. 15 | 16 | :param venv: 17 | """ 18 | 19 | def __init__(self, venv: VecEnv): 20 | assert is_image_space(venv.observation_space) or isinstance( 21 | venv.observation_space, spaces.dict.Dict 22 | ), "The observation space must be an image or dictionary observation space" 23 | 24 | if isinstance(venv.observation_space, spaces.dict.Dict): 25 | self.image_space_keys = [] 26 | observation_space = deepcopy(venv.observation_space) 27 | for key, space in observation_space.spaces.items(): 28 | if is_image_space(space): 29 | # Keep track of which keys should be transposed later 30 | self.image_space_keys.append(key) 31 | observation_space.spaces[key] = self.transpose_space(space, key) 32 | else: 33 | observation_space = self.transpose_space(venv.observation_space) 34 | super(VecTransposeImage, self).__init__(venv, observation_space=observation_space) 35 | 36 | @staticmethod 37 | def transpose_space(observation_space: spaces.Box, key: str = "") -> spaces.Box: 38 | """ 39 | Transpose an observation space (re-order channels). 40 | 41 | :param observation_space: 42 | :param key: In case of dictionary space, the key of the observation space. 43 | :return: 44 | """ 45 | # Sanity checks 46 | assert is_image_space(observation_space), "The observation space must be an image" 47 | assert not is_image_space_channels_first( 48 | observation_space 49 | ), f"The observation space {key} must follow the channel last convention" 50 | height, width, channels = observation_space.shape 51 | new_shape = (channels, height, width) 52 | return spaces.Box(low=0, high=255, shape=new_shape, dtype=observation_space.dtype) 53 | 54 | @staticmethod 55 | def transpose_image(image: np.ndarray) -> np.ndarray: 56 | """ 57 | Transpose an image or batch of images (re-order channels). 58 | 59 | :param image: 60 | :return: 61 | """ 62 | if len(image.shape) == 3: 63 | return np.transpose(image, (2, 0, 1)) 64 | return np.transpose(image, (0, 3, 1, 2)) 65 | 66 | def transpose_observations(self, observations: Union[np.ndarray, Dict]) -> Union[np.ndarray, Dict]: 67 | """ 68 | Transpose (if needed) and return new observations. 69 | 70 | :param observations: 71 | :return: Transposed observations 72 | """ 73 | if isinstance(observations, dict): 74 | # Avoid modifying the original object in place 75 | observations = deepcopy(observations) 76 | for k in self.image_space_keys: 77 | observations[k] = self.transpose_image(observations[k]) 78 | else: 79 | observations = self.transpose_image(observations) 80 | return observations 81 | 82 | def step_wait(self) -> VecEnvStepReturn: 83 | observations, rewards, dones, infos = self.venv.step_wait() 84 | 85 | # Transpose the terminal observations 86 | for idx, done in enumerate(dones): 87 | if not done: 88 | continue 89 | if "terminal_observation" in infos[idx]: 90 | infos[idx]["terminal_observation"] = self.transpose_observations(infos[idx]["terminal_observation"]) 91 | 92 | return self.transpose_observations(observations), rewards, dones, infos 93 | 94 | def reset(self) -> Union[np.ndarray, Dict]: 95 | """ 96 | Reset all environments 97 | """ 98 | return self.transpose_observations(self.venv.reset()) 99 | 100 | def close(self) -> None: 101 | self.venv.close() 102 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/common/vec_env/vec_video_recorder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Callable 3 | 4 | from gym.wrappers.monitoring import video_recorder 5 | 6 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper 7 | from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv 8 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 9 | 10 | 11 | class VecVideoRecorder(VecEnvWrapper): 12 | """ 13 | Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video. 14 | It requires ffmpeg or avconv to be installed on the machine. 15 | 16 | :param venv: 17 | :param video_folder: Where to save videos 18 | :param record_video_trigger: Function that defines when to start recording. 19 | The function takes the current number of step, 20 | and returns whether we should start recording or not. 21 | :param video_length: Length of recorded videos 22 | :param name_prefix: Prefix to the video name 23 | """ 24 | 25 | def __init__( 26 | self, 27 | venv: VecEnv, 28 | video_folder: str, 29 | record_video_trigger: Callable[[int], bool], 30 | video_length: int = 200, 31 | name_prefix: str = "rl-video", 32 | ): 33 | 34 | VecEnvWrapper.__init__(self, venv) 35 | 36 | self.env = venv 37 | # Temp variable to retrieve metadata 38 | temp_env = venv 39 | 40 | # Unwrap to retrieve metadata dict 41 | # that will be used by gym recorder 42 | while isinstance(temp_env, VecEnvWrapper): 43 | temp_env = temp_env.venv 44 | 45 | if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv): 46 | metadata = temp_env.get_attr("metadata")[0] 47 | else: 48 | metadata = temp_env.metadata 49 | 50 | self.env.metadata = metadata 51 | 52 | self.record_video_trigger = record_video_trigger 53 | self.video_recorder = None 54 | 55 | self.video_folder = os.path.abspath(video_folder) 56 | # Create output folder if needed 57 | os.makedirs(self.video_folder, exist_ok=True) 58 | 59 | self.name_prefix = name_prefix 60 | self.step_id = 0 61 | self.video_length = video_length 62 | 63 | self.recording = False 64 | self.recorded_frames = 0 65 | 66 | def reset(self) -> VecEnvObs: 67 | obs = self.venv.reset() 68 | self.start_video_recorder() 69 | return obs 70 | 71 | def start_video_recorder(self) -> None: 72 | self.close_video_recorder() 73 | 74 | video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}" 75 | base_path = os.path.join(self.video_folder, video_name) 76 | self.video_recorder = video_recorder.VideoRecorder( 77 | env=self.env, base_path=base_path, metadata={"step_id": self.step_id} 78 | ) 79 | 80 | self.video_recorder.capture_frame() 81 | self.recorded_frames = 1 82 | self.recording = True 83 | 84 | def _video_enabled(self) -> bool: 85 | return self.record_video_trigger(self.step_id) 86 | 87 | def step_wait(self) -> VecEnvStepReturn: 88 | obs, rews, dones, infos = self.venv.step_wait() 89 | 90 | self.step_id += 1 91 | if self.recording: 92 | self.video_recorder.capture_frame() 93 | self.recorded_frames += 1 94 | if self.recorded_frames > self.video_length: 95 | print(f"Saving video to {self.video_recorder.path}") 96 | self.close_video_recorder() 97 | elif self._video_enabled(): 98 | self.start_video_recorder() 99 | 100 | return obs, rews, dones, infos 101 | 102 | def close_video_recorder(self) -> None: 103 | if self.recording: 104 | self.video_recorder.close() 105 | self.recording = False 106 | self.recorded_frames = 1 107 | 108 | def close(self) -> None: 109 | VecEnvWrapper.close(self) 110 | self.close_video_recorder() 111 | 112 | def __del__(self): 113 | self.close() 114 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/ddpg/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.ddpg.ddpg import DDPG 2 | from stable_baselines3.ddpg.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 3 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/ddpg/ddpg.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Tuple, Type, Union 2 | 3 | import torch as th 4 | 5 | from stable_baselines3.common.buffers import ReplayBuffer 6 | from stable_baselines3.common.noise import ActionNoise 7 | from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm 8 | from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule 9 | from stable_baselines3.td3.policies import TD3Policy 10 | from stable_baselines3.td3.td3 import TD3 11 | 12 | 13 | class DDPG(TD3): 14 | """ 15 | Deep Deterministic Policy Gradient (DDPG). 16 | 17 | Deterministic Policy Gradient: http://proceedings.mlr.press/v32/silver14.pdf 18 | DDPG Paper: https://arxiv.org/abs/1509.02971 19 | Introduction to DDPG: https://spinningup.openai.com/en/latest/algorithms/ddpg.html 20 | 21 | Note: we treat DDPG as a special case of its successor TD3. 22 | 23 | :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) 24 | :param env: The environment to learn from (if registered in Gym, can be str) 25 | :param learning_rate: learning rate for adam optimizer, 26 | the same learning rate will be used for all networks (Q-Values, Actor and Value function) 27 | it can be a function of the current progress remaining (from 1 to 0) 28 | :param buffer_size: size of the replay buffer 29 | :param learning_starts: how many steps of the model to collect transitions for before learning starts 30 | :param batch_size: Minibatch size for each gradient update 31 | :param tau: the soft update coefficient ("Polyak update", between 0 and 1) 32 | :param gamma: the discount factor 33 | :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit 34 | like ``(5, "step")`` or ``(2, "episode")``. 35 | :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``) 36 | Set to ``-1`` means to do as many gradient steps as steps done in the environment 37 | during the rollout. 38 | :param action_noise: the action noise type (None by default), this can help 39 | for hard exploration problem. Cf common.noise for the different action noise type. 40 | :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). 41 | If ``None``, it will be automatically selected. 42 | :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. 43 | :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer 44 | at a cost of more complexity. 45 | See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 46 | :param create_eval_env: Whether to create a second environment that will be 47 | used for evaluating the agent periodically. (Only available when passing string for the environment) 48 | :param policy_kwargs: additional arguments to be passed to the policy on creation 49 | :param verbose: the verbosity level: 0 no output, 1 info, 2 debug 50 | :param seed: Seed for the pseudo random generators 51 | :param device: Device (cpu, cuda, ...) on which the code should be run. 52 | Setting it to auto, the code will be run on the GPU if possible. 53 | :param _init_setup_model: Whether or not to build the network at the creation of the instance 54 | """ 55 | 56 | def __init__( 57 | self, 58 | policy: Union[str, Type[TD3Policy]], 59 | env: Union[GymEnv, str], 60 | learning_rate: Union[float, Schedule] = 1e-3, 61 | buffer_size: int = 1_000_000, # 1e6 62 | learning_starts: int = 100, 63 | batch_size: int = 100, 64 | tau: float = 0.005, 65 | gamma: float = 0.99, 66 | train_freq: Union[int, Tuple[int, str]] = (1, "episode"), 67 | gradient_steps: int = -1, 68 | action_noise: Optional[ActionNoise] = None, 69 | replay_buffer_class: Optional[ReplayBuffer] = None, 70 | replay_buffer_kwargs: Optional[Dict[str, Any]] = None, 71 | optimize_memory_usage: bool = False, 72 | tensorboard_log: Optional[str] = None, 73 | create_eval_env: bool = False, 74 | policy_kwargs: Optional[Dict[str, Any]] = None, 75 | verbose: int = 0, 76 | seed: Optional[int] = None, 77 | device: Union[th.device, str] = "auto", 78 | _init_setup_model: bool = True, 79 | ): 80 | 81 | super(DDPG, self).__init__( 82 | policy=policy, 83 | env=env, 84 | learning_rate=learning_rate, 85 | buffer_size=buffer_size, 86 | learning_starts=learning_starts, 87 | batch_size=batch_size, 88 | tau=tau, 89 | gamma=gamma, 90 | train_freq=train_freq, 91 | gradient_steps=gradient_steps, 92 | action_noise=action_noise, 93 | replay_buffer_class=replay_buffer_class, 94 | replay_buffer_kwargs=replay_buffer_kwargs, 95 | policy_kwargs=policy_kwargs, 96 | tensorboard_log=tensorboard_log, 97 | verbose=verbose, 98 | device=device, 99 | create_eval_env=create_eval_env, 100 | seed=seed, 101 | optimize_memory_usage=optimize_memory_usage, 102 | # Remove all tricks from TD3 to obtain DDPG: 103 | # we still need to specify target_policy_noise > 0 to avoid errors 104 | policy_delay=1, 105 | target_noise_clip=0.0, 106 | target_policy_noise=0.1, 107 | _init_setup_model=False, 108 | ) 109 | 110 | # Use only one critic 111 | if "n_critics" not in self.policy_kwargs: 112 | self.policy_kwargs["n_critics"] = 1 113 | 114 | if _init_setup_model: 115 | self._setup_model() 116 | 117 | def learn( 118 | self, 119 | total_timesteps: int, 120 | callback: MaybeCallback = None, 121 | log_interval: int = 4, 122 | eval_env: Optional[GymEnv] = None, 123 | eval_freq: int = -1, 124 | n_eval_episodes: int = 5, 125 | tb_log_name: str = "DDPG", 126 | eval_log_path: Optional[str] = None, 127 | reset_num_timesteps: bool = True, 128 | ) -> OffPolicyAlgorithm: 129 | 130 | return super(DDPG, self).learn( 131 | total_timesteps=total_timesteps, 132 | callback=callback, 133 | log_interval=log_interval, 134 | eval_env=eval_env, 135 | eval_freq=eval_freq, 136 | n_eval_episodes=n_eval_episodes, 137 | tb_log_name=tb_log_name, 138 | eval_log_path=eval_log_path, 139 | reset_num_timesteps=reset_num_timesteps, 140 | ) 141 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/ddpg/policies.py: -------------------------------------------------------------------------------- 1 | # DDPG can be view as a special case of TD3 2 | from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy # noqa:F401 3 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/dqn/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.dqn.dqn import DQN 2 | from stable_baselines3.dqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 3 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/her/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy 2 | from stable_baselines3.her.her_replay_buffer import HerReplayBuffer 3 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/her/goal_selection_strategy.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class GoalSelectionStrategy(Enum): 5 | """ 6 | The strategies for selecting new goals when 7 | creating artificial transitions. 8 | """ 9 | 10 | # Select a goal that was achieved 11 | # after the current step, in the same episode 12 | FUTURE = 0 13 | # Select the goal that was achieved 14 | # at the end of the episode 15 | FINAL = 1 16 | # Select a goal that was achieved in the episode 17 | EPISODE = 2 18 | 19 | 20 | # For convenience 21 | # that way, we can use string to select a strategy 22 | KEY_TO_GOAL_STRATEGY = { 23 | "future": GoalSelectionStrategy.FUTURE, 24 | "final": GoalSelectionStrategy.FINAL, 25 | "episode": GoalSelectionStrategy.EPISODE, 26 | } 27 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/local_modifications.txt: -------------------------------------------------------------------------------- 1 | modified: stable_baselines3/common/evaluation.py 2 | stable_baselines3/common/vec_env/dummy_vec_env.py 3 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 2 | from stable_baselines3.ppo.ppo import PPO 3 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/ppo/policies.py: -------------------------------------------------------------------------------- 1 | # This file is here just to define MlpPolicy/CnnPolicy 2 | # that work for PPO 3 | from stable_baselines3.common.policies import ( 4 | ActorCriticCnnPolicy, 5 | ActorCriticPolicy, 6 | MultiInputActorCriticPolicy, 7 | register_policy, 8 | ) 9 | 10 | MlpPolicy = ActorCriticPolicy 11 | CnnPolicy = ActorCriticCnnPolicy 12 | MultiInputPolicy = MultiInputActorCriticPolicy 13 | 14 | register_policy("MlpPolicy", ActorCriticPolicy) 15 | register_policy("CnnPolicy", ActorCriticCnnPolicy) 16 | register_policy("MultiInputPolicy", MultiInputPolicy) 17 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RobustFieldAutonomyLab/Stochastic_Road_Network/1cdcd41c7311560bf7a5df0d4d8bca829fe2b958/thirdparty/stable_baselines3/py.typed -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/sac/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 2 | from stable_baselines3.sac.sac import SAC 3 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/td3/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 2 | from stable_baselines3.td3.td3 import TD3 3 | -------------------------------------------------------------------------------- /thirdparty/stable_baselines3/version.txt: -------------------------------------------------------------------------------- 1 | 1.3.1a1 2 | --------------------------------------------------------------------------------