├── LICENSE
├── README.md
├── assets
└── trained_icct.gif
├── icct
├── core
│ ├── icct.py
│ └── icct_helpers.py
├── dagger
│ ├── dagger.py
│ ├── dt_policy.py
│ ├── oracle_models
│ │ ├── fig8_mlp_max.zip
│ │ ├── ip_mlp_max.zip
│ │ ├── lk_mlp_max.zip
│ │ ├── ll_mlp_max.zip
│ │ ├── mlr_mlp_max.zip
│ │ └── slr_mlp_max.zip
│ ├── train_dagger.py
│ ├── train_dagger.sh
│ └── train_q_dagger.sh
├── plot
│ ├── learning_curve_plot.py
│ └── learning_curve_plotter.py
├── rl_helpers
│ ├── ddt_sac_policy.py
│ ├── ddt_td3_policy.py
│ ├── sac.py
│ ├── sac_policies.py
│ ├── save_after_ep_callback.py
│ ├── td3.py
│ └── td3_policies.py
├── runfiles
│ ├── test.py
│ ├── test.sh
│ ├── train.py
│ ├── train_fig8.sh
│ ├── train_ip.sh
│ ├── train_lk.sh
│ ├── train_ll.sh
│ ├── train_ring_accel.sh
│ └── train_ring_accel_lc.sh
└── sumo_envs
│ ├── accel_figure8.py
│ ├── accel_ring.py
│ └── accel_ring_multilane.py
└── setup.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 CORE-Robotics-Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ICCT Implementation [](https://opensource.org/licenses/MIT)
2 | [**Talk (RSS 2022)**](https://www.youtube.com/watch?v=V17vQSZP5Zs&t=6798s) | [**Robot Demo**](https://sites.google.com/view/icctree) | [**Paper (RSS 2022)**](https://www.roboticsproceedings.org/rss18/p068.pdf)
3 |
4 | This is the codebase for "[Learning Interpretable, High-Performing Policies for Autonomous Driving](http://www.roboticsproceedings.org/rss18/p068.pdf)", which is published in [Robotics: Science and Systems (RSS), 2022](http://www.roboticsproceedings.org/rss18/index.html).
5 |
6 | Authors: [Rohan Paleja*](https://rohanpaleja.com/), [Yaru Niu*](https://www.yaruniu.com/), [Andrew Silva](https://www.andrew-silva.com/), Chace Ritchie, Sugju Choi, [Matthew Gombolay](https://core-robotics.gatech.edu/people/matthew-gombolay/)
7 |
8 | \* indicates co-first authors.
9 |
10 |
11 | 
12 | Trained High-Performance ICCT Polices in Six Tested Domains.
13 |
14 |
15 | ## Dependencies
16 | * [PyTorch](https://pytorch.org/) 1.5.0 (GPU)
17 | * [Stable-Baselines3](https://stable-baselines3.readthedocs.io/en/master/) 1.1.0a11 (verified but not required)
18 | * [OpenAI Gym](https://github.com/openai/gym)
19 | * [box2d-py](https://github.com/openai/box2d-py)
20 | * [MuJoCo](https://mujoco.org/), [mujoco-py](https://github.com/openai/mujoco-py)
21 | * [highway-env](https://github.com/eleurent/highway-env)
22 | * [Flow](https://github.com/flow-project/flow) and [SUMO](https://github.com/eclipse/sumo) ([Installing Flow and SUMO](https://flow.readthedocs.io/en/latest/flow_setup.html#installing-flow-and-sumo))
23 | * [scikit-learn](https://scikit-learn.org/stable/install.html)
24 |
25 | ## Installation of ICCT
26 | ```
27 | git clone https://github.com/CORE-Robotics-Lab/ICCT.git
28 | cd ICCT
29 | pip install -e .
30 | ```
31 |
32 | ## Training
33 | In this codebase, we provide all the methods presented in the paper including CDDT (M1), CDDT-controllers (M2), ICCT-static (M3), ICCT-complete (M4), ICCT-L1-sparse (M5-a), ICCT-n-feature (M5-b), MLP-Max (large), MLP-Upper (medium), and MLP-Lower (small). Run `python icct/runfiles/train.py --help` to check all the options for training. Examples of training ICCT-2-feature can be found in `icct/runfiles/`. All the methods are trained using [Soft Actor-Critic (SAC)](https://arxiv.org/abs/1801.01290) in our paper. We also provide the implementation for [Twin Delayed DDPG (TD3)](https://arxiv.org/abs/1802.09477). Here we provide instructions on using method-specific arguments.
34 | * Arguments for all methods
35 | * `--env_name`: environment to run on
36 | * `--alg_type`: use SAC or TD3
37 | * `--policy_type`: use DDT or MLP as the policy network
38 | * `--seed`: set the seed number
39 | * `--gpu`: add to use GPU
40 | * `--lr`: the learning rate
41 | * `--buffer_size`: the buffer size
42 | * `--batch_size`: the batch size
43 | * `--gamma`: the discount factor
44 | * `--tau`: the soft update coefficient (between 0 and 1) in SAC
45 | * `--learning_starts`: how many steps of the model to collect transitions for before learning starts
46 | * `--training_steps`: total steps for training the model
47 | * `--min_reward`: the minimum reward to save the model
48 | * `--save_path`: the path to save the models and logged files
49 | * `--n_eval_episodes`: the number of episodes for each evaluation during training
50 | * `--eval_freq`: evaluation frequence (evaluating the model for every fixed number of steps) of the model during training
51 | * `--log_interval`: the number of episodes before logging
52 | * Arguments for MLP:
53 | * `--mlp_size`: choose the size of MLP to use (large: MLP-Max; medium: MLP-Upper; small: MLP-Lower)
54 | * Arguments for DDT (including ICCT):
55 | * `--num_leaves`: the number of leaves used in ddt (2^n)
56 | * `--ddt_lr`: a specific learning rate used for DDT (the policy network), the learning rate for the critic network will be specified by `--lr`
57 | * `--use_individual_alpha`: if use different alphas for different nodes (sometimes it helps boost the performance)
58 | * To activate CDDT (M1), only set `--policy_type` to `ddt`, and do not use `--submodels` or `--hard_node`
59 | * To activate CDDT-controllers (M2), use `--submodels` and set `--sparse_submodel_type` to 0
60 | * To activate ICCT-static (M3), use `--hard_node`
61 | * To activate ICCT-complete (M4), use `--hard_node`, `--submodels`, and set `--sparse_submodel_type` to 0
62 | * To activate ICCT-L1-sparse (M5-a), use `--hard_node`, `--submodels`, set `--sparse_submodel_type` to 1, and use the following arguments:
63 | * `--l1_reg_coeff`: the coefficient of the L1 regularization
64 | * `--l1_reg_bias`: if consider biases in the L1 loss (not recommended)
65 | * `--l1_hard_attn`: if only sample one leaf node's linear controller to perform L1 regularization for each update, and this can be helpful in enforcing sparsity on each linear controller
66 | * We choose L1 regularization over L2 because L1 is more likely to push coefficients to zeros
67 | * To activate ICCT-n-feature (M5-b, "n" is the number of features selected by each leaf's linear sub-controller), use `--hard_node`, `--submodels`, set `--sparse_submodel_type` to 2, and use the following arguments:
68 | * `--num_sub_features`: the number of chosen features for submodels
69 | * `--argmax_tau`: the temperature of the diff_argmax function
70 | * `--use_gumbel_softmax`: include to replace the Argmax operation in the paper with Gumbel-Softmax
71 |
72 | ## Loading and Testing
73 | All the MLP and DDT-based methods are evaluated in real time throughout the training process. Here we provide modules to load and test trained models. Please set up arguments and run `sh test.sh` in `icct/runfiles/`. For each DDT-based method, two types of performance can be output:
74 | * Fuzzy performance: the performance is evaluated by directly loading the trained model
75 | * Crisp performance: the performance is evaluated by a processed discretized (crisp) model. The discretization process is proposed in https://arxiv.org/pdf/1903.09338.pdf
76 |
77 | For any ICCT methods, fuzzy and crisp performance will be the same, while the crisp performance of CDDT (CDDT-Crisp) or CDDT-controllers (CDDT-controllers Crisp) will change and usually drop drastically.
78 |
79 | ## Visualization of Learning Curves
80 | During training, the training process can be monitored by tensorboard. Please run `tensorboard --logdir TARGET_PATH`, where `TARGET_PATH` is the path to your saved log files. We also provide visualization of mean rollout rewards and mean evaluation rewards througout the training process of multiple runs (seeds). The csv files of these two kinds of rewards are saved in the same folder of the trained models. Please copy the csv files from different runs (seeds) and different methods in the same tested domain to one folder. Run `learning_curve_plot.py` in `icct/plot/` and include the following the arguments:
81 | * `--log_dir`: the path to the data
82 | * `--eval_freq`: evaluation frequence used during training (has to be the same as the one in training)
83 | * `--n_eval_episodes`: the number of episodes for each evaluation during training (has to be the same as the one in training)
84 | * `--eval_smooth_window_size`: the sliding window size to smooth the evaluation rewards
85 | * `--non_eval_sample_freq`: the sample frequence of the rollout rewards for plotting
86 | * `--non_eval_smooth_window_size`: the sliding window size to smooth the sampled rollout rewards
87 |
88 | ## Imitation Learning - DAgger
89 | We provide an implementation of imitation learning by decision trees using [Dataset Aggregation (DAgger)](http://proceedings.mlr.press/v15/ross11a/ross11a.pdf). Please set up arguments and run `sh train.sh` in `icct/dagger/`. The oracle models are picked from the best of MLP-Max from five seeds trained by SAC, which can be found in `icct/dagger/oracle_models/`. We have improved the implementation of DAgger since paper submission and update the results averaged over five seeds as follows.
90 | | Environment | Inverted Pendulum | Lunar Lander | Lane Keeping | Single-Lane Ring | Multi-Lane Ring | Figure-8 |
91 | | :----------: | :----------: | :----------: | :----------: | :----------: | :----------: | :----------: |
92 | | Number of Leaves | 32 | 32 | 16 | 16 | 32 | 16 |
93 | | Best Rollout Performance | $853.1\pm38.2$ | $245.7\pm8.4$ | $393.1\pm14.2$ | $121.9\pm0.03$ | $1260.4\pm4.6$ | $1116.4\pm8.3$ |
94 | | Evaluation Performance | $776.6\pm54.2$ | $184.7\pm17.3$ | $395.2\pm13.8$ | $121.5\pm0.01$ | $1249.4\pm3.4$ | $1113.8\pm9.5$ |
95 | | Oracle Performance | $1000.0$ | $301.2$ | $494.1$ | $122.29$ | $1194.5$ | $1126.3$ |
96 |
97 | ## Citation
98 | If you find our paper or repo helpful to your research, please consider citing the paper:
99 | ```
100 | @inproceedings{icct-rss-22,
101 | title={Learning Interpretable, High-Performing Policies for Autonomous Driving},
102 | author={Paleja, Rohan and Niu, Yaru and Silva, Andrew and Ritchie, Chace and Choi, Sugju and Gombolay, Matthew},
103 | booktitle={Robotics: Science and Systems (RSS)},
104 | year={2022}
105 | }
106 | ```
107 |
108 | ## Acknowledgments
109 | Some parts of this codebase are inspired from or based on several public repos:
110 | * [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3)
111 | * [DDTs](https://github.com/CORE-Robotics-Lab/Interpretable_DDTS_AISTATS2020)
112 | * [VIPER](https://github.com/obastani/viper/)
113 |
--------------------------------------------------------------------------------
/assets/trained_icct.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CORE-Robotics-Lab/ICCT/a7887bfd824a86381599dc576b9e4c0aeac61092/assets/trained_icct.gif
--------------------------------------------------------------------------------
/icct/core/icct.py:
--------------------------------------------------------------------------------
1 | # Created by Yaru Niu
2 | # Reference file by Andrew Silva: https://github.com/CORE-Robotics-Lab/Interpretable_DDTS_AISTATS2020/blob/master/interpretable_ddts/agents/ddt.py
3 |
4 | import torch.nn as nn
5 | import torch
6 | import numpy as np
7 | import typing as t
8 | import torch.nn.functional as F
9 | import time
10 |
11 |
12 | class ICCT(nn.Module):
13 | def __init__(self,
14 | input_dim: int,
15 | weights: t.Union[t.List[np.array], np.array, None],
16 | comparators: t.Union[t.List[np.array], np.array, None],
17 | alpha: t.Union[t.List[np.array], np.array, None],
18 | leaves: t.Union[None, int, t.List],
19 | output_dim: t.Optional[int] = None,
20 | use_individual_alpha=False,
21 | device: str = 'cpu',
22 | use_submodels: bool = False,
23 | hard_node: bool = False,
24 | argmax_tau: float = 1.0,
25 | sparse_submodel_type = 0,
26 | fs_submodel_version = 0,
27 | l1_hard_attn = False,
28 | num_sub_features = 1,
29 | use_gumbel_softmax = False,
30 | alg_type = 'sac'):
31 | super(ICCT, self).__init__()
32 | """
33 | Initialize the Interpretable Continuous Control Tree (ICCT)
34 |
35 | :param input_dim: (observation/feature) input dimensionality
36 | :param weights: the weight vector for each node to initialize
37 | :param comparators: the comparator vector for each node to initialize
38 | :param alpha: the alpha to initialize
39 | :param leaves: the number of leaves of ICCT
40 | :param output_dim: (action) output dimensionality
41 | :param use_individual_alpha: whether use different alphas for different nodes
42 | (sometimes it helps boost the performance)
43 | :param device: which device should ICCT run on [cpu|cuda]
44 | :param use_submodels: whether use linear sub-controllers (submodels)
45 | :param hard_node: whether use differentiable crispification (this arg does not
46 | influence the differentiable crispification procedure in the
47 | sparse linear controllers)
48 | :param argmax_tau: the temperature of the diff_argmax function
49 | :param sparse_submodel_type: the type of the sparse sub-controller, 1 for L1
50 | regularization, 2 for feature selection, other
51 | values (default: 0) for not sparse
52 | :param fs_submodel_version: the version of feature-section submodel to use
53 | :param l1_hard_attn: whether only sample one linear controller to perform L1
54 | regularization for each update when using l1-reg submodels
55 | :param num_sub_features: the number of chosen features for sparse sub-controllers
56 | :param use_gumbel_softmax: whether use gumble softmax instead of the differentiable
57 | argmax (diff_argmax) proposed in the paper
58 | :param alg_type: current supported RL methods [SAC|TD3] (the results in the paper
59 | were obtained by SAC)
60 | """
61 | self.device = device
62 | self.leaf_init_information = leaves
63 | self.hard_node = hard_node
64 | self.argmax_tau = argmax_tau
65 |
66 | self.input_dim = input_dim
67 | self.output_dim = output_dim
68 | self.layers = None
69 | self.comparators = None
70 | self.use_submodels = use_submodels
71 | self.sparse_submodel_type = sparse_submodel_type
72 | self.fs_submodel_version = fs_submodel_version
73 | self.l1_hard_attn = l1_hard_attn
74 | self.num_sub_features = num_sub_features
75 | self.use_gumbel_softmax = use_gumbel_softmax
76 | self.use_individual_alpha = use_individual_alpha
77 | self.alg_type = alg_type
78 |
79 | self.init_comparators(comparators)
80 | self.init_weights(weights)
81 | self.init_alpha(alpha)
82 | self.init_paths()
83 | self.init_leaves()
84 | self.sig = nn.Sigmoid()
85 | self.num_leaves = self.layers.size(0) + 1
86 |
87 | if self.use_submodels:
88 | self.init_submodels()
89 |
90 | if self.alg_type == 'td3':
91 | self.tanh = nn.Tanh()
92 |
93 | def init_submodels(self):
94 | if self.sparse_submodel_type != 2:
95 | self.lin_models = nn.ModuleList([nn.Linear(self.input_dim, self.output_dim) for _ in range(self.num_leaves)])
96 | if self.sparse_submodel_type == 1:
97 | self.leaf_attn = None
98 | else:
99 | self.sub_scalars = nn.Parameter(torch.zeros(self.num_leaves, self.output_dim, self.input_dim).to(self.device), requires_grad=True)
100 | self.sub_weights = nn.Parameter(torch.zeros(self.num_leaves, self.output_dim, self.input_dim).to(self.device), requires_grad=True)
101 | self.sub_biases = nn.Parameter(torch.zeros(self.num_leaves, self.output_dim, self.input_dim).to(self.device), requires_grad=True)
102 |
103 | nn.init.xavier_normal_(self.sub_scalars.data)
104 | nn.init.xavier_normal_(self.sub_weights.data)
105 | nn.init.xavier_normal_(self.sub_biases.data)
106 |
107 | def init_comparators(self, comparators):
108 | if comparators is None:
109 | comparators = []
110 | if type(self.leaf_init_information) is int:
111 | depth = int(np.floor(np.log2(self.leaf_init_information)))
112 | else:
113 | depth = 4
114 | for level in range(depth):
115 | for node in range(2**level):
116 | comparators.append(np.random.normal(0, 1.0, 1))
117 | new_comps = torch.tensor(comparators, dtype=torch.float).to(self.device)
118 | new_comps.requires_grad = True
119 | self.comparators = nn.Parameter(new_comps, requires_grad=True)
120 |
121 | def init_weights(self, weights):
122 | if weights is None:
123 | weights = []
124 | if type(self.leaf_init_information) is int:
125 | depth = int(np.floor(np.log2(self.leaf_init_information)))
126 | else:
127 | depth = 4
128 | for level in range(depth):
129 | for node in range(2**level):
130 | weights.append(np.random.rand(self.input_dim))
131 |
132 | new_weights = torch.tensor(weights, dtype=torch.float).to(self.device)
133 | new_weights.requires_grad = True
134 | self.layers = nn.Parameter(new_weights, requires_grad=True)
135 |
136 | def init_alpha(self, alpha):
137 | if alpha is None:
138 | if self.use_individual_alpha:
139 | alphas = []
140 | if type(self.leaf_init_information) is int:
141 | depth = int(np.floor(np.log2(self.leaf_init_information)))
142 | else:
143 | depth = 4
144 | for level in range(depth):
145 | for node in range(2**level):
146 | alphas.append([1.0])
147 | else:
148 | alphas = [1.0]
149 | else:
150 | alphas = alpha
151 | self.alpha = torch.tensor(alphas, dtype=torch.float).to(self.device)
152 | self.alpha.requires_grad = True
153 | self.alpha = nn.Parameter(self.alpha, requires_grad=True)
154 |
155 | def init_paths(self):
156 | if type(self.leaf_init_information) is list:
157 | left_branches = torch.zeros((len(self.layers), len(self.leaf_init_information)), dtype=torch.float)
158 | right_branches = torch.zeros((len(self.layers), len(self.leaf_init_information)), dtype=torch.float)
159 | for n in range(0, len(self.leaf_init_information)):
160 | for i in self.leaf_init_information[n][0]:
161 | left_branches[i][n] = 1.0
162 | for j in self.leaf_init_information[n][1]:
163 | right_branches[j][n] = 1.0
164 | else:
165 | if type(self.leaf_init_information) is int:
166 | depth = int(np.floor(np.log2(self.leaf_init_information)))
167 | else:
168 | depth = 4
169 | left_branches = torch.zeros((2 ** depth - 1, 2 ** depth), dtype=torch.float)
170 | for n in range(0, depth):
171 | row = 2 ** n - 1
172 | for i in range(0, 2 ** depth):
173 | col = 2 ** (depth - n) * i
174 | end_col = col + 2 ** (depth - 1 - n)
175 | if row + i >= len(left_branches) or end_col >= len(left_branches[row]):
176 | break
177 | left_branches[row + i, col:end_col] = 1.0
178 | right_branches = torch.zeros((2 ** depth - 1, 2 ** depth), dtype=torch.float)
179 | left_turns = np.where(left_branches == 1)
180 | for row in np.unique(left_turns[0]):
181 | cols = left_turns[1][left_turns[0] == row]
182 | start_pos = cols[-1] + 1
183 | end_pos = start_pos + len(cols)
184 | right_branches[row, start_pos:end_pos] = 1.0
185 | left_branches.requires_grad = False
186 | right_branches.requires_grad = False
187 | self.left_path_sigs = nn.Parameter(left_branches.to(self.device), requires_grad=False)
188 | self.right_path_sigs = nn.Parameter(right_branches.to(self.device), requires_grad=False)
189 |
190 | def init_leaves(self):
191 | if type(self.leaf_init_information) is list:
192 | new_leaves = [leaf[-1] for leaf in self.leaf_init_information]
193 | else:
194 | new_leaves = []
195 | if type(self.leaf_init_information) is int:
196 | depth = int(np.floor(np.log2(self.leaf_init_information)))
197 | else:
198 | depth = 4
199 |
200 | last_level = np.arange(2**(depth-1)-1, 2**depth-1)
201 | going_left = True
202 | leaf_index = 0
203 | self.leaf_init_information = []
204 | for level in range(2**depth):
205 | curr_node = last_level[leaf_index]
206 | turn_left = going_left
207 | left_path = []
208 | right_path = []
209 | while curr_node >= 0:
210 | if turn_left:
211 | left_path.append(int(curr_node))
212 | else:
213 | right_path.append(int(curr_node))
214 | prev_node = np.ceil(curr_node / 2) - 1
215 | if curr_node // 2 > prev_node:
216 | turn_left = False
217 | else:
218 | turn_left = True
219 | curr_node = prev_node
220 | if going_left:
221 | going_left = False
222 | else:
223 | going_left = True
224 | leaf_index += 1
225 | new_probs = np.random.uniform(0, 1, self.output_dim) # *(1.0/self.output_dim)
226 | self.leaf_init_information.append([sorted(left_path), sorted(right_path), new_probs])
227 | new_leaves.append(new_probs)
228 |
229 | labels = torch.tensor(new_leaves, dtype=torch.float).to(self.device)
230 | labels.requires_grad = True
231 | if not self.use_submodels:
232 | self.action_mus = nn.Parameter(labels, requires_grad=True)
233 | torch.nn.init.xavier_uniform_(self.action_mus)
234 |
235 | if self.alg_type == 'sac':
236 | self.action_stds = nn.Parameter(labels.detach().clone(), requires_grad=True)
237 | torch.nn.init.xavier_uniform_(self.action_stds)
238 |
239 | def diff_argmax(self, logits, dim=-1):
240 | tau = self.argmax_tau
241 | sample = self.use_gumbel_softmax
242 |
243 | if sample:
244 | gumbels = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
245 | logits = logits + gumbels
246 |
247 | y_soft = (logits/tau).softmax(-1)
248 | # straight through
249 | index = y_soft.max(dim, keepdim=True)[1]
250 | y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
251 | ret = y_hard - y_soft.detach() + y_soft
252 |
253 | return ret
254 |
255 |
256 | def fs_submodels(self, input):
257 | ## feature-selection sparse linear sub-controller
258 | if self.fs_submodel_version == 1:
259 | # version 1 was used for inverted pendulum in our paper's experiments
260 | # we found that version 1 has better performance than version 0 for inverted pendulum
261 | # version 1 and version 0 have the same forward pass procedure with only precision difference
262 |
263 | # sub_scalars, sub_weights, sub_biases: [num_leaves, output_dim, input_dim]
264 | w = self.sub_weights
265 | batch_size = input.size(0)
266 | num_leaves, output_dim, input_dim = w.size(0), w.size(1), w.size(2)
267 |
268 | new_w = 0
269 | for i in range(self.num_sub_features):
270 | if not i == 0:
271 | w = w - w * onehot_weights
272 | # onehot_weights: [num_leaves, output_dim, input_dim]
273 | onehot_weights = self.diff_argmax(torch.abs(w))
274 | new_w = onehot_weights + new_w
275 |
276 | # new_s: [num_leaves, output_dim, input_dim]
277 | new_s = new_w * self.sub_scalars
278 | # new_i: [batch_size, num_leaves, output_dim, input_dim]
279 | new_i = new_w.unsqueeze(0).expand(batch_size, num_leaves, output_dim, input_dim) * input.unsqueeze(1).unsqueeze(2).expand(batch_size, num_leaves, output_dim, input_dim)
280 | # new_b: [num_leaves, output_dim]
281 | new_b = (new_w * self.sub_biases).sum(-1)
282 | # ret: [batch_size, num_leaves, output_dim]
283 | ret =(new_s * new_i).sum(-1) + new_b
284 | return ret
285 | else:
286 | # version 0 was used for all domains except inverted pendulum in our paper's experiments
287 |
288 | ret = []
289 | w = self.sub_weights
290 |
291 | for i in range(self.num_sub_features):
292 | if not i == 0:
293 | w = w - w * onehot_weights
294 |
295 | # onehot_weights: [num_leaves, output_dim, input_dim]
296 | onehot_weights = self.diff_argmax(torch.abs(w))
297 |
298 | # new_w: [num_leaves, output_dim, input_dim]
299 | # new_s: [num_leaves, output_dim, 1]
300 | # new_b: [num_leaves, output_dim, 1]
301 | new_w = onehot_weights
302 | new_s = (self.sub_scalars * onehot_weights).sum(-1).unsqueeze(-1)
303 | new_b = (self.sub_biases * onehot_weights).sum(-1).unsqueeze(-1)
304 |
305 | # input: [batch_size, input_dim]
306 | # output: [num_leaves, output_dim, batch_size]
307 | output = new_s * torch.matmul(new_w, input.transpose(0, 1)) + new_b
308 | ret.append(output.permute(2, 0, 1))
309 | return torch.sum(torch.stack(ret, dim=0), dim=0)
310 |
311 |
312 | def forward(self, input_data, embedding_list=None):
313 | # self.comparators: [num_node, 1]
314 |
315 | if self.hard_node:
316 | ## node crispification
317 | weights = torch.abs(self.layers)
318 | # onehot_weights: [num_nodes, num_leaves]
319 | onehot_weights = self.diff_argmax(weights)
320 | # divisors: [num_node, 1]
321 | divisors = (weights * onehot_weights).sum(-1).unsqueeze(-1)
322 | # fill 0 with 1
323 | divisors_filler = torch.zeros(divisors.size()).to(divisors.device)
324 | divisors_filler[divisors==0] = 1
325 | divisors = divisors + divisors_filler
326 | new_comps = self.comparators / divisors
327 | new_weights = self.layers * onehot_weights / divisors
328 | new_alpha = self.alpha
329 | else:
330 | new_comps = self.comparators
331 | new_weights = self.layers
332 | new_alpha = self.alpha
333 |
334 | # original input_data dim: [batch_size, input_dim]
335 | input_copy = input_data.clone()
336 | input_data = input_data.t().expand(new_weights.size(0), *input_data.t().size())
337 | # layers: [num_node, input_dim]
338 | # input_data dim: [batch_size, num_node, input_dim]
339 | input_data = input_data.permute(2, 0, 1)
340 | # after discretization, some weights can be -1 depending on their origal values
341 | # comp dim: [batch_size, num_node, 1]
342 | comp = new_weights.mul(input_data)
343 | comp = comp.sum(dim=2).unsqueeze(-1)
344 | comp = comp.sub(new_comps.expand(input_data.size(0), *new_comps.size()))
345 | if self.use_individual_alpha:
346 | comp = comp.mul(new_alpha.expand(input_data.size(0), *new_alpha.size()))
347 | else:
348 | comp = comp.mul(new_alpha)
349 | if self.hard_node:
350 | ## outcome crispification
351 | # sig_vals: [batch_size, num_node, 2]
352 | sig_vals = self.diff_argmax(torch.cat((comp, torch.zeros((input_data.size(0), self.layers.size(0), 1)).to(comp.device)), dim=-1))
353 |
354 | sig_vals = torch.narrow(sig_vals, 2, 0, 1).squeeze(-1)
355 | else:
356 | sig_vals = self.sig(comp)
357 | # sig_vals: [batch_size, num_node]
358 | sig_vals = sig_vals.view(input_data.size(0), -1)
359 | # one_minus_sig: [batch_size, num_node]
360 | one_minus_sig = torch.ones(sig_vals.size()).to(sig_vals.device)
361 | one_minus_sig = torch.sub(one_minus_sig, sig_vals)
362 |
363 | # left_path_probs: [num_leaves, num_nodes]
364 | left_path_probs = self.left_path_sigs.t()
365 | right_path_probs = self.right_path_sigs.t()
366 | # left_path_probs: [batch_size, num_leaves, num_nodes]
367 | left_path_probs = left_path_probs.expand(input_data.size(0), *left_path_probs.size()) * sig_vals.unsqueeze(
368 | 1)
369 | right_path_probs = right_path_probs.expand(input_data.size(0),
370 | *right_path_probs.size()) * one_minus_sig.unsqueeze(1)
371 | # left_path_probs: [batch_size, num_nodes, num_leaves]
372 | left_path_probs = left_path_probs.permute(0, 2, 1)
373 | right_path_probs = right_path_probs.permute(0, 2, 1)
374 |
375 | # We don't want 0s to ruin leaf probabilities, so replace them with 1s so they don't affect the product
376 | left_filler = torch.zeros(self.left_path_sigs.size()).to(left_path_probs.device)
377 | left_filler[self.left_path_sigs == 0] = 1
378 | right_filler = torch.zeros(self.right_path_sigs.size()).to(left_path_probs.device)
379 | right_filler[self.right_path_sigs == 0] = 1
380 |
381 | # left_path_probs: [batch_size, num_nodes, num_leaves]
382 | left_path_probs = left_path_probs.add(left_filler)
383 | right_path_probs = right_path_probs.add(right_filler)
384 |
385 | # probs: [batch_size, 2*num_nodes, num_leaves]
386 | probs = torch.cat((left_path_probs, right_path_probs), dim=1)
387 | # probs: [batch_size, num_leaves]
388 | probs = probs.prod(dim=1)
389 |
390 | if self.use_submodels and self.sparse_submodel_type == 1:
391 | # here we choose L1 regularization over L2 because L1 is more likely to push coefficients to zeros
392 | self.leaf_attn = probs.clone().detach().sum(dim=0) / input_data.size(0)
393 | if self.l1_hard_attn:
394 | # if only sample one leaf node's linear controller to perform L1 regularization for each update
395 | # this can be helpful in enforcing sparsity on each linear controller
396 | distribution = torch.distributions.Categorical(self.leaf_attn)
397 | attn_idx = distribution.sample()
398 | self.leaf_attn = torch.zeros(probs.size(1))
399 | self.leaf_attn[attn_idx] = 1
400 |
401 | if self.use_submodels:
402 | if self.sparse_submodel_type != 2:
403 | output = torch.zeros((self.num_leaves, input_data.size(0), self.output_dim)).to(self.device)
404 | # input_copy [batch_size, input_dim]
405 | for e, i in enumerate(self.lin_models):
406 | output[e] = i(input_copy)
407 | else:
408 | output = self.fs_submodels(input_copy).transpose(0, 1)
409 | actions = torch.bmm(probs.reshape(-1, 1, self.num_leaves), output.transpose(0, 1))
410 | mus = actions.squeeze(1)
411 | else:
412 | # self.action_mus: [num_leaves, output_dim]
413 | # mus: [batch_size, output_dim]
414 | mus = probs.mm(self.action_mus)
415 |
416 | if self.alg_type == 'sac':
417 | stds = probs.mm(self.action_stds)
418 | stds = torch.clamp(stds, -20, 2).view(input_data.size(0), -1)
419 | return mus, stds
420 | else:
421 | # TD3 here outputs deterministic policies during training
422 | mus = self.tanh(mus)
423 | return mus
424 |
--------------------------------------------------------------------------------
/icct/core/icct_helpers.py:
--------------------------------------------------------------------------------
1 | # Created by Yaru Niu
2 |
3 | import numpy as np
4 | import sys
5 | from icct.core.icct import ICCT
6 | import torch
7 |
8 |
9 | def convert_to_crisp(fuzzy_model, training_data):
10 | new_weights = []
11 | new_comps = []
12 | device = fuzzy_model.device
13 |
14 | weights = np.abs(fuzzy_model.layers.cpu().detach().numpy())
15 | most_used = np.argmax(weights, axis=1)
16 | for comp_ind, comparator in enumerate(fuzzy_model.comparators):
17 | comparator = comparator.item()
18 | divisor = abs(fuzzy_model.layers[comp_ind][most_used[comp_ind]].item())
19 | if divisor == 0:
20 | divisor = 1
21 | comparator /= divisor
22 | new_comps.append([comparator])
23 | max_ind = most_used[comp_ind]
24 | new_weight = np.zeros(len(fuzzy_model.layers[comp_ind].data))
25 | new_weight[max_ind] = fuzzy_model.layers[comp_ind][most_used[comp_ind]].item() / divisor
26 | new_weights.append(new_weight)
27 |
28 | new_input_dim = fuzzy_model.input_dim
29 | new_weights = np.array(new_weights)
30 | new_comps = np.array(new_comps)
31 | new_alpha = fuzzy_model.alpha
32 | new_alpha = 9999999. * new_alpha.cpu().detach().numpy() / np.abs(new_alpha.cpu().detach().numpy())
33 | crispy_model = ICCT(input_dim=new_input_dim,
34 | output_dim=fuzzy_model.output_dim,
35 | weights=new_weights,
36 | comparators=new_comps,
37 | leaves=fuzzy_model.leaf_init_information,
38 | alpha=new_alpha,
39 | use_individual_alpha=fuzzy_model.use_individual_alpha,
40 | use_submodels=fuzzy_model.use_submodels,
41 | hard_node=False,
42 | sparse_submodel_type=fuzzy_model.sparse_submodel_type,
43 | l1_hard_attn=False,
44 | num_sub_features=fuzzy_model.num_sub_features,
45 | use_gumbel_softmax=fuzzy_model.use_gumbel_softmax,
46 | device=device).to(device)
47 | if hasattr(fuzzy_model, 'action_mus'):
48 | crispy_model.action_mus.data = fuzzy_model.action_mus.data
49 | crispy_model.action_stds.data = fuzzy_model.action_stds.data
50 | if fuzzy_model.use_submodels:
51 | if fuzzy_model.sparse_submodel_type != 2:
52 | crispy_model.lin_models = fuzzy_model.lin_models
53 | else:
54 | crispy_model.sub_scalars = fuzzy_model.sub_scalars
55 | crispy_model.sub_weights = fuzzy_model.sub_weights
56 | crispy_model.sub_biases = fuzzy_model.sub_biases
57 |
58 | return crispy_model
59 |
--------------------------------------------------------------------------------
/icct/dagger/dagger.py:
--------------------------------------------------------------------------------
1 | # Created by Yaru Niu
2 | # Reference: https://github.com/obastani/viper/blob/master/python/viper/core/rl.py
3 |
4 | import numpy as np
5 | import pickle
6 | import torch
7 |
8 | class DAgger:
9 | def __init__(self,
10 | env,
11 | oracle_model,
12 | dt_model,
13 | n_rollouts,
14 | iterations,
15 | max_samples,
16 | is_reweight,
17 | n_q_samples):
18 | self.env = env
19 | self.oracle = oracle_model
20 | self.dt = dt_model
21 | self.best_dt = None
22 | self.n_rollouts = n_rollouts
23 | self.iterations = iterations
24 | self.max_samples = max_samples
25 | self.is_reweight = is_reweight
26 | self.n_q_samples = n_q_samples
27 | self.dataset_obs = []
28 | self.dataset_act = []
29 | self.dataset_prob = []
30 |
31 | def generate_uniform_combinations(self, n_samples=10, act_dim=2):
32 | # Step 1: Generate n values in [-1, 1]
33 | values = torch.linspace(-1, 1, n_samples)
34 | # Step 2: Create a meshgrid of all combinations (n^act_dim total combinations)
35 | grids = torch.meshgrid([values] * act_dim)
36 | # Step 3: Stack and reshape to shape (n^act_dim, act_dim)
37 | combinations = torch.stack(grids, dim=-1).reshape(-1, act_dim)
38 | return combinations
39 |
40 | def get_rollout(self, execute_dt=True):
41 | obs = self.env.reset()
42 | done = False
43 | rollout = []
44 | act_dim = self.env.action_space.shape[0]
45 | n_q_samples = self.n_q_samples
46 | while not done:
47 | # oracle_act: action that can be directly used by the environment
48 | oracle_act, _ = self.oracle.predict(obs, deterministic=True)
49 | # raw_oracle_act: raw mean action output by the policy network, used to train the DT
50 | processed_obs, raw_oracle_act = self.oracle.actor.get_sa_pair()
51 | sampled_obss = processed_obs.repeat(n_q_samples ** act_dim, 1)
52 | sampled_acts = self.generate_uniform_combinations(n_samples=n_q_samples, act_dim=act_dim)
53 | q1_values = self.oracle.critic.q1_forward(sampled_obss, sampled_acts)
54 | loss = torch.max(q1_values, dim=0).values - torch.min(q1_values, dim=0).values
55 |
56 | if execute_dt:
57 | act = self.dt.predict(processed_obs.cpu().numpy())
58 | else:
59 | act = oracle_act
60 |
61 | next_obs, rwd, done, info = self.env.step(act)
62 | rollout.append((processed_obs.cpu().numpy(), raw_oracle_act.cpu().numpy(), rwd, loss.detach().cpu().numpy()))
63 | obs = next_obs
64 |
65 | return rollout
66 |
67 | def get_rollouts(self, execute_dt=True):
68 | rollouts = []
69 | for n in range(self.n_rollouts):
70 | rollouts.extend(self.get_rollout(execute_dt))
71 | return rollouts
72 |
73 | def sample_batch_idx(self, probs, max_samples, is_reweight):
74 | probs = probs / np.sum(probs)
75 | if is_reweight:
76 | idx = np.random.choice(probs.shape[0], size=min(max_samples, np.sum(probs > 0)), p=probs)
77 | else:
78 | idx = np.random.choice(probs.shape[0], size=min(max_samples, np.sum(probs > 0)), replace=False)
79 | return idx
80 |
81 | def train(self, save_path):
82 | first_batch = self.get_rollouts(execute_dt=False)
83 | self.dataset_obs.extend((obs for obs, _, _, _ in first_batch))
84 | self.dataset_act.extend((act for _, act, _, _ in first_batch))
85 | self.dataset_prob.extend((prob for _, _, _, prob in first_batch))
86 | best_rwd = -9e5
87 |
88 | for i in range(self.iterations):
89 | dataset_obs = np.concatenate(self.dataset_obs, axis=0)
90 | dataset_act = np.concatenate(self.dataset_act, axis=0)
91 | dataset_prob = np.concatenate(self.dataset_prob, axis=0)
92 |
93 | idx = self.sample_batch_idx(probs=dataset_prob, max_samples=self.max_samples, is_reweight=self.is_reweight)
94 |
95 | self.dt.train(dataset_obs[idx], dataset_act[idx])
96 |
97 | later_batch = self.get_rollouts(execute_dt=True)
98 | self.dataset_obs.extend((obs for obs, _, _, _ in later_batch))
99 | self.dataset_act.extend((act for _, act, _, _ in later_batch))
100 | self.dataset_prob.extend((prob for _, _, _, prob in later_batch))
101 |
102 | average_rwd = np.sum((rwd for _, _, rwd, _ in later_batch)) / self.n_rollouts
103 | print('average reward: ', average_rwd)
104 | if average_rwd >= best_rwd:
105 | best_rwd = average_rwd
106 | print('save the best model!')
107 | self.best_dt = self.dt.clone()
108 | self.save_best_dt(save_path)
109 |
110 | def save_best_dt(self, save_path):
111 | pickle.dump(self.best_dt, open(save_path, 'wb'))
112 |
113 | def load_best_dt(self, load_path):
114 | self.best_dt = pickle.load(open(load_path, 'rb'))
115 |
116 | def evaluate(self, n_episodes):
117 | print('number of leaves', self.best_dt.tree.get_n_leaves())
118 | episode_reward_list = []
119 | for _ in range(n_episodes):
120 | obs = self.env.reset()
121 | done = False
122 | episode_reward = 0
123 | while not done:
124 | _, _ = self.oracle.predict(obs, deterministic=True)
125 | obs, _ = self.oracle.actor.get_sa_pair()
126 | action = self.best_dt.predict(obs.cpu().numpy())
127 | obs, reward, done, info = self.env.step(action)
128 | episode_reward += reward
129 | episode_reward_list.append(episode_reward)
130 | print(episode_reward_list)
131 | print(np.mean(episode_reward_list))
132 | print(np.std(episode_reward_list))
--------------------------------------------------------------------------------
/icct/dagger/dt_policy.py:
--------------------------------------------------------------------------------
1 | # Created by Yaru Niu
2 |
3 | import numpy as np
4 | from sklearn.tree import DecisionTreeRegressor
5 |
6 | class DTPolicy:
7 | def __init__(self, action_space, max_depth):
8 | self.action_space = action_space
9 | self.max_depth = max_depth
10 |
11 | def fit(self, obss, acts):
12 | self.tree = DecisionTreeRegressor(max_depth=self.max_depth)
13 | self.tree.fit(obss, acts)
14 |
15 | def train(self, obss, acts):
16 | self.fit(obss, acts)
17 |
18 | def _predict(self, obs):
19 | return self.tree.predict(obs)
20 |
21 | def predict(self, obs):
22 | raw_act = self._predict(obs)[0]
23 | squashed_act = np.tanh(raw_act)
24 | return self.unscale_action(squashed_act)
25 |
26 | def unscale_action(self, scaled_action):
27 | low, high = self.action_space.low, self.action_space.high
28 | return low + (0.5 * (scaled_action + 1.0) * (high - low))
29 |
30 | def clone(self):
31 | clone = DTPolicy(self.action_space, self.max_depth)
32 | clone.tree = self.tree
33 | return clone
34 |
35 |
36 |
--------------------------------------------------------------------------------
/icct/dagger/oracle_models/fig8_mlp_max.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CORE-Robotics-Lab/ICCT/a7887bfd824a86381599dc576b9e4c0aeac61092/icct/dagger/oracle_models/fig8_mlp_max.zip
--------------------------------------------------------------------------------
/icct/dagger/oracle_models/ip_mlp_max.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CORE-Robotics-Lab/ICCT/a7887bfd824a86381599dc576b9e4c0aeac61092/icct/dagger/oracle_models/ip_mlp_max.zip
--------------------------------------------------------------------------------
/icct/dagger/oracle_models/lk_mlp_max.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CORE-Robotics-Lab/ICCT/a7887bfd824a86381599dc576b9e4c0aeac61092/icct/dagger/oracle_models/lk_mlp_max.zip
--------------------------------------------------------------------------------
/icct/dagger/oracle_models/ll_mlp_max.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CORE-Robotics-Lab/ICCT/a7887bfd824a86381599dc576b9e4c0aeac61092/icct/dagger/oracle_models/ll_mlp_max.zip
--------------------------------------------------------------------------------
/icct/dagger/oracle_models/mlr_mlp_max.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CORE-Robotics-Lab/ICCT/a7887bfd824a86381599dc576b9e4c0aeac61092/icct/dagger/oracle_models/mlr_mlp_max.zip
--------------------------------------------------------------------------------
/icct/dagger/oracle_models/slr_mlp_max.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CORE-Robotics-Lab/ICCT/a7887bfd824a86381599dc576b9e4c0aeac61092/icct/dagger/oracle_models/slr_mlp_max.zip
--------------------------------------------------------------------------------
/icct/dagger/train_dagger.py:
--------------------------------------------------------------------------------
1 | # Created by Yaru Niu
2 |
3 | import gym
4 | import os
5 | import numpy as np
6 | import argparse
7 | from icct.dagger.dt_policy import DTPolicy
8 | from icct.dagger.dagger import DAgger
9 |
10 | from icct.rl_helpers.sac import SAC
11 | import highway_env
12 | from flow.utils.registry import make_create_env
13 | from icct.sumo_envs.accel_ring import ring_accel_params
14 | from icct.sumo_envs.accel_ring_multilane import ring_accel_lc_params
15 | from icct.sumo_envs.accel_figure8 import fig8_params
16 | from stable_baselines3.common.utils import set_random_seed
17 |
18 | def make_env(env_name, seed):
19 | set_random_seed(seed)
20 | if env_name == 'lunar':
21 | env = gym.make('LunarLanderContinuous-v2')
22 | name = 'LunarLanderContinuous-v2'
23 | elif env_name == 'cart':
24 | env = gym.make('InvertedPendulum-v2')
25 | name = 'InvertedPendulum-v2'
26 | elif env_name == 'lane_keeping':
27 | env = gym.make('lane-keeping-v0')
28 | name = 'lane-keeping-v0'
29 | elif env_name == 'ring_accel':
30 | create_env, gym_name = make_create_env(params=ring_accel_params, version=0)
31 | env = create_env()
32 | name = gym_name
33 | elif env_name == 'ring_lane_changing':
34 | create_env, gym_name = make_create_env(params=ring_accel_lc_params, version=0)
35 | env = create_env()
36 | name = gym_name
37 | elif env_name == 'figure8':
38 | create_env, gym_name = make_create_env(params=fig8_params, version=0)
39 | env = create_env()
40 | name = gym_name
41 | else:
42 | raise Exception('No valid environment selected')
43 | env.seed(seed)
44 | return env, name
45 |
46 |
47 | if __name__ == "__main__":
48 | parser = argparse.ArgumentParser(description='Training Decision Trees with DAgger')
49 | parser.add_argument('--env_name', help='environment to run on', type=str, default='lunar')
50 | parser.add_argument('--max_depth', help='the maximum depth of the decision tree', type=int, default=5)
51 | parser.add_argument('--n_rollouts', help='number of rollouts in a training batch', type=int, default=10)
52 | parser.add_argument('--iterations', help='maximum number of training iterations', type=int, default=80)
53 | parser.add_argument('--max_samples', help='the maximum number of samples for each training batch', type=int, default=1000000)
54 | parser.add_argument('--q_dagger', help='if use q-dagger', action='store_true', default=False)
55 | parser.add_argument('--n_q_samples', help='number of samples along each action dimension for q values', type=int, default=30)
56 | parser.add_argument('--eval_episodes', help='number of episodes to evaluate the trained decision tree', type=int, default=20)
57 | parser.add_argument('--oracle_load_path', help='the path of loading the oracle model', type=str, default='saved_mlp_models')
58 | parser.add_argument('--oracle_load_file', help='which oracle model file to load', type=str, default='best_model')
59 | parser.add_argument('--seed', help='random seed', type=int, default=42)
60 | parser.add_argument('--render', help='if render the tested environment', action='store_true')
61 | parser.add_argument('--gpu', help='if run on a GPU (depending on the loaded file)', action='store_true')
62 | parser.add_argument('--save', help='the path to save the decision tree model', type=str, default='saved_dt_models')
63 | parser.add_argument('--load', help='the path to load the decision tree model', type=str, default=None)
64 |
65 | args = parser.parse_args()
66 | env, env_n = make_env(args.env_name, args.seed)
67 | if not os.path.exists(args.save):
68 | os.makedirs(args.save)
69 | save_dir = args.save + "/" + "best_dt.pkl"
70 |
71 | if args.gpu:
72 | args.device = 'cuda'
73 | else:
74 | args.device = 'cpu'
75 |
76 | oracle_model = SAC.load(args.oracle_load_path + "/" + args.oracle_load_file, device=args.device)
77 | oracle_model.set_random_seed(args.seed)
78 | dt_model = DTPolicy(env.action_space, args.max_depth)
79 | dagger = DAgger(env, oracle_model, dt_model, args.n_rollouts, args.iterations, args.max_samples, args.q_dagger, args.n_q_samples)
80 | if args.load:
81 | dagger.load_best_dt(args.load)
82 | else:
83 | dagger.train(save_dir)
84 | dagger.evaluate(args.eval_episodes)
85 |
86 |
--------------------------------------------------------------------------------
/icct/dagger/train_dagger.sh:
--------------------------------------------------------------------------------
1 | export OMP_NUM_THREADS=1
2 |
3 | python -u train_dagger.py \
4 | --env_name lunar \
5 | --max_depth 5 \
6 | --n_rollouts 10 \
7 | --iterations 100 \
8 | --max_samples 1000000 \
9 | --eval_episodes 50 \
10 | --oracle_load_path oracle_models \
11 | --oracle_load_file ll_mlp_max \
12 | --save saved_dt_models/ll \
13 | --seed 5 \
14 | | tee train.log
--------------------------------------------------------------------------------
/icct/dagger/train_q_dagger.sh:
--------------------------------------------------------------------------------
1 | export OMP_NUM_THREADS=1
2 |
3 | python -u train_dagger.py \
4 | --env_name lunar \
5 | --max_depth 5 \
6 | --n_rollouts 10 \
7 | --iterations 100 \
8 | --max_samples 1000000 \
9 | --q_dagger \
10 | --n_q_samples 40 \
11 | --eval_episodes 50 \
12 | --oracle_load_path oracle_models \
13 | --oracle_load_file ll_mlp_max \
14 | --save saved_dt_models/ll \
15 | --seed 5 \
16 | | tee train.log
--------------------------------------------------------------------------------
/icct/plot/learning_curve_plot.py:
--------------------------------------------------------------------------------
1 | # Created by Yaru Niu
2 |
3 | import argparse
4 | from learning_curve_plotter import *
5 |
6 | if __name__ == "__main__":
7 | parser = argparse.ArgumentParser(description='Arguments for Plotting Learning Curves')
8 | parser.add_argument('--log_dir', help='the path to the data', type=str, default='results')
9 | parser.add_argument('--eval_freq', help='evaluation frequence used during training', type=int, default=1500)
10 | parser.add_argument('--n_eval_episodes', help='the number of episodes for each evaluation during training', type=int, default=5)
11 | parser.add_argument('--eval_smooth_window_size', help='the sliding window size to smooth the evaluation rewards', type=int, default=1)
12 | parser.add_argument('--non_eval_sample_freq', help='the sample frequence of the rollout rewards for plotting ', type=int, default=2000)
13 | parser.add_argument('--non_eval_smooth_window_size', help='the sliding window size to smooth the sampled rollout rewards', type=int, default=1)
14 | parser.add_argument('--env_name', help='the environment name of the raw data', type=str)
15 | parser.add_argument('--show_legend', help='if show the legend in the figure', action='store_true', default=False)
16 |
17 |
18 | args = parser.parse_args()
19 | plotter = Learning_Curve_Plotter(log_dir=args.log_dir,
20 | eval_freq=args.eval_freq,
21 | n_eval_episodes=args.n_eval_episodes,
22 | eval_smooth_window_size=args.eval_smooth_window_size,
23 | non_eval_sample_freq=args.non_eval_sample_freq,
24 | non_eval_smooth_window_size=args.non_eval_smooth_window_size,
25 | env_name=args.env_name,
26 | show_legend=args.show_legend)
27 | plotter.process_data()
28 | plotter.plot()
--------------------------------------------------------------------------------
/icct/plot/learning_curve_plotter.py:
--------------------------------------------------------------------------------
1 | # Created by Yaru Niu
2 |
3 | from typing import Callable, List, Optional, Tuple
4 |
5 | import csv
6 | import json
7 | import os
8 | import time
9 | from glob import glob
10 | from typing import Dict, List, Optional, Tuple, Union
11 | import gym
12 | import numpy as np
13 | import matplotlib.pyplot as plt
14 | import matplotlib
15 | import seaborn as sns
16 |
17 |
18 | from stable_baselines3.common.monitor import Monitor
19 | from stable_baselines3.common.results_plotter import load_results, ts2xy
20 | import pandas as pd
21 |
22 |
23 | class Learning_Curve_Plotter(object):
24 | def __init__(self,
25 | log_dir,
26 | eval_freq=1500,
27 | n_eval_episodes=5,
28 | eval_smooth_window_size=10,
29 | non_eval_sample_freq=2000,
30 | non_eval_smooth_window_size=1,
31 | method_names=None,
32 | env_name='random',
33 | show_legend=False) -> None:
34 | self.log_dir = log_dir
35 | self.eval_freq = eval_freq
36 | self.n_eval_episodes = n_eval_episodes
37 | self.eval_smooth_window_size = eval_smooth_window_size
38 | self.non_eval_sample_freq = non_eval_sample_freq
39 | self.non_eval_smooth_window_size = non_eval_smooth_window_size
40 | self.env_name = env_name
41 | self.show_legend = show_legend
42 |
43 | if method_names == None:
44 | self.method_names = {'CDDT': 'm1',
45 | 'CDDT-controllers': 'm2',
46 | 'ICCT-static': 'm3',
47 | 'ICCT-complete': 'm4',
48 | 'ICCT-L1-sparse': 'm5a',
49 | 'ICCT-1-feature': 'm5b_1',
50 | 'ICCT-2-feature': 'm5b_2',
51 | 'ICCT-3-feature': 'm5b_3',
52 | 'MLP': 'mlp_l',
53 | 'MLP-U': 'mlp_m',
54 | 'MLP-L': 'mlp_s'}
55 | else:
56 | self.method_names = method_names
57 |
58 | self.non_eval_monitor_files = {}
59 | self.eval_monitor_files = {}
60 | for method, method_name in self.method_names.items():
61 | self.non_eval_monitor_files[method] = self.get_monitor_files(self.log_dir, method_name)
62 | self.eval_monitor_files[method] = self.get_monitor_files(self.log_dir, method_name, eval=True)
63 |
64 | self.non_eval_data = None
65 | self.eval_data = None
66 |
67 |
68 | def process_data(self):
69 | self.non_eval_data = self._process_non_eval_data(self.non_eval_monitor_files, self.non_eval_sample_freq, self.non_eval_smooth_window_size)
70 | self.eval_data = self._process_eval_data(self.eval_monitor_files, self.eval_freq, self.n_eval_episodes, self.eval_smooth_window_size)
71 |
72 | return
73 |
74 |
75 | def plot(self):
76 | self._plot_non_eval()
77 | self._plot_eval()
78 |
79 | return
80 |
81 |
82 | def _plot_non_eval(self):
83 | sns.set_style("whitegrid")
84 | matplotlib.rcParams.update({'font.size': 25})
85 | plt.rcParams["font.weight"] = "bold"
86 | plt.rcParams['axes.labelweight'] = 'bold'
87 | plt.rcParams['axes.linewidth'] = 2
88 | plt.figure(figsize=(12, 6), dpi=100)
89 | if self.show_legend:
90 | legend_flag = 'auto'
91 | else:
92 | legend_flag = False
93 | hue_order = ['ICCT-complete',
94 | 'ICCT-1-feature',
95 | 'ICCT-2-feature',
96 | 'ICCT-3-feature',
97 | 'ICCT-static',
98 | 'ICCT-L1-sparse',
99 | 'CDDT',
100 | 'CDDT-controllers',
101 | 'MLP',
102 | 'MLP-U',
103 | 'MLP-L']
104 | hue_order.reverse()
105 | color_map = {'CDDT': 'purple',
106 | 'CDDT-controllers': 'brown',
107 | 'ICCT-static': 'gold',
108 | 'ICCT-complete': 'red',
109 | 'ICCT-L1-sparse': 'grey',
110 | 'ICCT-1-feature': 'darkorange',
111 | 'ICCT-2-feature': 'green',
112 | 'ICCT-3-feature': 'blue',
113 | 'MLP': 'darkturquoise',
114 | 'MLP-U': 'skyblue',
115 | 'MLP-L': 'pink'}
116 | sns.lineplot(data=self.non_eval_data, x="timesteps", y="rollout_rewards_mean", hue="method", ci=68, hue_order=hue_order, legend=legend_flag, palette=color_map)
117 | plt.xlabel('Time Step (k)')
118 | plt.ylabel('Reward')
119 | if self.show_legend:
120 | plt.legend(title=None, ncol=1, fontsize=6)
121 | plt.savefig(f'{self.env_name}_rollout_reward_curves.png', bbox_inches='tight')
122 | plt.close()
123 | return
124 |
125 |
126 | def _plot_eval(self):
127 | sns.set_style("whitegrid")
128 | matplotlib.rcParams.update({'font.size': 25})
129 | plt.rcParams["font.weight"] = "bold"
130 | plt.rcParams['axes.labelweight'] = 'bold'
131 | plt.rcParams['axes.linewidth'] = 2
132 | plt.figure(figsize=(12, 6), dpi=100)
133 | if self.show_legend:
134 | legend_flag = 'auto'
135 | else:
136 | legend_flag = False
137 | hue_order = ['ICCT-complete',
138 | 'ICCT-1-feature',
139 | 'ICCT-2-feature',
140 | 'ICCT-3-feature',
141 | 'ICCT-static',
142 | 'ICCT-L1-sparse',
143 | 'CDDT',
144 | 'CDDT-controllers',
145 | 'MLP',
146 | 'MLP-U',
147 | 'MLP-L']
148 | hue_order.reverse()
149 | color_map = {'CDDT': 'purple',
150 | 'CDDT-controllers': 'brown',
151 | 'ICCT-static': 'gold',
152 | 'ICCT-complete': 'red',
153 | 'ICCT-L1-sparse': 'grey',
154 | 'ICCT-1-feature': 'darkorange',
155 | 'ICCT-2-feature': 'green',
156 | 'ICCT-3-feature': 'blue',
157 | 'MLP': 'darkturquoise',
158 | 'MLP-U': 'skyblue',
159 | 'MLP-L': 'pink'}
160 | sns.lineplot(data=self.eval_data, x="eval_timesteps", y="eval_rewards_mean", hue="method", ci=68, hue_order=hue_order, legend=legend_flag, palette=color_map)
161 | plt.xlabel('Time Step (k)')
162 | plt.ylabel('Reward')
163 | if self.show_legend:
164 | plt.legend(title=None, ncol=1, fontsize=6)
165 | plt.savefig(f'{self.env_name}_eval_reward_curves.png', bbox_inches='tight')
166 | plt.close()
167 |
168 | return
169 |
170 | def _process_non_eval_data(self, dict_monitor_files, sample_freq, smooth_window_size):
171 | data_frames = []
172 | for method, file_names in dict_monitor_files.items():
173 | if len(file_names) == 0:
174 | pass
175 | for file_name in file_names:
176 | with open(file_name, "rt") as file_handler:
177 | first_line = file_handler.readline()
178 | assert first_line[0] == "#"
179 | data_frame = pd.read_csv(file_handler, index_col=None)
180 | rewards = data_frame['r'].to_numpy()
181 | timesteps = data_frame['l'].to_numpy().cumsum()
182 | new_timesteps = np.arange(0, timesteps[-1] + 1, sample_freq)
183 | dist = np.tile(new_timesteps.reshape(new_timesteps.shape[0], -1),
184 | timesteps.shape[0]) - timesteps
185 | sample_idx = np.argmin(np.abs(dist), axis=-1)
186 | sampled_rewards = rewards[sample_idx]
187 | sampled_rewards = self.moving_average(sampled_rewards, window=smooth_window_size)
188 | new_timesteps = new_timesteps[new_timesteps.shape[0] - sampled_rewards.shape[0]:]/1000
189 | processed_data = pd.DataFrame(
190 | np.stack([sampled_rewards, new_timesteps], axis=-1),
191 | columns = ['rollout_rewards_mean', 'timesteps'])
192 | method_names = [method] * new_timesteps.shape[0]
193 | processed_data['method'] = method_names
194 | data_frames.append(processed_data)
195 | data_frame = pd.concat(data_frames)
196 | data_frame.reset_index(inplace=True)
197 | return data_frame
198 |
199 |
200 | def _process_eval_data(self, dict_monitor_files, eval_freq, n_eval_episodes, smooth_window_size):
201 | data_frames = []
202 | for method, file_names in dict_monitor_files.items():
203 | if len(file_names) == 0:
204 | pass
205 | for file_name in file_names:
206 | with open(file_name, "rt") as file_handler:
207 | first_line = file_handler.readline()
208 | assert first_line[0] == "#"
209 | data_frame = pd.read_csv(file_handler, index_col=None)
210 | eval_rewards = data_frame['r'].to_numpy().reshape(-1, n_eval_episodes)
211 | eval_rewards_mean = eval_rewards.mean(axis=-1)
212 | eval_rewards_mean = self.moving_average(eval_rewards_mean, window=smooth_window_size)
213 | eval_rewards_std = eval_rewards.std(axis=-1)
214 | eval_rewards_std = self.moving_average(eval_rewards_std, window=smooth_window_size)
215 |
216 | eval_lengths = data_frame['l'].to_numpy().reshape(-1, n_eval_episodes)
217 | eval_lengths_mean = eval_lengths.mean(axis=-1)
218 | eval_lengths_mean = self.moving_average(eval_lengths_mean, window=smooth_window_size)
219 | eval_lengths_std = eval_lengths.std(axis=-1)
220 | eval_lengths_std = self.moving_average(eval_lengths_std, window=smooth_window_size)
221 |
222 | eval_timesteps = np.arange(eval_freq, eval_freq * eval_rewards.shape[0] + 1, eval_freq)
223 | eval_timesteps = eval_timesteps[eval_timesteps.shape[0] - eval_rewards_mean.shape[0]:]/1000
224 |
225 | processed_data = pd.DataFrame(
226 | np.stack([eval_rewards_mean, eval_rewards_std, eval_lengths_mean, eval_lengths_std, eval_timesteps], axis=-1),
227 | columns = ['eval_rewards_mean', 'eval_rewards_std', 'eval_lengths_mean', 'eval_lengths_std', 'eval_timesteps'])
228 | method_names = [method] * eval_timesteps.shape[0]
229 | processed_data['method'] = method_names
230 | data_frames.append(processed_data)
231 | data_frame = pd.concat(data_frames)
232 | data_frame.reset_index(inplace=True)
233 | return data_frame
234 |
235 |
236 |
237 | def get_monitor_files(self, path, method_name, eval=False) -> List[str]:
238 | eval_files = glob(os.path.join(path, '*' + 'eval' + '*' + method_name + '*' + 'monitor.csv'))
239 | all_files = glob(os.path.join(path, '*' + method_name + '*' + 'monitor.csv'))
240 | non_eval_files = list(set(all_files) - set(eval_files))
241 |
242 | if eval:
243 | ret = eval_files
244 | else:
245 | ret = non_eval_files
246 |
247 | return ret
248 |
249 |
250 | def moving_average(self, values, window):
251 | """
252 | Smooth values by doing a moving average
253 | :param values: (numpy array)
254 | :param window: (int)
255 | :return: (numpy array)
256 | """
257 | weights = np.repeat(1.0, window) / window
258 | return np.convolve(values, weights, 'valid')
259 |
260 |
--------------------------------------------------------------------------------
/icct/rl_helpers/ddt_sac_policy.py:
--------------------------------------------------------------------------------
1 | # Created by Yaru Niu and Andrew Silva
2 |
3 | from typing import Any, Dict, List, Optional, Tuple, Type, Union
4 |
5 | import gym
6 | import torch as th
7 | from torch import nn
8 | from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
9 | from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, create_sde_features_extractor, register_policy, get_policy_from_name
10 |
11 | from icct.rl_helpers.sac_policies import SACPolicy, LOG_STD_MAX, LOG_STD_MIN
12 | from stable_baselines3.common.preprocessing import get_action_dim
13 | from stable_baselines3.common.torch_layers import (
14 | BaseFeaturesExtractor,
15 | CombinedExtractor,
16 | FlattenExtractor
17 | )
18 | from icct.core.icct import ICCT
19 | from stable_baselines3.common.type_aliases import Schedule
20 |
21 |
22 | class DDTActor(BasePolicy):
23 |
24 | def __init__(
25 | self,
26 | observation_space: gym.spaces.Space,
27 | action_space: gym.spaces.Space,
28 | net_arch: List[int],
29 | features_extractor: nn.Module,
30 | features_dim: int,
31 | activation_fn: Type[nn.Module] = nn.ReLU,
32 | use_sde: bool = False,
33 | log_std_init: float = -3,
34 | full_std: bool = True,
35 | sde_net_arch: Optional[List[int]] = None,
36 | use_expln: bool = False,
37 | clip_mean: float = 2.0,
38 | normalize_images: bool = True,
39 | ddt_kwargs: Dict[str, Any] = None,
40 | ):
41 | super(DDTActor, self).__init__(
42 | observation_space,
43 | action_space,
44 | features_extractor=features_extractor,
45 | normalize_images=normalize_images,
46 | squash_output=True,
47 | )
48 |
49 | # Save arguments to re-create object at loading
50 | self.use_sde = use_sde
51 | self.sde_features_extractor = None
52 | self.sde_net_arch = sde_net_arch
53 | self.net_arch = net_arch
54 | self.features_dim = features_dim
55 | self.activation_fn = activation_fn
56 | self.log_std_init = log_std_init
57 | self.sde_net_arch = sde_net_arch
58 | self.use_expln = use_expln
59 | self.full_std = full_std
60 | self.clip_mean = clip_mean
61 | self.ddt_kwargs = ddt_kwargs
62 |
63 | action_dim = get_action_dim(self.action_space)
64 | last_layer_dim = features_dim
65 | self.ddt = ICCT(input_dim=features_dim,
66 | output_dim=action_dim,
67 | weights=None,
68 | comparators=None,
69 | leaves=ddt_kwargs['num_leaves'],
70 | alpha=None,
71 | use_individual_alpha=ddt_kwargs['use_individual_alpha'],
72 | device=ddt_kwargs['device'],
73 | use_submodels=ddt_kwargs['submodels'],
74 | hard_node=ddt_kwargs['hard_node'],
75 | argmax_tau=ddt_kwargs['argmax_tau'],
76 | sparse_submodel_type=ddt_kwargs['sparse_submodel_type'],
77 | fs_submodel_version = ddt_kwargs['fs_submodel_version'],
78 | l1_hard_attn=ddt_kwargs['l1_hard_attn'],
79 | num_sub_features=ddt_kwargs['num_sub_features'],
80 | use_gumbel_softmax=ddt_kwargs['use_gumbel_softmax'],
81 | alg_type=ddt_kwargs['alg_type']).to(ddt_kwargs['device'])
82 | print(self.ddt.state_dict())
83 | if self.use_sde:
84 | latent_sde_dim = last_layer_dim
85 | # Separate features extractor for gSDE
86 | if sde_net_arch is not None:
87 | self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(
88 | features_dim, sde_net_arch, activation_fn
89 | )
90 |
91 | self.action_dist = StateDependentNoiseDistribution(
92 | action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True
93 | )
94 | self.mu, self.log_std = self.action_dist.proba_distribution_net(
95 | latent_dim=last_layer_dim, latent_sde_dim=latent_sde_dim, log_std_init=log_std_init
96 | )
97 | # Avoid numerical issues by limiting the mean of the Gaussian
98 | # to be in [-clip_mean, clip_mean]
99 | if clip_mean > 0.0:
100 | self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean))
101 | else:
102 | self.action_dist = SquashedDiagGaussianDistribution(action_dim)
103 |
104 |
105 | def _get_constructor_parameters(self) -> Dict[str, Any]:
106 | data = super()._get_constructor_parameters()
107 |
108 | data.update(
109 | dict(
110 | net_arch=self.net_arch,
111 | features_dim=self.features_dim,
112 | activation_fn=self.activation_fn,
113 | use_sde=self.use_sde,
114 | log_std_init=self.log_std_init,
115 | full_std=self.full_std,
116 | sde_net_arch=self.sde_net_arch,
117 | use_expln=self.use_expln,
118 | features_extractor=self.features_extractor,
119 | clip_mean=self.clip_mean,
120 | )
121 | )
122 | return data
123 |
124 | def get_std(self) -> th.Tensor:
125 | """
126 | Retrieve the standard deviation of the action distribution.
127 | Only useful when using gSDE.
128 | It corresponds to ``th.exp(log_std)`` in the normal case,
129 | but is slightly different when using ``expln`` function
130 | (cf StateDependentNoiseDistribution doc).
131 |
132 | :return:
133 | """
134 | msg = "get_std() is only available when using gSDE"
135 | assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
136 | return self.action_dist.get_std(self.log_std)
137 |
138 | def reset_noise(self, batch_size: int = 1) -> None:
139 | """
140 | Sample new weights for the exploration matrix, when using gSDE.
141 |
142 | :param batch_size:
143 | """
144 | msg = "reset_noise() is only available when using gSDE"
145 | assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
146 | self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
147 |
148 | def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
149 | """
150 | Get the parameters for the action distribution.
151 |
152 | :param obs:
153 | :return:
154 | Mean, standard deviation and optional keyword arguments.
155 | """
156 | features = self.extract_features(obs)
157 | mean_actions, log_std = self.ddt(features)
158 |
159 | if self.use_sde:
160 | latent_sde = self.ddt # Feature extractor goes here
161 | if self.sde_features_extractor is not None:
162 | latent_sde = self.sde_features_extractor(features)
163 | return mean_actions, self.log_std, dict(latent_sde=latent_sde)
164 | # Unstructured exploration (Original implementation)
165 | # Original Implementation to cap the standard deviation
166 | # log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
167 | return mean_actions, log_std, {}
168 |
169 | def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
170 | mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
171 | # Note: the action is squashed
172 | return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
173 |
174 | def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
175 | mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
176 | # return action and associated log prob
177 | return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
178 |
179 | def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
180 | return self.forward(observation, deterministic)
181 |
182 |
183 | class DDT_SACPolicy(SACPolicy):
184 |
185 | def __init__(
186 | self,
187 | observation_space: gym.spaces.Space,
188 | action_space: gym.spaces.Space,
189 | lr_schedule: Schedule,
190 | net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
191 | activation_fn: Type[nn.Module] = nn.ReLU,
192 | use_sde: bool = False,
193 | log_std_init: float = -3,
194 | sde_net_arch: Optional[List[int]] = None,
195 | use_expln: bool = False,
196 | clip_mean: float = 2.0,
197 | features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
198 | features_extractor_kwargs: Optional[Dict[str, Any]] = None,
199 | normalize_images: bool = True,
200 | optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
201 | optimizer_kwargs: Optional[Dict[str, Any]] = None,
202 | n_critics: int = 2,
203 | share_features_extractor: bool = True,
204 | ddt_kwargs: Dict[str, Any] = None,
205 | ):
206 | self.ddt_kwargs = ddt_kwargs
207 | super(DDT_SACPolicy, self).__init__(
208 | observation_space,
209 | action_space,
210 | lr_schedule,
211 | net_arch,
212 | activation_fn,
213 | use_sde,
214 | log_std_init,
215 | sde_net_arch,
216 | use_expln,
217 | clip_mean,
218 | features_extractor_class,
219 | features_extractor_kwargs,
220 | normalize_images,
221 | optimizer_class,
222 | optimizer_kwargs,
223 | n_critics,
224 | share_features_extractor,
225 | )
226 |
227 | def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> DDTActor:
228 | actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
229 | actor_kwargs['ddt_kwargs'] = self.ddt_kwargs
230 | return DDTActor(**actor_kwargs).to(self.device)
231 |
232 | register_policy("DDT_SACPolicy", DDT_SACPolicy)
233 |
--------------------------------------------------------------------------------
/icct/rl_helpers/ddt_td3_policy.py:
--------------------------------------------------------------------------------
1 | # Created by Yaru Niu
2 |
3 | from typing import Any, Dict, List, Optional, Tuple, Type, Union
4 |
5 | import gym
6 | import torch as th
7 | from torch import nn
8 | from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, register_policy
9 |
10 | from icct.rl_helpers.td3_policies import TD3Policy
11 | from stable_baselines3.common.preprocessing import get_action_dim
12 | from stable_baselines3.common.torch_layers import (
13 | BaseFeaturesExtractor,
14 | CombinedExtractor,
15 | FlattenExtractor,
16 | NatureCNN,
17 | create_mlp,
18 | get_actor_critic_arch,
19 | )
20 | from icct.core.icct import ICCT
21 | from stable_baselines3.common.type_aliases import Schedule
22 |
23 |
24 | class DDTActor(BasePolicy):
25 |
26 | def __init__(
27 | self,
28 | observation_space: gym.spaces.Space,
29 | action_space: gym.spaces.Space,
30 | net_arch: List[int],
31 | features_extractor: nn.Module,
32 | features_dim: int,
33 | activation_fn: Type[nn.Module] = nn.ReLU,
34 | normalize_images: bool = True,
35 | ddt_kwargs: Dict[str, Any] = None,
36 | ):
37 | super(DDTActor, self).__init__(
38 | observation_space,
39 | action_space,
40 | features_extractor=features_extractor,
41 | normalize_images=normalize_images,
42 | squash_output=True,
43 | )
44 |
45 | # Save arguments to re-create object at loading
46 | self.net_arch = net_arch
47 | self.features_dim = features_dim
48 | self.activation_fn = activation_fn
49 | self.ddt_kwargs = ddt_kwargs
50 |
51 | action_dim = get_action_dim(self.action_space)
52 | last_layer_dim = features_dim
53 | self.ddt = ICCT(input_dim=features_dim,
54 | output_dim=action_dim,
55 | weights=None,
56 | comparators=None,
57 | leaves=ddt_kwargs['num_leaves'],
58 | alpha=None,
59 | use_individual_alpha=ddt_kwargs['use_individual_alpha'],
60 | device=ddt_kwargs['device'],
61 | use_submodels=ddt_kwargs['submodels'],
62 | hard_node=ddt_kwargs['hard_node'],
63 | argmax_tau=ddt_kwargs['argmax_tau'],
64 | sparse_submodel_type=ddt_kwargs['sparse_submodel_type'],
65 | fs_submodel_version = ddt_kwargs['fs_submodel_version'],
66 | l1_hard_attn=ddt_kwargs['l1_hard_attn'],
67 | num_sub_features=ddt_kwargs['num_sub_features'],
68 | use_gumbel_softmax=ddt_kwargs['use_gumbel_softmax'],
69 | alg_type=ddt_kwargs['alg_type']).to(ddt_kwargs['device'])
70 | print(self.ddt.state_dict())
71 |
72 |
73 | def _get_constructor_parameters(self) -> Dict[str, Any]:
74 | data = super()._get_constructor_parameters()
75 |
76 | data.update(
77 | dict(
78 | net_arch=self.net_arch,
79 | features_dim=self.features_dim,
80 | activation_fn=self.activation_fn,
81 | features_extractor=self.features_extractor,
82 | )
83 | )
84 | return data
85 |
86 | def forward(self, obs: th.Tensor) -> th.Tensor:
87 | # assert deterministic, 'The TD3 actor only outputs deterministic actions'
88 | features = self.extract_features(obs)
89 | return self.ddt(features)
90 |
91 | def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
92 | # Note: the deterministic parameter is ignored in the case of TD3.
93 | # Predictions are always deterministic.
94 | return self.forward(observation)
95 |
96 |
97 | class DDT_TD3Policy(TD3Policy):
98 |
99 | def __init__(
100 | self,
101 | observation_space: gym.spaces.Space,
102 | action_space: gym.spaces.Space,
103 | lr_schedule: Schedule,
104 | net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
105 | activation_fn: Type[nn.Module] = nn.ReLU,
106 | features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
107 | features_extractor_kwargs: Optional[Dict[str, Any]] = None,
108 | normalize_images: bool = True,
109 | optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
110 | optimizer_kwargs: Optional[Dict[str, Any]] = None,
111 | n_critics: int = 2,
112 | share_features_extractor: bool = True,
113 | ddt_kwargs: Dict[str, Any] = None,
114 | ):
115 | self.ddt_kwargs = ddt_kwargs
116 | super(DDT_TD3Policy, self).__init__(
117 | observation_space,
118 | action_space,
119 | lr_schedule,
120 | net_arch,
121 | activation_fn,
122 | features_extractor_class,
123 | features_extractor_kwargs,
124 | normalize_images,
125 | optimizer_class,
126 | optimizer_kwargs,
127 | n_critics,
128 | share_features_extractor
129 | )
130 |
131 | def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> DDTActor:
132 | actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
133 | actor_kwargs['ddt_kwargs'] = self.ddt_kwargs
134 | return DDTActor(**actor_kwargs).to(self.device)
135 |
136 | register_policy("DDT_TD3Policy", DDT_TD3Policy)
--------------------------------------------------------------------------------
/icct/rl_helpers/sac.py:
--------------------------------------------------------------------------------
1 | # Created by Yaru Niu
2 | # Revised from https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/sac/sac.py
3 |
4 | from typing import Any, Dict, List, Optional, Tuple, Type, Union
5 |
6 | import gym
7 | import numpy as np
8 | import torch as th
9 | from torch.nn import functional as F
10 |
11 | from stable_baselines3.common.buffers import ReplayBuffer
12 | from stable_baselines3.common.noise import ActionNoise
13 | from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
14 | from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
15 | from stable_baselines3.common.utils import polyak_update
16 | from icct.rl_helpers.sac_policies import SACPolicy
17 |
18 |
19 | class SAC(OffPolicyAlgorithm):
20 | """
21 | Soft Actor-Critic (SAC)
22 | Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor,
23 | This implementation borrows code from original implementation (https://github.com/haarnoja/sac)
24 | from OpenAI Spinning Up (https://github.com/openai/spinningup), from the softlearning repo
25 | (https://github.com/rail-berkeley/softlearning/)
26 | and from Stable Baselines (https://github.com/hill-a/stable-baselines)
27 | Paper: https://arxiv.org/abs/1801.01290
28 | Introduction to SAC: https://spinningup.openai.com/en/latest/algorithms/sac.html
29 |
30 | Note: we use double q target and not value target as discussed
31 | in https://github.com/hill-a/stable-baselines/issues/270
32 |
33 | :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
34 | :param env: The environment to learn from (if registered in Gym, can be str)
35 | :param learning_rate: learning rate for adam optimizer,
36 | the same learning rate will be used for all networks (Q-Values, Actor and Value function)
37 | it can be a function of the current progress remaining (from 1 to 0)
38 | :param buffer_size: size of the replay buffer
39 | :param learning_starts: how many steps of the model to collect transitions for before learning starts
40 | :param batch_size: Minibatch size for each gradient update
41 | :param tau: the soft update coefficient ("Polyak update", between 0 and 1)
42 | :param gamma: the discount factor
43 | :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
44 | like ``(5, "step")`` or ``(2, "episode")``.
45 | :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
46 | Set to ``-1`` means to do as many gradient steps as steps done in the environment
47 | during the rollout.
48 | :param action_noise: the action noise type (None by default), this can help
49 | for hard exploration problem. Cf common.noise for the different action noise type.
50 | :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
51 | If ``None``, it will be automatically selected.
52 | :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
53 | :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
54 | at a cost of more complexity.
55 | See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
56 | :param ent_coef: Entropy regularization coefficient. (Equivalent to
57 | inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off.
58 | Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value)
59 | :param target_update_interval: update the target network every ``target_network_update_freq``
60 | gradient steps.
61 | :param target_entropy: target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``)
62 | :param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
63 | instead of action noise exploration (default: False)
64 | :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
65 | Default: -1 (only sample at the beginning of the rollout)
66 | :param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling
67 | during the warm up phase (before learning starts)
68 | :param create_eval_env: Whether to create a second environment that will be
69 | used for evaluating the agent periodically. (Only available when passing string for the environment)
70 | :param policy_kwargs: additional arguments to be passed to the policy on creation
71 | :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
72 | :param seed: Seed for the pseudo random generators
73 | :param device: Device (cpu, cuda, ...) on which the code should be run.
74 | Setting it to auto, the code will be run on the GPU if possible.
75 | :param _init_setup_model: Whether or not to build the network at the creation of the instance
76 | """
77 |
78 | def __init__(
79 | self,
80 | policy: Union[str, Type[SACPolicy]],
81 | env: Union[GymEnv, str],
82 | learning_rate: Union[float, Schedule] = 3e-4,
83 | buffer_size: int = 1000000, # 1e6
84 | learning_starts: int = 100,
85 | batch_size: int = 256,
86 | tau: float = 0.005,
87 | gamma: float = 0.99,
88 | train_freq: Union[int, Tuple[int, str]] = 1,
89 | gradient_steps: int = 1,
90 | action_noise: Optional[ActionNoise] = None,
91 | replay_buffer_class: Optional[ReplayBuffer] = None,
92 | replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
93 | optimize_memory_usage: bool = False,
94 | ent_coef: Union[str, float] = "auto",
95 | target_update_interval: int = 1,
96 | target_entropy: Union[str, float] = "auto",
97 | use_sde: bool = False,
98 | sde_sample_freq: int = -1,
99 | use_sde_at_warmup: bool = False,
100 | tensorboard_log: Optional[str] = None,
101 | create_eval_env: bool = False,
102 | policy_kwargs: Dict[str, Any] = None,
103 | verbose: int = 0,
104 | seed: Optional[int] = None,
105 | device: Union[th.device, str] = "auto",
106 | _init_setup_model: bool = True,
107 | ):
108 |
109 | super(SAC, self).__init__(
110 | policy,
111 | env,
112 | SACPolicy,
113 | learning_rate,
114 | buffer_size,
115 | learning_starts,
116 | batch_size,
117 | tau,
118 | gamma,
119 | train_freq,
120 | gradient_steps,
121 | action_noise,
122 | replay_buffer_class=replay_buffer_class,
123 | replay_buffer_kwargs=replay_buffer_kwargs,
124 | policy_kwargs=policy_kwargs,
125 | tensorboard_log=tensorboard_log,
126 | verbose=verbose,
127 | device=device,
128 | create_eval_env=create_eval_env,
129 | seed=seed,
130 | use_sde=use_sde,
131 | sde_sample_freq=sde_sample_freq,
132 | use_sde_at_warmup=use_sde_at_warmup,
133 | optimize_memory_usage=optimize_memory_usage,
134 | supported_action_spaces=(gym.spaces.Box),
135 | )
136 |
137 | self.target_entropy = target_entropy
138 | self.log_ent_coef = None # type: Optional[th.Tensor]
139 | # Entropy coefficient / Entropy temperature
140 | # Inverse of the reward scale
141 | self.ent_coef = ent_coef
142 | self.target_update_interval = target_update_interval
143 | self.ent_coef_optimizer = None
144 |
145 | if _init_setup_model:
146 | self._setup_model()
147 |
148 | def _setup_model(self) -> None:
149 | super(SAC, self)._setup_model()
150 | self._create_aliases()
151 | # Target entropy is used when learning the entropy coefficient
152 | if self.target_entropy == "auto":
153 | # automatically set target entropy if needed
154 | self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32)
155 | else:
156 | # Force conversion
157 | # this will also throw an error for unexpected string
158 | self.target_entropy = float(self.target_entropy)
159 |
160 | # The entropy coefficient or entropy can be learned automatically
161 | # see Automating Entropy Adjustment for Maximum Entropy RL section
162 | # of https://arxiv.org/abs/1812.05905
163 | if isinstance(self.ent_coef, str) and self.ent_coef.startswith("auto"):
164 | # Default initial value of ent_coef when learned
165 | init_value = 1.0
166 | if "_" in self.ent_coef:
167 | init_value = float(self.ent_coef.split("_")[1])
168 | assert init_value > 0.0, "The initial value of ent_coef must be greater than 0"
169 |
170 | # Note: we optimize the log of the entropy coeff which is slightly different from the paper
171 | # as discussed in https://github.com/rail-berkeley/softlearning/issues/37
172 | self.log_ent_coef = th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True)
173 | self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.lr_schedule(1))
174 | else:
175 | # Force conversion to float
176 | # this will throw an error if a malformed string (different from 'auto')
177 | # is passed
178 | self.ent_coef_tensor = th.tensor(float(self.ent_coef)).to(self.device)
179 |
180 | def _create_aliases(self) -> None:
181 | self.actor = self.policy.actor
182 | self.critic = self.policy.critic
183 | self.critic_target = self.policy.critic_target
184 |
185 | def train(self, gradient_steps: int, batch_size: int = 64) -> None:
186 | # Update optimizers learning rate
187 | optimizers = [self.actor.optimizer, self.critic.optimizer]
188 | if self.ent_coef_optimizer is not None:
189 | optimizers += [self.ent_coef_optimizer]
190 |
191 | # # Update learning rate according to lr schedule
192 | # self._update_learning_rate(optimizers)
193 |
194 | ent_coef_losses, ent_coefs = [], []
195 | actor_losses, critic_losses = [], []
196 | l1_reg_losses = []
197 |
198 | for gradient_step in range(gradient_steps):
199 | # Sample replay buffer
200 | replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
201 |
202 | # We need to sample because `log_std` may have changed between two gradient steps
203 | if self.use_sde:
204 | self.actor.reset_noise()
205 |
206 | # Action by the current actor for the sampled state
207 | actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations)
208 | log_prob = log_prob.reshape(-1, 1)
209 |
210 | if hasattr(self.actor, 'ddt'):
211 | if self.actor.ddt.use_submodels and self.actor.ddt.sparse_submodel_type == 1:
212 | attn = self.actor.ddt.leaf_attn.repeat_interleave(2)
213 |
214 | ent_coef_loss = None
215 | if self.ent_coef_optimizer is not None:
216 | # Important: detach the variable from the graph
217 | # so we don't change it with other losses
218 | # see https://github.com/rail-berkeley/softlearning/issues/60
219 | ent_coef = th.exp(self.log_ent_coef.detach())
220 | ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
221 | ent_coef_losses.append(ent_coef_loss.item())
222 | else:
223 | ent_coef = self.ent_coef_tensor
224 |
225 | ent_coefs.append(ent_coef.item())
226 |
227 | # Optimize entropy coefficient, also called
228 | # entropy temperature or alpha in the paper
229 | if ent_coef_loss is not None:
230 | self.ent_coef_optimizer.zero_grad()
231 | ent_coef_loss.backward()
232 | self.ent_coef_optimizer.step()
233 |
234 | with th.no_grad():
235 | # Select action according to policy
236 | next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
237 | # Compute the next Q values: min over all critics targets
238 | next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
239 | next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
240 | # add entropy term
241 | next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1)
242 | # td error + entropy term
243 | target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
244 |
245 | # Get current Q-values estimates for each critic network
246 | # using action from the replay buffer
247 | current_q_values = self.critic(replay_data.observations, replay_data.actions)
248 |
249 | # Compute critic loss
250 | critic_loss = 0.5 * sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values])
251 | critic_losses.append(critic_loss.item())
252 |
253 | # Optimize the critic
254 | self.critic.optimizer.zero_grad()
255 | critic_loss.backward()
256 | self.critic.optimizer.step()
257 |
258 | # Compute actor loss
259 | # Alternative: actor_loss = th.mean(log_prob - qf1_pi)
260 | # Mean over all critic networks
261 | q_values_pi = th.cat(self.critic.forward(replay_data.observations, actions_pi), dim=1)
262 | min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
263 | actor_loss = (ent_coef * log_prob - min_qf_pi).mean()
264 |
265 | if hasattr(self.actor, 'ddt'):
266 | if self.actor.ddt.use_submodels and self.actor.ddt.sparse_submodel_type == 1:
267 | l1_reg_loss = 0
268 | if self.actor.ddt_kwargs['l1_reg_bias']:
269 | for i, (name, p) in enumerate(self.actor.ddt.lin_models.named_parameters()):
270 | l1_reg_loss += th.sum(abs(p)) * attn[i]
271 | else:
272 | for i, (name, p) in enumerate(self.actor.ddt.lin_models.named_parameters()):
273 | if not 'bias' in name:
274 | l1_reg_loss += th.sum(abs(p)) * attn[i]
275 | l1_reg_loss *= self.actor.ddt_kwargs['l1_reg_coeff'] * self.actor.ddt.leaf_attn.size(0)
276 | l1_reg_losses.append(l1_reg_loss.item())
277 | actor_loss += l1_reg_loss
278 |
279 | actor_losses.append(actor_loss.item())
280 |
281 | # Optimize the actor
282 | self.actor.optimizer.zero_grad()
283 | actor_loss.backward()
284 | self.actor.optimizer.step()
285 |
286 | # Update target networks
287 | if gradient_step % self.target_update_interval == 0:
288 | polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
289 |
290 | self._n_updates += gradient_steps
291 |
292 | self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
293 | self.logger.record("train/ent_coef", np.mean(ent_coefs))
294 | self.logger.record("train/actor_loss", np.mean(actor_losses))
295 | self.logger.record("train/critic_loss", np.mean(critic_losses))
296 | if len(ent_coef_losses) > 0:
297 | self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
298 | if len(l1_reg_losses) > 0:
299 | self.logger.record("train/l1_reg_loss", np.mean(l1_reg_losses))
300 |
301 | def learn(
302 | self,
303 | total_timesteps: int,
304 | callback: MaybeCallback = None,
305 | log_interval: int = 4,
306 | eval_env: Optional[GymEnv] = None,
307 | eval_freq: int = -1,
308 | n_eval_episodes: int = 5,
309 | tb_log_name: str = "SAC",
310 | eval_log_path: Optional[str] = None,
311 | reset_num_timesteps: bool = True,
312 | ) -> OffPolicyAlgorithm:
313 |
314 | return super(SAC, self).learn(
315 | total_timesteps=total_timesteps,
316 | callback=callback,
317 | log_interval=log_interval,
318 | eval_env=eval_env,
319 | eval_freq=eval_freq,
320 | n_eval_episodes=n_eval_episodes,
321 | tb_log_name=tb_log_name,
322 | eval_log_path=eval_log_path,
323 | reset_num_timesteps=reset_num_timesteps,
324 | )
325 |
326 | def _excluded_save_params(self) -> List[str]:
327 | return super(SAC, self)._excluded_save_params() + ["actor", "critic", "critic_target"]
328 |
329 | def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
330 | state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
331 | if self.ent_coef_optimizer is not None:
332 | saved_pytorch_variables = ["log_ent_coef"]
333 | state_dicts.append("ent_coef_optimizer")
334 | else:
335 | saved_pytorch_variables = ["ent_coef_tensor"]
336 | return state_dicts, saved_pytorch_variables
337 |
--------------------------------------------------------------------------------
/icct/rl_helpers/sac_policies.py:
--------------------------------------------------------------------------------
1 | # Created by Yaru Niu
2 | # Revised from https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/sac/policies.py
3 |
4 | from typing import Any, Dict, List, Optional, Tuple, Type, Union
5 |
6 | import gym
7 | import torch as th
8 | from torch import nn
9 |
10 | from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
11 | from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, create_sde_features_extractor, register_policy
12 | from stable_baselines3.common.preprocessing import get_action_dim
13 | from stable_baselines3.common.torch_layers import (
14 | BaseFeaturesExtractor,
15 | CombinedExtractor,
16 | FlattenExtractor,
17 | NatureCNN,
18 | create_mlp,
19 | get_actor_critic_arch,
20 | )
21 | from stable_baselines3.common.type_aliases import Schedule
22 |
23 | # CAP the standard deviation of the actor
24 | LOG_STD_MAX = 2
25 | LOG_STD_MIN = -20
26 |
27 |
28 | class Actor(BasePolicy):
29 | """
30 | Actor network (policy) for SAC.
31 |
32 | :param observation_space: Obervation space
33 | :param action_space: Action space
34 | :param net_arch: Network architecture
35 | :param features_extractor: Network to extract features
36 | (a CNN when using images, a nn.Flatten() layer otherwise)
37 | :param features_dim: Number of features
38 | :param activation_fn: Activation function
39 | :param use_sde: Whether to use State Dependent Exploration or not
40 | :param log_std_init: Initial value for the log standard deviation
41 | :param full_std: Whether to use (n_features x n_actions) parameters
42 | for the std instead of only (n_features,) when using gSDE.
43 | :param sde_net_arch: Network architecture for extracting features
44 | when using gSDE. If None, the latent features from the policy will be used.
45 | Pass an empty list to use the states as features.
46 | :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
47 | a positive standard deviation (cf paper). It allows to keep variance
48 | above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
49 | :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
50 | :param normalize_images: Whether to normalize images or not,
51 | dividing by 255.0 (True by default)
52 | """
53 |
54 | def __init__(
55 | self,
56 | observation_space: gym.spaces.Space,
57 | action_space: gym.spaces.Space,
58 | net_arch: List[int],
59 | features_extractor: nn.Module,
60 | features_dim: int,
61 | activation_fn: Type[nn.Module] = nn.ReLU,
62 | use_sde: bool = False,
63 | log_std_init: float = -3,
64 | full_std: bool = True,
65 | sde_net_arch: Optional[List[int]] = None,
66 | use_expln: bool = False,
67 | clip_mean: float = 2.0,
68 | normalize_images: bool = True,
69 | ):
70 | super(Actor, self).__init__(
71 | observation_space,
72 | action_space,
73 | features_extractor=features_extractor,
74 | normalize_images=normalize_images,
75 | squash_output=True,
76 | )
77 |
78 | # Save arguments to re-create object at loading
79 | self.use_sde = use_sde
80 | self.sde_features_extractor = None
81 | self.sde_net_arch = sde_net_arch
82 | self.net_arch = net_arch
83 | self.features_dim = features_dim
84 | self.activation_fn = activation_fn
85 | self.log_std_init = log_std_init
86 | self.sde_net_arch = sde_net_arch
87 | self.use_expln = use_expln
88 | self.full_std = full_std
89 | self.clip_mean = clip_mean
90 | self.processed_obs = None
91 | self.raw_act = None
92 |
93 | action_dim = get_action_dim(self.action_space)
94 | latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn)
95 | self.latent_pi = nn.Sequential(*latent_pi_net)
96 | last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
97 |
98 | if self.use_sde:
99 | latent_sde_dim = last_layer_dim
100 | # Separate features extractor for gSDE
101 | if sde_net_arch is not None:
102 | self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(
103 | features_dim, sde_net_arch, activation_fn
104 | )
105 |
106 | self.action_dist = StateDependentNoiseDistribution(
107 | action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True
108 | )
109 | self.mu, self.log_std = self.action_dist.proba_distribution_net(
110 | latent_dim=last_layer_dim, latent_sde_dim=latent_sde_dim, log_std_init=log_std_init
111 | )
112 | # Avoid numerical issues by limiting the mean of the Gaussian
113 | # to be in [-clip_mean, clip_mean]
114 | if clip_mean > 0.0:
115 | self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean))
116 | else:
117 | self.action_dist = SquashedDiagGaussianDistribution(action_dim)
118 | self.mu = nn.Linear(last_layer_dim, action_dim)
119 | self.log_std = nn.Linear(last_layer_dim, action_dim)
120 |
121 | def _get_constructor_parameters(self) -> Dict[str, Any]:
122 | data = super()._get_constructor_parameters()
123 |
124 | data.update(
125 | dict(
126 | net_arch=self.net_arch,
127 | features_dim=self.features_dim,
128 | activation_fn=self.activation_fn,
129 | use_sde=self.use_sde,
130 | log_std_init=self.log_std_init,
131 | full_std=self.full_std,
132 | sde_net_arch=self.sde_net_arch,
133 | use_expln=self.use_expln,
134 | features_extractor=self.features_extractor,
135 | clip_mean=self.clip_mean,
136 | )
137 | )
138 | return data
139 |
140 | def get_std(self) -> th.Tensor:
141 | """
142 | Retrieve the standard deviation of the action distribution.
143 | Only useful when using gSDE.
144 | It corresponds to ``th.exp(log_std)`` in the normal case,
145 | but is slightly different when using ``expln`` function
146 | (cf StateDependentNoiseDistribution doc).
147 |
148 | :return:
149 | """
150 | msg = "get_std() is only available when using gSDE"
151 | assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
152 | return self.action_dist.get_std(self.log_std)
153 |
154 | def reset_noise(self, batch_size: int = 1) -> None:
155 | """
156 | Sample new weights for the exploration matrix, when using gSDE.
157 |
158 | :param batch_size:
159 | """
160 | msg = "reset_noise() is only available when using gSDE"
161 | assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
162 | self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
163 |
164 | def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
165 | """
166 | Get the parameters for the action distribution.
167 |
168 | :param obs:
169 | :return:
170 | Mean, standard deviation and optional keyword arguments.
171 | """
172 | features = self.extract_features(obs)
173 | latent_pi = self.latent_pi(features)
174 | mean_actions = self.mu(latent_pi)
175 |
176 | self.processed_obs = features
177 | self.raw_act = mean_actions
178 |
179 | if self.use_sde:
180 | latent_sde = latent_pi
181 | if self.sde_features_extractor is not None:
182 | latent_sde = self.sde_features_extractor(features)
183 | return mean_actions, self.log_std, dict(latent_sde=latent_sde)
184 | # Unstructured exploration (Original implementation)
185 | log_std = self.log_std(latent_pi)
186 | # Original Implementation to cap the standard deviation
187 | log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
188 | return mean_actions, log_std, {}
189 |
190 | def get_sa_pair(self):
191 | """
192 | Get real-time state and action pairs
193 | """
194 | return [self.processed_obs, self.raw_act]
195 |
196 | def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
197 | mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
198 | # Note: the action is squashed
199 | return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
200 |
201 | def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
202 | mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
203 | # return action and associated log prob
204 | return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
205 |
206 | def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
207 | return self.forward(observation, deterministic)
208 |
209 |
210 | class SACPolicy(BasePolicy):
211 | """
212 | Policy class (with both actor and critic) for SAC.
213 |
214 | :param observation_space: Observation space
215 | :param action_space: Action space
216 | :param lr_schedule: Learning rate schedule (could be constant)
217 | :param net_arch: The specification of the policy and value networks.
218 | :param activation_fn: Activation function
219 | :param use_sde: Whether to use State Dependent Exploration or not
220 | :param log_std_init: Initial value for the log standard deviation
221 | :param sde_net_arch: Network architecture for extracting features
222 | when using gSDE. If None, the latent features from the policy will be used.
223 | Pass an empty list to use the states as features.
224 | :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
225 | a positive standard deviation (cf paper). It allows to keep variance
226 | above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
227 | :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
228 | :param features_extractor_class: Features extractor to use.
229 | :param features_extractor_kwargs: Keyword arguments
230 | to pass to the features extractor.
231 | :param normalize_images: Whether to normalize images or not,
232 | dividing by 255.0 (True by default)
233 | :param optimizer_class: The optimizer to use,
234 | ``th.optim.Adam`` by default
235 | :param optimizer_kwargs: Additional keyword arguments,
236 | excluding the learning rate, to pass to the optimizer
237 | :param n_critics: Number of critic networks to create.
238 | :param share_features_extractor: Whether to share or not the features extractor
239 | between the actor and the critic (this saves computation time)
240 | """
241 |
242 | def __init__(
243 | self,
244 | observation_space: gym.spaces.Space,
245 | action_space: gym.spaces.Space,
246 | lr_schedule: Schedule,
247 | net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
248 | activation_fn: Type[nn.Module] = nn.ReLU,
249 | use_sde: bool = False,
250 | log_std_init: float = -3,
251 | sde_net_arch: Optional[List[int]] = None,
252 | use_expln: bool = False,
253 | clip_mean: float = 2.0,
254 | features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
255 | features_extractor_kwargs: Optional[Dict[str, Any]] = None,
256 | normalize_images: bool = True,
257 | optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
258 | optimizer_kwargs: Optional[Dict[str, Any]] = None,
259 | n_critics: int = 2,
260 | share_features_extractor: bool = True,
261 | ):
262 | super(SACPolicy, self).__init__(
263 | observation_space,
264 | action_space,
265 | features_extractor_class,
266 | features_extractor_kwargs,
267 | optimizer_class=optimizer_class,
268 | optimizer_kwargs=optimizer_kwargs,
269 | squash_output=True,
270 | )
271 |
272 | if net_arch is None:
273 | if features_extractor_class == NatureCNN:
274 | net_arch = []
275 | else:
276 | net_arch = [256, 256]
277 |
278 | actor_arch, critic_arch = get_actor_critic_arch(net_arch)
279 |
280 | self.net_arch = net_arch
281 | self.activation_fn = activation_fn
282 | self.net_args = {
283 | "observation_space": self.observation_space,
284 | "action_space": self.action_space,
285 | "net_arch": actor_arch,
286 | "activation_fn": self.activation_fn,
287 | "normalize_images": normalize_images,
288 | }
289 | self.actor_kwargs = self.net_args.copy()
290 | sde_kwargs = {
291 | "use_sde": use_sde,
292 | "log_std_init": log_std_init,
293 | "sde_net_arch": sde_net_arch,
294 | "use_expln": use_expln,
295 | "clip_mean": clip_mean,
296 | }
297 | self.actor_kwargs.update(sde_kwargs)
298 | self.critic_kwargs = self.net_args.copy()
299 | self.critic_kwargs.update(
300 | {
301 | "n_critics": n_critics,
302 | "net_arch": critic_arch,
303 | "share_features_extractor": share_features_extractor,
304 | }
305 | )
306 |
307 | self.actor, self.actor_target = None, None
308 | self.critic, self.critic_target = None, None
309 | self.share_features_extractor = share_features_extractor
310 |
311 | self._build(lr_schedule)
312 |
313 | def _build(self, lr_schedule: Schedule) -> None:
314 | self.actor = self.make_actor()
315 | if hasattr(self.actor, 'ddt'):
316 | self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=self.actor.ddt_kwargs['ddt_lr'], **self.optimizer_kwargs)
317 | else:
318 | self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
319 |
320 | if self.share_features_extractor:
321 | self.critic = self.make_critic(features_extractor=self.actor.features_extractor)
322 | # Do not optimize the shared features extractor with the critic loss
323 | # otherwise, there are gradient computation issues
324 | critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name]
325 | else:
326 | # Create a separate features extractor for the critic
327 | # this requires more memory and computation
328 | self.critic = self.make_critic(features_extractor=None)
329 | critic_parameters = self.critic.parameters()
330 |
331 | # Critic target should not share the features extractor with critic
332 | self.critic_target = self.make_critic(features_extractor=None)
333 | self.critic_target.load_state_dict(self.critic.state_dict())
334 |
335 | self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)
336 |
337 | def _get_constructor_parameters(self) -> Dict[str, Any]:
338 | data = super()._get_constructor_parameters()
339 |
340 | data.update(
341 | dict(
342 | net_arch=self.net_arch,
343 | activation_fn=self.net_args["activation_fn"],
344 | use_sde=self.actor_kwargs["use_sde"],
345 | log_std_init=self.actor_kwargs["log_std_init"],
346 | sde_net_arch=self.actor_kwargs["sde_net_arch"],
347 | use_expln=self.actor_kwargs["use_expln"],
348 | clip_mean=self.actor_kwargs["clip_mean"],
349 | n_critics=self.critic_kwargs["n_critics"],
350 | lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
351 | optimizer_class=self.optimizer_class,
352 | optimizer_kwargs=self.optimizer_kwargs,
353 | features_extractor_class=self.features_extractor_class,
354 | features_extractor_kwargs=self.features_extractor_kwargs,
355 | )
356 | )
357 | return data
358 |
359 | def reset_noise(self, batch_size: int = 1) -> None:
360 | """
361 | Sample new weights for the exploration matrix, when using gSDE.
362 |
363 | :param batch_size:
364 | """
365 | self.actor.reset_noise(batch_size=batch_size)
366 |
367 | def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor:
368 | actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
369 | return Actor(**actor_kwargs).to(self.device)
370 |
371 | def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> ContinuousCritic:
372 | critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
373 | return ContinuousCritic(**critic_kwargs).to(self.device)
374 |
375 | def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
376 | return self._predict(obs, deterministic=deterministic)
377 |
378 | def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
379 | return self.actor(observation, deterministic)
380 |
381 |
382 | MlpPolicy = SACPolicy
383 |
384 |
385 | class MultiInputPolicy(SACPolicy):
386 | """
387 | Policy class (with both actor and critic) for SAC.
388 |
389 | :param observation_space: Observation space
390 | :param action_space: Action space
391 | :param lr_schedule: Learning rate schedule (could be constant)
392 | :param net_arch: The specification of the policy and value networks.
393 | :param activation_fn: Activation function
394 | :param use_sde: Whether to use State Dependent Exploration or not
395 | :param log_std_init: Initial value for the log standard deviation
396 | :param sde_net_arch: Network architecture for extracting features
397 | when using gSDE. If None, the latent features from the policy will be used.
398 | Pass an empty list to use the states as features.
399 | :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
400 | a positive standard deviation (cf paper). It allows to keep variance
401 | above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
402 | :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
403 | :param features_extractor_class: Features extractor to use.
404 | :param normalize_images: Whether to normalize images or not,
405 | dividing by 255.0 (True by default)
406 | :param optimizer_class: The optimizer to use,
407 | ``th.optim.Adam`` by default
408 | :param optimizer_kwargs: Additional keyword arguments,
409 | excluding the learning rate, to pass to the optimizer
410 | :param n_critics: Number of critic networks to create.
411 | :param share_features_extractor: Whether to share or not the features extractor
412 | between the actor and the critic (this saves computation time)
413 | """
414 |
415 | def __init__(
416 | self,
417 | observation_space: gym.spaces.Space,
418 | action_space: gym.spaces.Space,
419 | lr_schedule: Schedule,
420 | net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
421 | activation_fn: Type[nn.Module] = nn.ReLU,
422 | use_sde: bool = False,
423 | log_std_init: float = -3,
424 | sde_net_arch: Optional[List[int]] = None,
425 | use_expln: bool = False,
426 | clip_mean: float = 2.0,
427 | features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
428 | features_extractor_kwargs: Optional[Dict[str, Any]] = None,
429 | normalize_images: bool = True,
430 | optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
431 | optimizer_kwargs: Optional[Dict[str, Any]] = None,
432 | n_critics: int = 2,
433 | share_features_extractor: bool = True,
434 | ):
435 | super(MultiInputPolicy, self).__init__(
436 | observation_space,
437 | action_space,
438 | lr_schedule,
439 | net_arch,
440 | activation_fn,
441 | use_sde,
442 | log_std_init,
443 | sde_net_arch,
444 | use_expln,
445 | clip_mean,
446 | features_extractor_class,
447 | features_extractor_kwargs,
448 | normalize_images,
449 | optimizer_class,
450 | optimizer_kwargs,
451 | n_critics,
452 | share_features_extractor,
453 | )
454 |
455 |
456 | register_policy("MlpPolicy", MlpPolicy)
457 | register_policy("MultiInputPolicy", MultiInputPolicy)
458 |
--------------------------------------------------------------------------------
/icct/rl_helpers/save_after_ep_callback.py:
--------------------------------------------------------------------------------
1 | # Created by Yaru Niu and Andrew Silva
2 |
3 | from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
4 | import os
5 | from typing import Any, Callable, Dict, List, Optional, Union
6 | from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization
7 | from stable_baselines3.common.evaluation import evaluate_policy
8 | import gym
9 | import numpy as np
10 | import warnings
11 |
12 |
13 | class EpCheckPointCallback(EvalCallback):
14 | """
15 | Callback for evaluating an agent.
16 |
17 | :param eval_env: The environment used for initialization
18 | :param callback_on_new_best: Callback to trigger
19 | when there is a new best model according to the ``mean_reward``
20 | :param n_eval_episodes: The number of episodes to test the agent
21 | :param eval_freq: Evaluate the agent every eval_freq call of the callback.
22 | :param log_path: Path to a folder where the evaluations
23 | will be saved. It will be updated at each evaluation.
24 | :param minimum_reward: The minimum reward to reach to save a model
25 | :param best_model_save_path: Path to a folder where the best model
26 | according to performance on the eval env will be saved.
27 | :param deterministic: Whether the evaluation should
28 | use a stochastic or deterministic actions.
29 | :param render: Whether to render or not the environment during evaluation
30 | :param verbose:
31 | :param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been
32 | wrapped with a Monitor wrapper)
33 | """
34 |
35 | def __init__(
36 | self,
37 | eval_env: Union[gym.Env, VecEnv],
38 | callback_on_new_best: Optional[BaseCallback] = None,
39 | n_eval_episodes: int = 5,
40 | eval_freq: int = 10000,
41 | log_path: str = None,
42 | minimum_reward: int = 200,
43 | best_model_save_path: str = None,
44 | deterministic: bool = True,
45 | render: bool = False,
46 | verbose: int = 1,
47 | warn: bool = True,
48 | ):
49 | super(EpCheckPointCallback, self).__init__(eval_env=eval_env,
50 | callback_on_new_best=callback_on_new_best,
51 | n_eval_episodes=n_eval_episodes,
52 | eval_freq=eval_freq,
53 | log_path=log_path,
54 | best_model_save_path=best_model_save_path,
55 | deterministic=deterministic,
56 | render=render,
57 | warn=warn,
58 | verbose=verbose)
59 | self.minimum_reward = minimum_reward
60 |
61 | def _on_step(self) -> bool:
62 |
63 | if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
64 | # Sync training and eval env if there is VecNormalize
65 | sync_envs_normalization(self.training_env, self.eval_env)
66 |
67 | # Reset success rate buffer
68 | self._is_success_buffer = []
69 |
70 | episode_rewards, episode_lengths = evaluate_policy(
71 | self.model,
72 | self.eval_env,
73 | n_eval_episodes=self.n_eval_episodes,
74 | render=self.render,
75 | deterministic=self.deterministic,
76 | return_episode_rewards=True,
77 | warn=self.warn,
78 | callback=self._log_success_callback,
79 | )
80 |
81 | if self.log_path is not None:
82 | self.evaluations_timesteps.append(self.num_timesteps)
83 | self.evaluations_results.append(episode_rewards)
84 | self.evaluations_length.append(episode_lengths)
85 |
86 | kwargs = {}
87 | # Save success log if present
88 | if len(self._is_success_buffer) > 0:
89 | self.evaluations_successes.append(self._is_success_buffer)
90 | kwargs = dict(successes=self.evaluations_successes)
91 |
92 | np.savez(
93 | self.log_path,
94 | timesteps=self.evaluations_timesteps,
95 | results=self.evaluations_results,
96 | ep_lengths=self.evaluations_length,
97 | **kwargs,
98 | )
99 |
100 | mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
101 | mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
102 | self.last_mean_reward = mean_reward
103 |
104 | if self.verbose > 0:
105 | print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
106 | print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}")
107 | # Add to current Logger
108 | self.logger.record("eval/mean_reward", float(mean_reward))
109 | self.logger.record("eval/mean_ep_length", mean_ep_length)
110 |
111 | if len(self._is_success_buffer) > 0:
112 | success_rate = np.mean(self._is_success_buffer)
113 | if self.verbose > 0:
114 | print(f"Success rate: {100 * success_rate:.2f}%")
115 | self.logger.record("eval/success_rate", success_rate)
116 |
117 | if mean_reward > self.best_mean_reward:
118 | if self.verbose > 0:
119 | print("New best mean reward!")
120 | if self.best_model_save_path is not None:
121 | self.model.save(os.path.join(self.best_model_save_path, "best_model"))
122 | self.best_mean_reward = mean_reward
123 | # Trigger callback if needed
124 | if self.callback is not None:
125 | return self._on_event()
126 | if mean_reward > self.minimum_reward:
127 | if self.best_model_save_path is not None:
128 | self.model.save(os.path.join(self.best_model_save_path, f"callback_{self.n_calls}"))
129 |
130 | return True
131 |
--------------------------------------------------------------------------------
/icct/rl_helpers/td3.py:
--------------------------------------------------------------------------------
1 | # Created by Yaru Niu
2 | # Revised from https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/td3/td3.py
3 |
4 | from typing import Any, Dict, List, Optional, Tuple, Type, Union
5 |
6 | import gym
7 | import numpy as np
8 | import torch as th
9 | from torch.nn import functional as F
10 |
11 | from stable_baselines3.common.buffers import ReplayBuffer
12 | from stable_baselines3.common.noise import ActionNoise
13 | from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
14 | from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
15 | from stable_baselines3.common.utils import polyak_update
16 | from icct.rl_helpers.td3_policies import TD3Policy
17 |
18 |
19 | class TD3(OffPolicyAlgorithm):
20 | """
21 | Twin Delayed DDPG (TD3)
22 | Addressing Function Approximation Error in Actor-Critic Methods.
23 |
24 | Original implementation: https://github.com/sfujim/TD3
25 | Paper: https://arxiv.org/abs/1802.09477
26 | Introduction to TD3: https://spinningup.openai.com/en/latest/algorithms/td3.html
27 |
28 | :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
29 | :param env: The environment to learn from (if registered in Gym, can be str)
30 | :param learning_rate: learning rate for adam optimizer,
31 | the same learning rate will be used for all networks (Q-Values, Actor and Value function)
32 | it can be a function of the current progress remaining (from 1 to 0)
33 | :param buffer_size: size of the replay buffer
34 | :param learning_starts: how many steps of the model to collect transitions for before learning starts
35 | :param batch_size: Minibatch size for each gradient update
36 | :param tau: the soft update coefficient ("Polyak update", between 0 and 1)
37 | :param gamma: the discount factor
38 | :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
39 | like ``(5, "step")`` or ``(2, "episode")``.
40 | :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
41 | Set to ``-1`` means to do as many gradient steps as steps done in the environment
42 | during the rollout.
43 | :param action_noise: the action noise type (None by default), this can help
44 | for hard exploration problem. Cf common.noise for the different action noise type.
45 | :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
46 | If ``None``, it will be automatically selected.
47 | :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
48 | :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
49 | at a cost of more complexity.
50 | See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
51 | :param policy_delay: Policy and target networks will only be updated once every policy_delay steps
52 | per training steps. The Q values will be updated policy_delay more often (update every training step).
53 | :param target_policy_noise: Standard deviation of Gaussian noise added to target policy
54 | (smoothing noise)
55 | :param target_noise_clip: Limit for absolute value of target policy smoothing noise.
56 | :param create_eval_env: Whether to create a second environment that will be
57 | used for evaluating the agent periodically. (Only available when passing string for the environment)
58 | :param policy_kwargs: additional arguments to be passed to the policy on creation
59 | :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
60 | :param seed: Seed for the pseudo random generators
61 | :param device: Device (cpu, cuda, ...) on which the code should be run.
62 | Setting it to auto, the code will be run on the GPU if possible.
63 | :param _init_setup_model: Whether or not to build the network at the creation of the instance
64 | """
65 |
66 | def __init__(
67 | self,
68 | policy: Union[str, Type[TD3Policy]],
69 | env: Union[GymEnv, str],
70 | learning_rate: Union[float, Schedule] = 1e-3,
71 | buffer_size: int = 1000000, # 1e6
72 | learning_starts: int = 100,
73 | batch_size: int = 100,
74 | tau: float = 0.005,
75 | gamma: float = 0.99,
76 | train_freq: Union[int, Tuple[int, str]] = (1, "episode"),
77 | gradient_steps: int = -1,
78 | action_noise: Optional[ActionNoise] = None,
79 | replay_buffer_class: Optional[ReplayBuffer] = None,
80 | replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
81 | optimize_memory_usage: bool = False,
82 | policy_delay: int = 2,
83 | target_policy_noise: float = 0.2,
84 | target_noise_clip: float = 0.5,
85 | tensorboard_log: Optional[str] = None,
86 | create_eval_env: bool = False,
87 | policy_kwargs: Dict[str, Any] = None,
88 | verbose: int = 0,
89 | seed: Optional[int] = None,
90 | device: Union[th.device, str] = "auto",
91 | _init_setup_model: bool = True,
92 | ):
93 |
94 | super(TD3, self).__init__(
95 | policy,
96 | env,
97 | TD3Policy,
98 | learning_rate,
99 | buffer_size,
100 | learning_starts,
101 | batch_size,
102 | tau,
103 | gamma,
104 | train_freq,
105 | gradient_steps,
106 | action_noise=action_noise,
107 | replay_buffer_class=replay_buffer_class,
108 | replay_buffer_kwargs=replay_buffer_kwargs,
109 | policy_kwargs=policy_kwargs,
110 | tensorboard_log=tensorboard_log,
111 | verbose=verbose,
112 | device=device,
113 | create_eval_env=create_eval_env,
114 | seed=seed,
115 | sde_support=False,
116 | optimize_memory_usage=optimize_memory_usage,
117 | supported_action_spaces=(gym.spaces.Box),
118 | )
119 |
120 | self.policy_delay = policy_delay
121 | self.target_noise_clip = target_noise_clip
122 | self.target_policy_noise = target_policy_noise
123 |
124 | if _init_setup_model:
125 | self._setup_model()
126 |
127 | def _setup_model(self) -> None:
128 | super(TD3, self)._setup_model()
129 | self._create_aliases()
130 |
131 | def _create_aliases(self) -> None:
132 | self.actor = self.policy.actor
133 | self.actor_target = self.policy.actor_target
134 | self.critic = self.policy.critic
135 | self.critic_target = self.policy.critic_target
136 |
137 | def train(self, gradient_steps: int, batch_size: int = 100) -> None:
138 |
139 | # # Update learning rate according to lr schedule
140 | # self._update_learning_rate([self.actor.optimizer, self.critic.optimizer])
141 |
142 | actor_losses, critic_losses = [], []
143 | l1_reg_losses = []
144 |
145 | for _ in range(gradient_steps):
146 |
147 | self._n_updates += 1
148 | # Sample replay buffer
149 | replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
150 |
151 | with th.no_grad():
152 | # Select action according to policy and add clipped noise
153 | noise = replay_data.actions.clone().data.normal_(0, self.target_policy_noise)
154 | noise = noise.clamp(-self.target_noise_clip, self.target_noise_clip)
155 | next_actions = (self.actor_target(replay_data.next_observations) + noise).clamp(-1, 1)
156 |
157 | # Compute the next Q-values: min over all critics targets
158 | next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
159 | next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
160 | target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
161 |
162 | # Get current Q-values estimates for each critic network
163 | current_q_values = self.critic(replay_data.observations, replay_data.actions)
164 |
165 | # Compute critic loss
166 | critic_loss = sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values])
167 | critic_losses.append(critic_loss.item())
168 |
169 | # Optimize the critics
170 | self.critic.optimizer.zero_grad()
171 | critic_loss.backward()
172 | self.critic.optimizer.step()
173 |
174 | # Delayed policy updates
175 | if self._n_updates % self.policy_delay == 0:
176 | # Compute actor loss
177 | actor_loss = -self.critic.q1_forward(replay_data.observations, self.actor(replay_data.observations)).mean()
178 |
179 | if hasattr(self.actor, 'ddt'):
180 | if self.actor.ddt.use_submodels and self.actor.ddt.sparse_submodel_type == 1:
181 | attn = self.actor.ddt.leaf_attn.repeat_interleave(2)
182 | l1_reg_loss = 0
183 | if self.actor.ddt_kwargs['l1_reg_bias']:
184 | for i, (name, p) in enumerate(self.actor.ddt.lin_models.named_parameters()):
185 | l1_reg_loss += th.sum(abs(p)) * attn[i]
186 | else:
187 | for i, (name, p) in enumerate(self.actor.ddt.lin_models.named_parameters()):
188 | if not 'bias' in name:
189 | l1_reg_loss += th.sum(abs(p)) * attn[i]
190 | l1_reg_loss *= self.actor.ddt_kwargs['l1_reg_coeff'] * self.actor.ddt.leaf_attn.size(0)
191 | l1_reg_losses.append(l1_reg_loss.item())
192 | actor_loss += l1_reg_loss
193 |
194 | actor_losses.append(actor_loss.item())
195 |
196 | # Optimize the actor
197 | self.actor.optimizer.zero_grad()
198 | actor_loss.backward()
199 | self.actor.optimizer.step()
200 |
201 | polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
202 | polyak_update(self.actor.parameters(), self.actor_target.parameters(), self.tau)
203 |
204 | self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
205 | if len(actor_losses) > 0:
206 | self.logger.record("train/actor_loss", np.mean(actor_losses))
207 | self.logger.record("train/critic_loss", np.mean(critic_losses))
208 | if len(l1_reg_losses) > 0:
209 | self.logger.record("train/l1_reg_loss", np.mean(l1_reg_losses))
210 |
211 | def learn(
212 | self,
213 | total_timesteps: int,
214 | callback: MaybeCallback = None,
215 | log_interval: int = 4,
216 | eval_env: Optional[GymEnv] = None,
217 | eval_freq: int = -1,
218 | n_eval_episodes: int = 5,
219 | tb_log_name: str = "TD3",
220 | eval_log_path: Optional[str] = None,
221 | reset_num_timesteps: bool = True,
222 | ) -> OffPolicyAlgorithm:
223 |
224 | return super(TD3, self).learn(
225 | total_timesteps=total_timesteps,
226 | callback=callback,
227 | log_interval=log_interval,
228 | eval_env=eval_env,
229 | eval_freq=eval_freq,
230 | n_eval_episodes=n_eval_episodes,
231 | tb_log_name=tb_log_name,
232 | eval_log_path=eval_log_path,
233 | reset_num_timesteps=reset_num_timesteps,
234 | )
235 |
236 | def _excluded_save_params(self) -> List[str]:
237 | return super(TD3, self)._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"]
238 |
239 | def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
240 | state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
241 | return state_dicts, []
--------------------------------------------------------------------------------
/icct/rl_helpers/td3_policies.py:
--------------------------------------------------------------------------------
1 | # Created by Yaru Niu
2 | # Revised from https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/td3/policies.py
3 |
4 | from typing import Any, Dict, List, Optional, Type, Union
5 |
6 | import gym
7 | import torch as th
8 | from torch import nn
9 |
10 | from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, register_policy
11 | from stable_baselines3.common.preprocessing import get_action_dim
12 | from stable_baselines3.common.torch_layers import (
13 | BaseFeaturesExtractor,
14 | CombinedExtractor,
15 | FlattenExtractor,
16 | NatureCNN,
17 | create_mlp,
18 | get_actor_critic_arch,
19 | )
20 | from stable_baselines3.common.type_aliases import Schedule
21 |
22 |
23 | class Actor(BasePolicy):
24 | """
25 | Actor network (policy) for TD3.
26 |
27 | :param observation_space: Obervation space
28 | :param action_space: Action space
29 | :param net_arch: Network architecture
30 | :param features_extractor: Network to extract features
31 | (a CNN when using images, a nn.Flatten() layer otherwise)
32 | :param features_dim: Number of features
33 | :param activation_fn: Activation function
34 | :param normalize_images: Whether to normalize images or not,
35 | dividing by 255.0 (True by default)
36 | """
37 |
38 | def __init__(
39 | self,
40 | observation_space: gym.spaces.Space,
41 | action_space: gym.spaces.Space,
42 | net_arch: List[int],
43 | features_extractor: nn.Module,
44 | features_dim: int,
45 | activation_fn: Type[nn.Module] = nn.ReLU,
46 | normalize_images: bool = True,
47 | ):
48 | super(Actor, self).__init__(
49 | observation_space,
50 | action_space,
51 | features_extractor=features_extractor,
52 | normalize_images=normalize_images,
53 | squash_output=True,
54 | )
55 |
56 | self.net_arch = net_arch
57 | self.features_dim = features_dim
58 | self.activation_fn = activation_fn
59 |
60 | action_dim = get_action_dim(self.action_space)
61 | actor_net = create_mlp(features_dim, action_dim, net_arch, activation_fn, squash_output=True)
62 | # Deterministic action
63 | self.mu = nn.Sequential(*actor_net)
64 |
65 | def _get_constructor_parameters(self) -> Dict[str, Any]:
66 | data = super()._get_constructor_parameters()
67 |
68 | data.update(
69 | dict(
70 | net_arch=self.net_arch,
71 | features_dim=self.features_dim,
72 | activation_fn=self.activation_fn,
73 | features_extractor=self.features_extractor,
74 | )
75 | )
76 | return data
77 |
78 | def forward(self, obs: th.Tensor) -> th.Tensor:
79 | # assert deterministic, 'The TD3 actor only outputs deterministic actions'
80 | features = self.extract_features(obs)
81 | return self.mu(features)
82 |
83 | def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
84 | # Note: the deterministic deterministic parameter is ignored in the case of TD3.
85 | # Predictions are always deterministic.
86 | return self.forward(observation)
87 |
88 |
89 | class TD3Policy(BasePolicy):
90 | """
91 | Policy class (with both actor and critic) for TD3.
92 |
93 | :param observation_space: Observation space
94 | :param action_space: Action space
95 | :param lr_schedule: Learning rate schedule (could be constant)
96 | :param net_arch: The specification of the policy and value networks.
97 | :param activation_fn: Activation function
98 | :param features_extractor_class: Features extractor to use.
99 | :param features_extractor_kwargs: Keyword arguments
100 | to pass to the features extractor.
101 | :param normalize_images: Whether to normalize images or not,
102 | dividing by 255.0 (True by default)
103 | :param optimizer_class: The optimizer to use,
104 | ``th.optim.Adam`` by default
105 | :param optimizer_kwargs: Additional keyword arguments,
106 | excluding the learning rate, to pass to the optimizer
107 | :param n_critics: Number of critic networks to create.
108 | :param share_features_extractor: Whether to share or not the features extractor
109 | between the actor and the critic (this saves computation time)
110 | """
111 |
112 | def __init__(
113 | self,
114 | observation_space: gym.spaces.Space,
115 | action_space: gym.spaces.Space,
116 | lr_schedule: Schedule,
117 | net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
118 | activation_fn: Type[nn.Module] = nn.ReLU,
119 | features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
120 | features_extractor_kwargs: Optional[Dict[str, Any]] = None,
121 | normalize_images: bool = True,
122 | optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
123 | optimizer_kwargs: Optional[Dict[str, Any]] = None,
124 | n_critics: int = 2,
125 | share_features_extractor: bool = True,
126 | ):
127 | super(TD3Policy, self).__init__(
128 | observation_space,
129 | action_space,
130 | features_extractor_class,
131 | features_extractor_kwargs,
132 | optimizer_class=optimizer_class,
133 | optimizer_kwargs=optimizer_kwargs,
134 | squash_output=True,
135 | )
136 |
137 | # Default network architecture, from the original paper
138 | if net_arch is None:
139 | if features_extractor_class == NatureCNN:
140 | net_arch = []
141 | else:
142 | net_arch = [400, 300]
143 |
144 | actor_arch, critic_arch = get_actor_critic_arch(net_arch)
145 |
146 | self.net_arch = net_arch
147 | self.activation_fn = activation_fn
148 | self.net_args = {
149 | "observation_space": self.observation_space,
150 | "action_space": self.action_space,
151 | "net_arch": actor_arch,
152 | "activation_fn": self.activation_fn,
153 | "normalize_images": normalize_images,
154 | }
155 | self.actor_kwargs = self.net_args.copy()
156 | self.critic_kwargs = self.net_args.copy()
157 | self.critic_kwargs.update(
158 | {
159 | "n_critics": n_critics,
160 | "net_arch": critic_arch,
161 | "share_features_extractor": share_features_extractor,
162 | }
163 | )
164 |
165 | self.actor, self.actor_target = None, None
166 | self.critic, self.critic_target = None, None
167 | self.share_features_extractor = share_features_extractor
168 |
169 | self._build(lr_schedule)
170 |
171 | def _build(self, lr_schedule: Schedule) -> None:
172 | # Create actor and target
173 | # the features extractor should not be shared
174 | self.actor = self.make_actor(features_extractor=None)
175 | self.actor_target = self.make_actor(features_extractor=None)
176 | # Initialize the target to have the same weights as the actor
177 | self.actor_target.load_state_dict(self.actor.state_dict())
178 |
179 | if hasattr(self.actor, 'ddt'):
180 | self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=self.actor.ddt_kwargs['ddt_lr'], **self.optimizer_kwargs)
181 | else:
182 | self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
183 |
184 | if self.share_features_extractor:
185 | self.critic = self.make_critic(features_extractor=self.actor.features_extractor)
186 | # Critic target should not share the features extactor with critic
187 | # but it can share it with the actor target as actor and critic are sharing
188 | # the same features_extractor too
189 | # NOTE: as a result the effective poliak (soft-copy) coefficient for the features extractor
190 | # will be 2 * tau instead of tau (updated one time with the actor, a second time with the critic)
191 | self.critic_target = self.make_critic(features_extractor=self.actor_target.features_extractor)
192 | else:
193 | # Create new features extractor for each network
194 | self.critic = self.make_critic(features_extractor=None)
195 | self.critic_target = self.make_critic(features_extractor=None)
196 |
197 | self.critic_target.load_state_dict(self.critic.state_dict())
198 | self.critic.optimizer = self.optimizer_class(self.critic.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
199 |
200 | def _get_constructor_parameters(self) -> Dict[str, Any]:
201 | data = super()._get_constructor_parameters()
202 |
203 | data.update(
204 | dict(
205 | net_arch=self.net_arch,
206 | activation_fn=self.net_args["activation_fn"],
207 | n_critics=self.critic_kwargs["n_critics"],
208 | lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
209 | optimizer_class=self.optimizer_class,
210 | optimizer_kwargs=self.optimizer_kwargs,
211 | features_extractor_class=self.features_extractor_class,
212 | features_extractor_kwargs=self.features_extractor_kwargs,
213 | share_features_extractor=self.share_features_extractor,
214 | )
215 | )
216 | return data
217 |
218 | def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor:
219 | actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
220 | return Actor(**actor_kwargs).to(self.device)
221 |
222 | def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> ContinuousCritic:
223 | critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
224 | return ContinuousCritic(**critic_kwargs).to(self.device)
225 |
226 | def forward(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
227 | return self._predict(observation, deterministic=deterministic)
228 |
229 | def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
230 | # Note: the deterministic parameter is ignored in the case of TD3.
231 | # Predictions are always deterministic.
232 | return self.actor(observation)
233 |
234 |
235 | MlpPolicy = TD3Policy
236 |
237 |
238 | class CnnPolicy(TD3Policy):
239 | """
240 | Policy class (with both actor and critic) for TD3.
241 |
242 | :param observation_space: Observation space
243 | :param action_space: Action space
244 | :param lr_schedule: Learning rate schedule (could be constant)
245 | :param net_arch: The specification of the policy and value networks.
246 | :param activation_fn: Activation function
247 | :param features_extractor_class: Features extractor to use.
248 | :param features_extractor_kwargs: Keyword arguments
249 | to pass to the features extractor.
250 | :param normalize_images: Whether to normalize images or not,
251 | dividing by 255.0 (True by default)
252 | :param optimizer_class: The optimizer to use,
253 | ``th.optim.Adam`` by default
254 | :param optimizer_kwargs: Additional keyword arguments,
255 | excluding the learning rate, to pass to the optimizer
256 | :param n_critics: Number of critic networks to create.
257 | :param share_features_extractor: Whether to share or not the features extractor
258 | between the actor and the critic (this saves computation time)
259 | """
260 |
261 | def __init__(
262 | self,
263 | observation_space: gym.spaces.Space,
264 | action_space: gym.spaces.Space,
265 | lr_schedule: Schedule,
266 | net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
267 | activation_fn: Type[nn.Module] = nn.ReLU,
268 | features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
269 | features_extractor_kwargs: Optional[Dict[str, Any]] = None,
270 | normalize_images: bool = True,
271 | optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
272 | optimizer_kwargs: Optional[Dict[str, Any]] = None,
273 | n_critics: int = 2,
274 | share_features_extractor: bool = True,
275 | ):
276 | super(CnnPolicy, self).__init__(
277 | observation_space,
278 | action_space,
279 | lr_schedule,
280 | net_arch,
281 | activation_fn,
282 | features_extractor_class,
283 | features_extractor_kwargs,
284 | normalize_images,
285 | optimizer_class,
286 | optimizer_kwargs,
287 | n_critics,
288 | share_features_extractor,
289 | )
290 |
291 |
292 | class MultiInputPolicy(TD3Policy):
293 | """
294 | Policy class (with both actor and critic) for TD3 to be used with Dict observation spaces.
295 |
296 | :param observation_space: Observation space
297 | :param action_space: Action space
298 | :param lr_schedule: Learning rate schedule (could be constant)
299 | :param net_arch: The specification of the policy and value networks.
300 | :param activation_fn: Activation function
301 | :param features_extractor_class: Features extractor to use.
302 | :param features_extractor_kwargs: Keyword arguments
303 | to pass to the features extractor.
304 | :param normalize_images: Whether to normalize images or not,
305 | dividing by 255.0 (True by default)
306 | :param optimizer_class: The optimizer to use,
307 | ``th.optim.Adam`` by default
308 | :param optimizer_kwargs: Additional keyword arguments,
309 | excluding the learning rate, to pass to the optimizer
310 | :param n_critics: Number of critic networks to create.
311 | :param share_features_extractor: Whether to share or not the features extractor
312 | between the actor and the critic (this saves computation time)
313 | """
314 |
315 | def __init__(
316 | self,
317 | observation_space: gym.spaces.Dict,
318 | action_space: gym.spaces.Space,
319 | lr_schedule: Schedule,
320 | net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
321 | activation_fn: Type[nn.Module] = nn.ReLU,
322 | features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
323 | features_extractor_kwargs: Optional[Dict[str, Any]] = None,
324 | normalize_images: bool = True,
325 | optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
326 | optimizer_kwargs: Optional[Dict[str, Any]] = None,
327 | n_critics: int = 2,
328 | share_features_extractor: bool = True,
329 | ):
330 | super(MultiInputPolicy, self).__init__(
331 | observation_space,
332 | action_space,
333 | lr_schedule,
334 | net_arch,
335 | activation_fn,
336 | features_extractor_class,
337 | features_extractor_kwargs,
338 | normalize_images,
339 | optimizer_class,
340 | optimizer_kwargs,
341 | n_critics,
342 | share_features_extractor,
343 | )
344 |
345 |
346 | register_policy("MlpPolicy", MlpPolicy)
347 | register_policy("CnnPolicy", CnnPolicy)
348 | register_policy("MultiInputPolicy", MultiInputPolicy)
349 |
--------------------------------------------------------------------------------
/icct/runfiles/test.py:
--------------------------------------------------------------------------------
1 | # Created by Yaru Niu
2 |
3 | import gym
4 | import numpy as np
5 | import copy
6 | import argparse
7 | import random
8 | import os
9 | import torch
10 | from icct.rl_helpers import ddt_policy
11 | from icct.core.icct_helpers import convert_to_crisp
12 | from icct.rl_helpers.save_after_ep_callback import EpCheckPointCallback
13 | from stable_baselines3.common.torch_layers import (
14 | BaseFeaturesExtractor,
15 | CombinedExtractor,
16 | FlattenExtractor
17 | )
18 |
19 | from stable_baselines3 import SAC
20 | import highway_env
21 | from flow.utils.registry import make_create_env
22 | from icct.sumo_envs.accel_ring import ring_accel_params
23 | from icct.sumo_envs.accel_figure8 import fig8_params
24 | from icct.sumo_envs.accel_ring_multilane import ring_accel_lc_params
25 | from stable_baselines3.common.utils import set_random_seed
26 |
27 | def make_env(env_name, seed):
28 | set_random_seed(seed)
29 | if env_name == 'lunar':
30 | env = gym.make('LunarLanderContinuous-v2')
31 | name = 'LunarLanderContinuous-v2'
32 | elif env_name == 'cart':
33 | env = gym.make('InvertedPendulum-v2')
34 | name = 'InvertedPendulum-v2'
35 | elif env_name == 'lane_keeping':
36 | env = gym.make('lane-keeping-v0')
37 | name = 'lane-keeping-v0'
38 | elif env_name == 'ring_accel':
39 | create_env, gym_name = make_create_env(params=ring_accel_params, version=0)
40 | env = create_env()
41 | name = gym_name
42 | elif env_name == 'ring_lane_changing':
43 | create_env, gym_name = make_create_env(params=ring_accel_lc_params, version=0)
44 | env = create_env()
45 | name = gym_name
46 | elif env_name == 'figure8':
47 | create_env, gym_name = make_create_env(params=fig8_params, version=0)
48 | env = create_env()
49 | name = gym_name
50 | else:
51 | raise Exception('No valid environment selected')
52 | env.seed(seed)
53 | return env, name
54 |
55 |
56 | if __name__ == "__main__":
57 | parser = argparse.ArgumentParser(description='ICCT Testing')
58 | parser.add_argument('--env_name', help='environment to run on', type=str, default='lunar')
59 | parser.add_argument('--seed', help='random seed', type=int, default=42)
60 | parser.add_argument('--load_path', help='the path of saving the model', type=str, default='test')
61 | parser.add_argument('--num_episodes', help='number of episodes to test', type=int, default=20)
62 | parser.add_argument('--render', help='if render the tested environment', action='store_true')
63 | parser.add_argument('--gpu', help='if run on a GPU', action='store_true')
64 | parser.add_argument('--load_file', help='which model file to load and test', type=str, default='best_model')
65 |
66 |
67 | args = parser.parse_args()
68 | env, env_n = make_env(args.env_name, args.seed)
69 |
70 | if args.gpu:
71 | args.device = 'cuda'
72 | else:
73 | args.device = 'cpu'
74 |
75 | model = SAC.load("../../" + args.load_path + "/" + args.load_file, device=args.device)
76 | obs = env.reset()
77 | episode_reward_for_reg = []
78 | for _ in range(args.num_episodes):
79 | done = False
80 | episode_reward = 0
81 | while not done:
82 | action, _states = model.predict(obs, deterministic=True)
83 | obs, reward, done, info = env.step(action)
84 | episode_reward+= reward
85 | if args.render:
86 | env.render()
87 | if done:
88 | obs = env.reset()
89 | episode_reward_for_reg.append(episode_reward)
90 | break
91 | print('fuzzy results:')
92 | print(episode_reward_for_reg)
93 | print(np.mean(episode_reward_for_reg))
94 | print(np.std(episode_reward_for_reg))
95 |
96 | env, env_n = make_env(args.env_name, args.seed)
97 | if hasattr(model.actor, 'ddt'):
98 | model.actor.ddt = convert_to_crisp(model.actor.ddt, training_data=None)
99 | obs = env.reset()
100 | discrete_episode_reward_for_reg = []
101 | for _ in range(args.num_episodes):
102 | done = False
103 | episode_reward = 0
104 | while not done:
105 | action, _states = model.predict(obs, deterministic=True)
106 | obs, reward, done, info = env.step(action)
107 | episode_reward += reward
108 | if args.render:
109 | env.render()
110 | if done:
111 | obs = env.reset()
112 | discrete_episode_reward_for_reg.append(episode_reward)
113 | break
114 | print('crisp results:')
115 | print(discrete_episode_reward_for_reg)
116 | print(np.mean(discrete_episode_reward_for_reg))
117 | print(np.std(discrete_episode_reward_for_reg))
--------------------------------------------------------------------------------
/icct/runfiles/test.sh:
--------------------------------------------------------------------------------
1 | export OMP_NUM_THREADS=1
2 | # export MKL_NUM_THREADS=1
3 | # export CUDA_VISIBLE_DEVICES=1
4 |
5 | python -u test.py \
6 | --env_name lunar \
7 | --seed 42 \
8 | --load_path log/ll \
9 | --num_episodes 5 \
10 | --load_file best_model \
11 | --gpu \
12 | | tee test.log
13 |
14 |
--------------------------------------------------------------------------------
/icct/runfiles/train.py:
--------------------------------------------------------------------------------
1 | # Created by Yaru Niu
2 |
3 | import gym
4 | import numpy as np
5 | import copy
6 | import argparse
7 | import random
8 | import os
9 | import torch
10 | from icct.rl_helpers import ddt_sac_policy
11 | from icct.rl_helpers import ddt_td3_policy
12 | from icct.core.icct_helpers import convert_to_crisp
13 | from icct.rl_helpers.save_after_ep_callback import EpCheckPointCallback
14 | from stable_baselines3.common.torch_layers import (
15 | BaseFeaturesExtractor,
16 | CombinedExtractor,
17 | FlattenExtractor
18 | )
19 |
20 | from icct.rl_helpers.sac import SAC
21 | from icct.rl_helpers.td3 import TD3
22 | import highway_env
23 | from flow.utils.registry import make_create_env
24 | from icct.sumo_envs.accel_ring import ring_accel_params
25 | from icct.sumo_envs.accel_ring_multilane import ring_accel_lc_params
26 | from icct.sumo_envs.accel_figure8 import fig8_params
27 | from stable_baselines3.common.utils import set_random_seed
28 | from stable_baselines3.common.monitor import Monitor
29 |
30 |
31 | def make_env(env_name, seed):
32 | set_random_seed(seed)
33 | if env_name == 'lunar':
34 | env = gym.make('LunarLanderContinuous-v2')
35 | name = 'LunarLanderContinuous-v2'
36 | elif env_name == 'cart':
37 | env = gym.make('InvertedPendulum-v2')
38 | name = 'InvertedPendulum-v2'
39 | elif env_name == 'lane_keeping':
40 | env = gym.make('lane-keeping-v0')
41 | name = 'lane-keeping-v0'
42 | elif env_name == 'ring_accel':
43 | create_env, gym_name = make_create_env(params=ring_accel_params, version=0)
44 | env = create_env()
45 | name = gym_name
46 | elif env_name == 'ring_lane_changing':
47 | create_env, gym_name = make_create_env(params=ring_accel_lc_params, version=0)
48 | env = create_env()
49 | name = gym_name
50 | elif env_name == 'figure8':
51 | create_env, gym_name = make_create_env(params=fig8_params, version=0)
52 | env = create_env()
53 | name = gym_name
54 | else:
55 | raise Exception('No valid environment selected')
56 | env.seed(seed)
57 | return env, name
58 |
59 |
60 | if __name__ == "__main__":
61 | parser = argparse.ArgumentParser(description='ICCT Training')
62 | parser.add_argument('--env_name', help='environment to run on', type=str, default='lunar')
63 | parser.add_argument('--alg_type', help='sac or td3', type=str, default='sac')
64 | parser.add_argument('--policy_type', help='mlp or ddt', type=str, default='ddt')
65 | parser.add_argument('--mlp_size', help='the size of mlp (small|medium|large)', type=str, default='medium')
66 | parser.add_argument('--seed', help='the seed number to use', type=int, default=42)
67 | parser.add_argument('--num_leaves', help='number of leaves used in ddt (2^n)', type=int, default=16)
68 | parser.add_argument('--submodels', help='if use sub-models in ddt', action='store_true', default=False)
69 | parser.add_argument('--sparse_submodel_type', help='the type of the sparse submodel, 1 for L1 regularization, 2 for feature selection, other values for not sparse', type=int, default=0)
70 | parser.add_argument('--hard_node', help='if use differentiable crispification', action='store_true', default=False)
71 | parser.add_argument('--gpu', help='if run on a GPU', action='store_true', default=False)
72 | parser.add_argument('--lr', help='learning rate', type=float, default=3e-4)
73 | parser.add_argument('--buffer_size', help='buffer size', type=int, default=1000000)
74 | parser.add_argument('--batch_size', help='batch size', type=int, default=256)
75 | parser.add_argument('--gamma', help='the discount factor', type=float, default=0.9999)
76 | parser.add_argument('--tau', help='the soft update coefficient (between 0 and 1)', type=float, default=0.01)
77 | parser.add_argument('--learning_starts', help='how many steps of the model to collect transitions for before learning starts', type=int, default=10000)
78 | parser.add_argument('--training_steps', help='total steps for training the model', type=int, default=500000)
79 | parser.add_argument('--argmax_tau', help='the temperature of the diff_argmax function', type=float, default=1.0)
80 | parser.add_argument('--ddt_lr', help='the learning rate of the ddt', type=float, default=3e-4)
81 | parser.add_argument('--use_individual_alpha', help='if use different alphas for different nodes', action='store_true', default=False)
82 | parser.add_argument('--l1_reg_coeff', help='the coefficient of the l1 regularization when using l1-reg submodels', type=float, default=5e-3)
83 | parser.add_argument('--l1_reg_bias', help='if consider biases in the l1 loss when using l1-reg submodels', action='store_true', default=False)
84 | parser.add_argument('--l1_hard_attn', help='if only sample one linear controller to perform L1 regularization for each update when using l1-reg submodels', action='store_true', default=False)
85 | parser.add_argument('--num_sub_features', help='the number of chosen features for submodels', type=int, default=1)
86 | parser.add_argument('--use_gumbel_softmax', help='if use gumble softmax instead of the differentiable argmax proposed in the paper', action='store_true', default=False)
87 | # evaluation and model saving
88 | parser.add_argument('--min_reward', help='minimum reward to save the model', type=int)
89 | parser.add_argument('--save_path', help='the path of saving the model', type=str, default='test')
90 | parser.add_argument('--n_eval_episodes', help='the number of episodes for each evaluation during training', type=int, default=5)
91 | parser.add_argument('--eval_freq', help='evaluation frequence of the model', type=int, default=1500)
92 | parser.add_argument('--log_interval', help='the number of episodes before logging', type=int, default=4)
93 |
94 |
95 | args = parser.parse_args()
96 | env, env_n = make_env(args.env_name, args.seed)
97 | eval_env = gym.make(env_n)
98 | eval_env.seed(args.seed)
99 | save_folder = args.save_path
100 | log_dir = '../../' + save_folder + '/'
101 | if not os.path.exists(log_dir):
102 | os.makedirs(log_dir)
103 |
104 | if args.policy_type == 'ddt':
105 | if not args.submodels and not args.hard_node:
106 | method = 'm1'
107 | elif args.submodels and not args.hard_node:
108 | method = 'm2'
109 | if args.sparse_submodel_type == 1 or args.sparse_submodel_type == 2:
110 | raise Exception('Not a method we want to test')
111 | elif not args.submodels and args.hard_node:
112 | method = 'm3'
113 | else:
114 | if args.sparse_submodel_type != 1 and args.sparse_submodel_type != 2:
115 | method = 'm4'
116 | elif args.sparse_submodel_type == 1:
117 | method = 'm5a'
118 | else:
119 | method = f'm5b_{args.num_sub_features}'
120 | elif args.policy_type == 'mlp':
121 | if args.mlp_size == 'small':
122 | method = 'mlp_s'
123 | elif args.mlp_size == 'medium':
124 | method = 'mlp_m'
125 | elif args.mlp_size == 'large':
126 | method = 'mlp_l'
127 | else:
128 | raise Exception('Not a valid MLP size')
129 | else:
130 | raise Exception('Not a valid policy type')
131 |
132 | monitor_file_path = log_dir + method + f'_seed{args.seed}'
133 | env = Monitor(env, monitor_file_path)
134 | eval_monitor_file_path = log_dir + 'eval_' + method + f'_seed{args.seed}'
135 | eval_env = Monitor(eval_env, eval_monitor_file_path)
136 | callback = EpCheckPointCallback(eval_env=eval_env, best_model_save_path=log_dir, n_eval_episodes=args.n_eval_episodes,
137 | eval_freq=args.eval_freq, minimum_reward=args.min_reward)
138 |
139 | if args.gpu:
140 | args.device = 'cuda'
141 | else:
142 | args.device = 'cpu'
143 |
144 | if args.env_name == 'lane_keeping':
145 | features_extractor = CombinedExtractor
146 | else:
147 | features_extractor = FlattenExtractor
148 |
149 | if args.env_name == 'cart':
150 | args.fs_submodel_version = 1
151 | else:
152 | args.fs_submodel_version = 0
153 |
154 | if args.alg_type != 'sac' and args.alg_type != 'td3':
155 | raise Exception('Not a valid RL algorithm type')
156 |
157 | if args.policy_type == 'ddt':
158 | ddt_kwargs = {
159 | 'num_leaves': args.num_leaves,
160 | 'submodels': args.submodels,
161 | 'hard_node': args.hard_node,
162 | 'device': args.device,
163 | 'argmax_tau': args.argmax_tau,
164 | 'ddt_lr': args.ddt_lr,
165 | 'use_individual_alpha': args.use_individual_alpha,
166 | 'sparse_submodel_type': args.sparse_submodel_type,
167 | 'fs_submodel_version': args.fs_submodel_version,
168 | 'l1_reg_coeff': args.l1_reg_coeff,
169 | 'l1_reg_bias': args.l1_reg_bias,
170 | 'l1_hard_attn': args.l1_hard_attn,
171 | 'num_sub_features': args.num_sub_features,
172 | 'use_gumbel_softmax': args.use_gumbel_softmax,
173 | 'alg_type': args.alg_type
174 | }
175 | policy_kwargs = {
176 | 'features_extractor_class': features_extractor,
177 | 'ddt_kwargs': ddt_kwargs
178 | }
179 | if args.alg_type == 'sac':
180 | policy_name = 'DDT_SACPolicy'
181 | policy_kwargs['net_arch'] = {'pi': [16, 16], 'qf': [256, 256]} # [256, 256] is a default setting in SB3 for SAC
182 | else:
183 | policy_name = 'DDT_TD3Policy'
184 | policy_kwargs['net_arch'] = {'pi': [16, 16], 'qf': [400, 300]} # [400, 300] is a default setting in SB3 for TD3
185 |
186 | elif args.policy_type == 'mlp':
187 | if args.env_name == 'lane_keeping':
188 | policy_name = 'MultiInputPolicy'
189 | else:
190 | policy_name = 'MlpPolicy'
191 |
192 | if args.mlp_size == 'small':
193 | if args.env_name == 'cart':
194 | pi_size = [6, 6]
195 | elif args.env_name == 'lunar':
196 | pi_size = [6, 6]
197 | elif args.env_name == 'lane_keeping':
198 | pi_size = [6, 6]
199 | elif args.env_name == 'ring_accel':
200 | pi_size = [3, 3]
201 | elif args.env_name == 'ring_lane_changing':
202 | pi_size = [3, 3]
203 | else:
204 | pi_size = [3, 3]
205 | elif args.mlp_size == 'medium':
206 | if args.env_name == 'cart':
207 | pi_size = [8, 8]
208 | elif args.env_name == 'lunar':
209 | pi_size = [10, 10]
210 | elif args.env_name == 'lane_keeping':
211 | pi_size = [14, 14]
212 | elif args.env_name == 'ring_accel':
213 | pi_size = [12, 12]
214 | elif args.env_name == 'ring_lane_changing':
215 | pi_size = [32, 32]
216 | else:
217 | pi_size = [20, 20]
218 | elif args.mlp_size == 'large':
219 | if args.alg_type == 'sac':
220 | pi_size = [256, 256]
221 | else:
222 | pi_size = [400, 300]
223 | else:
224 | raise Exception('Not a valid MLP size')
225 | if args.alg_type == 'sac':
226 | policy_kwargs = {
227 | 'net_arch': {'pi': pi_size, 'qf': [256, 256]},
228 | 'features_extractor_class': features_extractor,
229 | }
230 | else:
231 | policy_kwargs = {
232 | 'net_arch': {'pi': pi_size, 'qf': [400, 300]},
233 | 'features_extractor_class': features_extractor,
234 | }
235 | else:
236 | raise Exception('Not a valid policy type')
237 |
238 | if args.alg_type == 'sac':
239 | model = SAC(policy_name, env,
240 | learning_rate=args.lr,
241 | buffer_size=args.buffer_size,
242 | batch_size=args.batch_size,
243 | ent_coef='auto',
244 | train_freq=1,
245 | gradient_steps=1,
246 | gamma=args.gamma,
247 | tau=args.tau,
248 | learning_starts=args.learning_starts,
249 | policy_kwargs=policy_kwargs,
250 | tensorboard_log=log_dir,
251 | verbose=1,
252 | device=args.device,
253 | seed=args.seed)
254 | else:
255 | model = TD3(policy_name, env,
256 | learning_rate=args.lr,
257 | buffer_size=args.buffer_size,
258 | batch_size=args.batch_size,
259 | gamma=args.gamma,
260 | tau=args.tau,
261 | learning_starts=args.learning_starts,
262 | target_policy_noise=0.1,
263 | policy_kwargs=policy_kwargs,
264 | tensorboard_log=log_dir,
265 | verbose=1,
266 | device=args.device,
267 | seed=args.seed)
268 |
269 | model.learn(total_timesteps=args.training_steps, log_interval=args.log_interval, callback=callback)
270 |
271 |
--------------------------------------------------------------------------------
/icct/runfiles/train_fig8.sh:
--------------------------------------------------------------------------------
1 | export OMP_NUM_THREADS=1
2 | # export MKL_NUM_THREADS=1
3 | # export CUDA_VISIBLE_DEVICES=1
4 |
5 | python -u train.py \
6 | --env_name figure8 \
7 | --policy_type ddt \
8 | --seed 0 \
9 | --num_leaves 16 \
10 | --lr 6e-4 \
11 | --ddt_lr 6e-4 \
12 | --buffer_size 1000000 \
13 | --batch_size 1024 \
14 | --gamma 0.99 \
15 | --tau 0.01 \
16 | --learning_starts 5000 \
17 | --eval_freq 2500 \
18 | --min_reward 900 \
19 | --training_steps 500000 \
20 | --log_interval 4 \
21 | --save_path log/fig8 \
22 | --use_individual_alpha \
23 | --submodels \
24 | --hard_node \
25 | --argmax_tau 1.0 \
26 | --sparse_submodel_type 2 \
27 | --num_sub_features 2 \
28 | | tee train_fig8.log
29 |
--------------------------------------------------------------------------------
/icct/runfiles/train_ip.sh:
--------------------------------------------------------------------------------
1 | export OMP_NUM_THREADS=1
2 | # export MKL_NUM_THREADS=1
3 | # export CUDA_VISIBLE_DEVICES=1
4 |
5 | python -u train.py \
6 | --env_name cart \
7 | --policy_type ddt \
8 | --seed 0 \
9 | --num_leaves 4 \
10 | --lr 5e-4 \
11 | --ddt_lr 5e-4 \
12 | --buffer_size 1000000 \
13 | --batch_size 1024 \
14 | --gamma 0.99 \
15 | --tau 0.01 \
16 | --learning_starts 10000 \
17 | --eval_freq 1500 \
18 | --min_reward 900 \
19 | --training_steps 500000 \
20 | --log_interval 4 \
21 | --save_path log/ip \
22 | --use_individual_alpha \
23 | --submodels \
24 | --hard_node \
25 | --argmax_tau 1.0 \
26 | --sparse_submodel_type 2 \
27 | --num_sub_features 2 \
28 | | tee train_ip.log
--------------------------------------------------------------------------------
/icct/runfiles/train_lk.sh:
--------------------------------------------------------------------------------
1 | export OMP_NUM_THREADS=1
2 | # export MKL_NUM_THREADS=1
3 | # export CUDA_VISIBLE_DEVICES=1
4 |
5 | python -u train.py \
6 | --env_name lane_keeping \
7 | --policy_type ddt \
8 | --seed 0 \
9 | --num_leaves 16 \
10 | --lr 3e-4 \
11 | --ddt_lr 3e-4 \
12 | --buffer_size 1000000 \
13 | --batch_size 1024 \
14 | --gamma 0.99 \
15 | --tau 0.01 \
16 | --learning_starts 10000 \
17 | --eval_freq 1500 \
18 | --min_reward 420 \
19 | --training_steps 500000 \
20 | --log_interval 4 \
21 | --save_path log/lk \
22 | --use_individual_alpha \
23 | --submodels \
24 | --hard_node \
25 | --argmax_tau 1.0 \
26 | --sparse_submodel_type 2 \
27 | --num_sub_features 2 \
28 | | tee train_lk.log
29 |
--------------------------------------------------------------------------------
/icct/runfiles/train_ll.sh:
--------------------------------------------------------------------------------
1 | export OMP_NUM_THREADS=1
2 | # export MKL_NUM_THREADS=1
3 | # export CUDA_VISIBLE_DEVICES=1
4 |
5 | python -u train.py \
6 | --env_name lunar \
7 | --policy_type ddt \
8 | --seed 0 \
9 | --num_leaves 8 \
10 | --lr 5e-4 \
11 | --ddt_lr 5e-4 \
12 | --buffer_size 1000000 \
13 | --batch_size 256 \
14 | --gamma 0.99 \
15 | --tau 0.01 \
16 | --learning_starts 10000 \
17 | --eval_freq 1500 \
18 | --min_reward 225 \
19 | --training_steps 500000 \
20 | --log_interval 4 \
21 | --save_path log/ll \
22 | --use_individual_alpha \
23 | --submodels \
24 | --hard_node \
25 | --gpu \
26 | --argmax_tau 1.0 \
27 | --sparse_submodel_type 2 \
28 | --num_sub_features 2 \
29 | | tee train_ll.log
--------------------------------------------------------------------------------
/icct/runfiles/train_ring_accel.sh:
--------------------------------------------------------------------------------
1 | export OMP_NUM_THREADS=1
2 | # export MKL_NUM_THREADS=1
3 | # export CUDA_VISIBLE_DEVICES=1
4 |
5 | python -u train.py \
6 | --env_name ring_accel \
7 | --policy_type ddt \
8 | --seed 0 \
9 | --num_leaves 16 \
10 | --lr 5e-4 \
11 | --ddt_lr 5e-4 \
12 | --buffer_size 1000000 \
13 | --batch_size 1024 \
14 | --gamma 0.99 \
15 | --tau 0.01 \
16 | --learning_starts 5000 \
17 | --eval_freq 1500 \
18 | --min_reward 120 \
19 | --training_steps 100000 \
20 | --log_interval 4 \
21 | --save_path log/ring_accel \
22 | --use_individual_alpha \
23 | --submodels \
24 | --hard_node \
25 | --gpu \
26 | --argmax_tau 1.0 \
27 | --sparse_submodel_type 2 \
28 | --num_sub_features 2 \
29 | | tee train_ring_accel.log
--------------------------------------------------------------------------------
/icct/runfiles/train_ring_accel_lc.sh:
--------------------------------------------------------------------------------
1 | export OMP_NUM_THREADS=1
2 | # export MKL_NUM_THREADS=1
3 | # export CUDA_VISIBLE_DEVICES=1
4 |
5 | python -u train.py \
6 | --env_name ring_lane_changing \
7 | --policy_type ddt \
8 | --seed 0 \
9 | --num_leaves 16 \
10 | --lr 5e-4 \
11 | --ddt_lr 5e-4 \
12 | --buffer_size 1000000 \
13 | --batch_size 1024 \
14 | --gamma 0.99 \
15 | --tau 0.01 \
16 | --learning_starts 5000 \
17 | --eval_freq 2500 \
18 | --min_reward 1200 \
19 | --training_steps 500000 \
20 | --log_interval 4 \
21 | --save_path log/ring_lc \
22 | --use_individual_alpha \
23 | --submodels \
24 | --hard_node \
25 | --argmax_tau 1.0 \
26 | --sparse_submodel_type 2 \
27 | --num_sub_features 3 \
28 | | tee train_ring_accel_lc.log
--------------------------------------------------------------------------------
/icct/sumo_envs/accel_figure8.py:
--------------------------------------------------------------------------------
1 | # Created by Yaru Niu and Chace Ritchie
2 | # Revised from https://github.com/flow-project/flow/blob/master/examples/exp_configs/rl/singleagent/singleagent_figure_eight.py
3 |
4 | from flow.envs import AccelEnv
5 | from flow.networks import FigureEightNetwork
6 | from copy import deepcopy
7 | from flow.core.params import SumoParams, EnvParams, InitialConfig, NetParams, \
8 | SumoCarFollowingParams
9 | from flow.core.params import VehicleParams
10 | from flow.controllers import IDMController, ContinuousRouter, RLController
11 | from flow.networks.figure_eight import ADDITIONAL_NET_PARAMS
12 |
13 | # time horizon of a single rollout
14 | HORIZON = 1500
15 | ADDITIONAL_NET_PARAMS["speed_limit"] = 12
16 |
17 | # We place 1 autonomous vehicle and 13 human-driven vehicles in the network
18 | vehicles = VehicleParams()
19 | vehicles.add(
20 | veh_id="human",
21 | acceleration_controller=(IDMController, {
22 | "noise": 0.2
23 | }),
24 | routing_controller=(ContinuousRouter, {}),
25 | car_following_params=SumoCarFollowingParams(
26 | speed_mode="obey_safe_speed",
27 | max_speed=12
28 | ),
29 | num_vehicles=13)
30 | vehicles.add(
31 | veh_id="rl",
32 | acceleration_controller=(RLController, {}),
33 | routing_controller=(ContinuousRouter, {}),
34 | car_following_params=SumoCarFollowingParams(
35 | speed_mode="obey_safe_speed",
36 | ),
37 | num_vehicles=1)
38 |
39 | fig8_params = dict(
40 | # name of the experiment
41 | exp_tag="figure_eight",
42 |
43 | # name of the flow environment the experiment is running on
44 | env_name=AccelEnv,
45 |
46 | # name of the network class the experiment is running on
47 | network=FigureEightNetwork,
48 |
49 | # simulator that is used by the experiment
50 | simulator='traci',
51 |
52 | # sumo-related parameters (see flow.core.params.SumoParams)
53 | sim=SumoParams(
54 | sim_step=0.1,
55 | render=False,
56 | ),
57 |
58 | # environment related parameters (see flow.core.params.EnvParams)
59 | env=EnvParams(
60 | horizon=HORIZON,
61 | additional_params={
62 | "target_velocity": 5,
63 | "max_accel": 3,
64 | "max_decel": 3,
65 | "sort_vehicles": False
66 | },
67 | ),
68 |
69 | # network-related parameters (see flow.core.params.NetParams and the
70 | # network's documentation or ADDITIONAL_NET_PARAMS component)
71 | net=NetParams(
72 | additional_params=deepcopy(ADDITIONAL_NET_PARAMS),
73 | ),
74 |
75 | # vehicles to be placed in the network at the start of a rollout (see
76 | # flow.core.params.VehicleParams)
77 | veh=vehicles,
78 |
79 | # parameters specifying the positioning of vehicles upon initialization/
80 | # reset (see flow.core.params.InitialConfig)
81 | initial=InitialConfig(),
82 | )
--------------------------------------------------------------------------------
/icct/sumo_envs/accel_ring.py:
--------------------------------------------------------------------------------
1 | # Created by Yaru Niu and Chace Ritchie
2 | # Revised from https://github.com/flow-project/flow/blob/master/examples/exp_configs/rl/singleagent/singleagent_ring.py
3 |
4 | from flow.envs import AccelEnv as RingAccelEnv
5 | from flow.networks.ring import RingNetwork, ADDITIONAL_NET_PARAMS
6 | from flow.utils.registry import make_create_env
7 | from flow.utils.rllib import FlowParamsEncoder
8 | from flow.core.params import SumoParams, EnvParams, InitialConfig, NetParams
9 | from flow.core.params import VehicleParams, SumoCarFollowingParams
10 | from flow.controllers import RLController, IDMController, ContinuousRouter
11 | from flow.controllers.car_following_models import SimCarFollowingController
12 |
13 | # time horizon of a single rollout
14 | HORIZON = 750
15 |
16 | # We place one autonomous vehicle and 22 human-driven vehicles in the network
17 | vehicles = VehicleParams()
18 | vehicles.add(
19 | veh_id="human",
20 | acceleration_controller=(IDMController, {
21 | "noise": 0.2
22 | }),
23 | car_following_params=SumoCarFollowingParams(
24 | min_gap=0
25 | ),
26 | routing_controller=(ContinuousRouter, {}),
27 | num_vehicles=21)
28 | vehicles.add(
29 | veh_id="rl",
30 | acceleration_controller=(RLController, {}),
31 | routing_controller=(ContinuousRouter, {}),
32 | num_vehicles=1)
33 |
34 | ring_accel_params = dict(
35 | # name of the experiment
36 | exp_tag="stabilizing_the_ring",
37 |
38 | # name of the flow environment the experiment is running on
39 | env_name=RingAccelEnv,
40 |
41 | # name of the network class the experiment is running on
42 | network=RingNetwork,
43 |
44 | # simulator that is used by the experiment
45 | simulator='traci',
46 |
47 | # sumo-related parameters (see flow.core.params.SumoParams)
48 | sim=SumoParams(
49 | sim_step=0.1, # seconds per simulation step
50 | render=False,
51 | ),
52 |
53 | # environment related parameters (see flow.core.params.EnvParams)
54 | env=EnvParams(
55 | horizon=HORIZON,
56 | warmup_steps=100,
57 | clip_actions=False,
58 | additional_params={
59 | "target_velocity": 20,
60 | "sort_vehicles": False,
61 | "max_accel": 1,
62 | "max_decel": 1,
63 | },
64 | ),
65 |
66 | # network-related parameters (see flow.core.params.NetParams and the
67 | # network's documentation or ADDITIONAL_NET_PARAMS component)
68 | net=NetParams(
69 | additional_params=ADDITIONAL_NET_PARAMS.copy()
70 | ),
71 |
72 | # vehicles to be placed in the network at the start of a rollout (see
73 | # flow.core.params.VehicleParams)
74 | veh=vehicles,
75 |
76 | # parameters specifying the positioning of vehicles upon initialization/
77 | # reset (see flow.core.params.InitialConfig)
78 | initial=InitialConfig(
79 | bunching=20,
80 | ),
81 | )
--------------------------------------------------------------------------------
/icct/sumo_envs/accel_ring_multilane.py:
--------------------------------------------------------------------------------
1 | # Created by Yaru Niu and Chace Ritchie
2 | # Reference: https://github.com/flow-project/flow/tree/master/tutorials
3 |
4 | from flow.controllers.car_following_models import SimCarFollowingController
5 | from flow.core import rewards
6 | from flow.envs import AccelEnv as RingAccelEnv
7 | from flow.networks.ring import RingNetwork, ADDITIONAL_NET_PARAMS
8 | from flow.utils.registry import make_create_env
9 | from flow.utils.rllib import FlowParamsEncoder
10 | from flow.core.params import SumoParams, EnvParams, InitialConfig, NetParams
11 | from flow.envs.ring.lane_change_accel import LaneChangeAccelEnv
12 | #from flow.networks.figure_eight import FigureEightNetwork, ADDITIONAL_NET_PARAMS
13 | from flow.core.params import VehicleParams, SumoCarFollowingParams, SumoLaneChangeParams
14 | from flow.controllers import RLController, IDMController, ContinuousRouter, SimLaneChangeController
15 | import numpy as np
16 |
17 | class LaneChangeAccelEnv_Wrapper(LaneChangeAccelEnv):
18 | def __init__(self, env_params, sim_params, network, simulator='traci'):
19 | super().__init__(env_params, sim_params, network, simulator)
20 |
21 | def _apply_rl_actions(self, actions):
22 | acceleration = actions[::2]
23 | direction = actions[1::2]
24 |
25 | # re-arrange actions according to mapping in observation space
26 | sorted_rl_ids = [
27 | veh_id for veh_id in self.sorted_ids
28 | if veh_id in self.k.vehicle.get_rl_ids()
29 | ]
30 |
31 | # discretize the direction values
32 | lane_changing_plus = \
33 | [direction[i] >= 0.5 and direction[i] <= 1 for i, veh_id in enumerate(sorted_rl_ids)]
34 | direction[lane_changing_plus] = \
35 | np.array([1] * sum(lane_changing_plus))
36 |
37 | lane_changing_minus = \
38 | [direction[i] >= -1 and direction[i] <= -0.5 for i, veh_id in enumerate(sorted_rl_ids)]
39 | direction[lane_changing_minus] = \
40 | np.array([-1] * sum(lane_changing_minus))
41 |
42 | lane_keeping = \
43 | [direction[i] > -0.5 and direction[i] < 0.5 for i, veh_id in enumerate(sorted_rl_ids)]
44 | direction[lane_keeping] = \
45 | np.array([0] * sum(lane_keeping))
46 |
47 | # represents vehicles that are allowed to change lanes
48 | non_lane_changing_veh = \
49 | [self.time_counter <=
50 | self.env_params.additional_params["lane_change_duration"]
51 | + self.k.vehicle.get_last_lc(veh_id)
52 | for veh_id in sorted_rl_ids]
53 | # vehicle that are not allowed to change have their directions set to 0
54 | direction[non_lane_changing_veh] = \
55 | np.array([0] * sum(non_lane_changing_veh))
56 | if direction[0] != 0 and direction[0] != 1 and direction[0] != -1:
57 | print('wrong value of direction!', direction[0])
58 | direction = np.array([0.])
59 | self.k.vehicle.apply_acceleration(sorted_rl_ids, acc=acceleration)
60 | self.k.vehicle.apply_lane_change(sorted_rl_ids, direction=direction)
61 |
62 | def compute_reward(self, rl_actions, **kwargs):
63 | """See class definition."""
64 | # compute the system-level performance of vehicles from a velocity
65 | # perspective
66 | reward = rewards.desired_velocity(self, fail=kwargs["fail"])
67 |
68 | return reward
69 |
70 |
71 |
72 | # time horizon of a single rollout
73 | HORIZON = 1500
74 |
75 | ADDITIONAL_NET_PARAMS["lanes"] = 2
76 | ADDITIONAL_NET_PARAMS["speed_limit"] = 12
77 |
78 |
79 | # We place one autonomous vehicle and 21 human-driven vehicles in the network
80 | vehicles = VehicleParams()
81 | vehicles.add(
82 | veh_id="human",
83 | acceleration_controller=(SimCarFollowingController, {
84 | "noise": 0.2
85 | }),
86 | car_following_params=SumoCarFollowingParams(
87 | min_gap=0,
88 | max_speed=12
89 | ),
90 | routing_controller=(ContinuousRouter, {}),
91 | num_vehicles=21)
92 |
93 | vehicles.add(
94 | veh_id='rl',
95 | acceleration_controller=(RLController, {}),
96 | routing_controller=(ContinuousRouter, {}),
97 | lane_change_params=SumoLaneChangeParams(lane_change_mode="no_lc_safe",),
98 | car_following_params=SumoCarFollowingParams(
99 | speed_mode="obey_safe_speed",
100 | decel=1.5,
101 | ),
102 | num_vehicles=1)
103 |
104 |
105 | ring_accel_lc_params = dict(
106 | # name of the experiment
107 | exp_tag="stabilizing_the_ring",
108 |
109 | # name of the flow environment the experiment is running on
110 | env_name=LaneChangeAccelEnv_Wrapper,
111 |
112 | # name of the network class the experiment is running on
113 | network=RingNetwork,
114 |
115 | # simulator that is used by the experiment
116 | simulator='traci',
117 |
118 | # sumo-related parameters (see flow.core.params.SumoParams)
119 | sim=SumoParams(
120 | sim_step=0.1, # seconds per simulation step
121 | render=False,
122 | ),
123 |
124 | # environment related parameters (see flow.core.params.EnvParams)
125 | env=EnvParams(
126 | horizon=HORIZON,
127 | warmup_steps=750,
128 | clip_actions=False,
129 | additional_params={
130 | "target_velocity": 5,
131 | "sort_vehicles": False,
132 | "max_accel": 3,
133 | "max_decel": 3,
134 | "lane_change_duration": 5
135 | },
136 | ),
137 |
138 | # network-related parameters (see flow.core.params.NetParams and the
139 | # network's documentation or ADDITIONAL_NET_PARAMS component)
140 | net=NetParams(
141 | additional_params=ADDITIONAL_NET_PARAMS.copy()
142 | ),
143 |
144 | # vehicles to be placed in the network at the start of a rollout (see
145 | # flow.core.params.VehicleParams)
146 | veh=vehicles,
147 |
148 | # parameters specifying the positioning of vehicles upon initialization/
149 | # reset (see flow.core.params.InitialConfig)
150 | initial=InitialConfig(
151 | bunching=20,
152 | ),
153 | )
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(name='icct', version='1.0', packages=find_packages(), install_requires=['torch', 'numpy'])
4 |
5 |
--------------------------------------------------------------------------------