├── readme.md ├── src ├── .ipynb_checkpoints │ └── load_model_weights-checkpoint.ipynb ├── __init__.py ├── __pycache__ │ └── __init__.cpython-36.pyc ├── cdt │ ├── CDT.py │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ │ ├── CDT.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ └── cdt_discretization.cpython-36.pyc │ ├── cdt_discretization.py │ ├── cdt_evaluation.py │ ├── cdt_il_train.json │ ├── cdt_plot.py │ ├── cdt_rl_train.json │ ├── cdt_rl_train_compare.json │ └── deprecated │ │ ├── cdt_discretization.py │ │ ├── cdt_il_train.py │ │ └── cdt_rl_train.py ├── hdt │ ├── HDT_lunarlander.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── heuristic_agents.cpython-36.pyc │ │ └── heuristic_lunarlander.cpython-36.pyc │ ├── heuristic_agents.py │ └── heuristic_lunarlander.py ├── il │ └── il.json ├── il_data_collect.py ├── il_eval.py ├── il_train.py ├── il_train.sh ├── mlp │ └── mlp_rl_train.json ├── rl │ ├── .ipynb_checkpoints │ │ ├── state_statistics-checkpoint.ipynb │ │ ├── stats-checkpoint.ipynb │ │ └── test-checkpoint.ipynb │ ├── PPO.py │ ├── PPO.pyc │ ├── __init__.py │ ├── __init__.pyc │ ├── env_wrapper.py │ ├── rl.json │ └── stats.ipynb ├── rl_data_collect.py ├── rl_eval.py ├── rl_train.py ├── rl_train.sh ├── rl_train_compare_cdt.py ├── rl_train_compare_cdt.sh ├── rl_train_compare_sdt.py ├── rl_train_compare_sdt.sh ├── sdt │ ├── SDT.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── SDT.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ └── sdt_discretization.cpython-36.pyc │ ├── deprecated │ │ ├── sdt_discretization.py │ │ ├── sdt_il_train.py │ │ └── sdt_rl_train.py │ ├── sdt_discretization.py │ ├── sdt_evaluation.py │ ├── sdt_il_train.json │ ├── sdt_plot.py │ ├── sdt_rl_train.json │ └── sdt_rl_train_compare.json ├── utils │ ├── __pycache__ │ │ ├── common_func.cpython-36.pyc │ │ ├── dataset.cpython-36.pyc │ │ └── heuristic_evaluation.cpython-36.pyc │ ├── common_func.py │ ├── dataset.py │ └── heuristic_evaluation.py └── viper │ ├── .gitignore │ ├── algorithms │ ├── __init__.py │ ├── agents.py │ ├── algorithm.py │ ├── config │ │ ├── Dagger.py │ │ ├── QLearning_Atari.py │ │ ├── SAC.py │ │ ├── __init__.py │ │ └── heuristic.py │ ├── envs │ │ ├── Wrapper.py │ │ └── __init__.py │ ├── models.py │ └── utils.py │ ├── dagger.py │ ├── gym.ipynb │ ├── lunar_lander.py │ └── rl.py └── visual ├── .ipynb_checkpoints ├── load_model_weights-checkpoint.ipynb ├── params-checkpoint.ipynb ├── plot-checkpoint.ipynb └── stability_analysis-checkpoint.ipynb ├── params.ipynb ├── plot.ipynb └── stability_analysis.ipynb /readme.md: -------------------------------------------------------------------------------- 1 | # Cascading Decision Tree (CDT) for Explainable Reinforcement Learning 2 | 3 | Open-source code for paper *CDT: Cascading Decision Trees for Explainable Reinforcement Learning* (https://arxiv.org/abs/2011.07553). 4 | 5 | Data folder: the data folders (`data` and `data_no_norm`) should be put at the root of the repo to run the code. See issues: [#4](https://github.com/quantumiracle/Cascading-Decision-Tree/issues/4). The data folders are stored at the [Google Drive](https://drive.google.com/drive/folders/18GGBNZhugIAQXJ1TXtwJBwI6HkDWIJay?usp=sharing). 6 | 7 | # File Structure 8 | 9 | * data: all data for experiments (not maintained in the repo, but can be collected with the given scripts below) 10 | * mlp: data for MLP model; 11 | * cdt: data for CDT model; 12 | * sdt: data for SDT model; 13 | * il: data for general Imitation Learning (IL); 14 | * rl: data for general Reinforcement Learning (RL); 15 | * cdt_compare_depth: data for cdt with different depths in RL; 16 | * sdt_compare_depth: data for sdt with different depths in RL; 17 | * src: source code 18 | * mlp: training configurations for MLP as policy function approximator; 19 | * cdt: the Cascading Decision Tree (CDT) class and necessary functions; 20 | * sdt: the Soft Decision Tree (SDT) class and necessary functions; 21 | * hdt: the heuristic agents; 22 | * il: configurations for Imitation Learning (IL); 23 | * rl: configurations for Reinforcement Learning (RL) and RL agents (e.g., PPO) etc; 24 | * utils: some common functions 25 | * `il_data_collect.py`: collect dataset (state-action from heuristic or well-trained policy) for IL; 26 | * `rl_data_collect.py`: collect dataset (states during training for calculating normalization statistics) for RL; 27 | * `il_train.py`: train IL agent with different function approximators (e.g., SDT, CDT); 28 | * `rl_train.py`: train RL agent different function approximators (e.g., SDT, CDT, MLP); 29 | * `il_eval.py`: evaluate the trained IL agents before and after tree discretization, based on prediction accuracy; 30 | * `rl_eval.py`: evaluate the trained RL agents before and after tree discretization, based on episodic reward; 31 | * `il_train.sh`: bash to run IL test with different models on server; 32 | * `rl_train.sh`: bash to run RL test with different models on server; 33 | * `rl_train_compare_sdt.py`: train RL agent with SDT; 34 | * `rl_train_compare_cdt.py`: train RL agent with SDT; 35 | * `rl_train_compare_sdt.sh`: bash to run RL test with SDT of different depths on server; 36 | * `rl_train_compare_cdt.sh`: bash to run RL test with CDT of different depths on server; 37 | * visual 38 | * `plot.ipynb`: plot learning curves, etc. 39 | * `params.ipynb`: quantitive analysis of model parameters (SDT and CDT). 40 | * `stability_analysis.ipynb`: refer to the stability analysis in paper--compare the tree weights. 41 | 42 | # To Run 43 | 44 | For fully replicating the experiments in the paper, the code needs to run in several stages. 45 | 46 | ### A. Reinforcement Learning Comparison with SDT, CDT and MLP 47 | 48 | 1. Collect dataset: for state normalization 49 | 50 | ``` bash 51 | cd ./src 52 | python rl_data_collect.py 53 | ``` 54 | 55 | 2. Get statistics on dataset 56 | 57 | ````bash 58 | cd rl 59 | jupyter notebook 60 | ```` 61 | 62 | open `stats.ipynb` and run cells in it to generate files for dataset statistics. 63 | 64 | Step 1, 2 can be skipped is not using state normalization. 65 | 66 | 3. Train RL agents with different policy function approximators: SDT, CDT, MLP 67 | 68 | ```bash 69 | cd .. 70 | python rl_train.py --train --env='CartPole-v1' --method='sdt' --id=0 71 | python rl_train.py --train --env='LunarLander-v2' --method='cdt' --id=0 72 | python rl_train.py --train --env='MountainCar-v0' --method='mlp' --id=0 73 | ``` 74 | 75 | or simply run with: 76 | 77 | ````bash 78 | ./rl_train.sh 79 | ```` 80 | 81 | 4. Evaluate the trained agents (with discretization operation) 82 | 83 | ````bash 84 | python rl_eval.py --env='CartPole-v1' --method='sdt' 85 | python rl_eval.py --env='LunarLander-v2' --method='cdt' 86 | ```` 87 | 88 | 89 | 5. Results visualization 90 | 91 | ```bash 92 | cd ../visual 93 | jupyter notebook 94 | ``` 95 | 96 | see in `plot.ipynb`. 97 | 98 | ### B. Imitation Learning Comparison with SDT and CDT 99 | 100 | 1. Collect dataset: for (1) state normalization and (2) as imitation learning dataset 101 | 102 | ```bash 103 | cd ./src 104 | python il_data_collect.py 105 | ``` 106 | 107 | 2. Train RL agents with different policy function approximators: SDT, CDT 108 | 109 | ```bash 110 | python il_train.py --train --env='CartPole-v1' --method='sdt' --id=0 111 | python il_train.py --train --env='LunarLander-v2' --method='cdt' --id=0 112 | ``` 113 | 114 | or simply run with: 115 | 116 | ```bash 117 | ./il_train.sh 118 | ``` 119 | 120 | 3. Evaluate the trained agents 121 | 122 | ```bash 123 | python il_eval.py --env='CartPole-v1' --method='sdt' 124 | python il_eval.py --env='LunarLander-v2' --method='cdt' 125 | ``` 126 | 127 | 4. Results visualization 128 | 129 | ``` 130 | cd ../visual 131 | jupyter notebook 132 | ``` 133 | 134 | see in `plot.ipynb`. 135 | 136 | ### B'. Imitation Learning with DAGGER and Q-DAGGER 137 | DAGGER and Q-DAGGER methods in [VIPER](https://arxiv.org/abs/1805.08328) are compared in the paper as well under the imitation learning setting. Code in `./src/viper/`. Credit gives to [Hangrui (Henry) Bi 138 | ](https://github.com/20171130). 139 | 140 | ### C. Tree Depths for SDT and CDT in Reinforcement Learning 141 | 142 | Run the comparison with different tree depths: 143 | 144 | For SDT: 145 | 146 | ```bash 147 | ./rl_train_compare_sdt.sh 148 | ``` 149 | 150 | For CDT: 151 | 152 | ``` 153 | ./rl_train_compare_cdt.sh 154 | ``` 155 | 156 | ### D. Stability Analysis 157 | 158 | Compare the tree weights of different agents in IL: 159 | 160 | ```bash 161 | cd ./visual 162 | jupyner notebook 163 | ``` 164 | 165 | See in `stability_analysis.ipynb`. 166 | 167 | ### E. Model Simplicity 168 | 169 | Quantitative analysis of number of model parameters: 170 | 171 | ```bash 172 | cd ./visual 173 | jupyter notebook 174 | ``` 175 | 176 | See in `params.ipynb`. 177 | 178 | ## Citation: 179 | ``` 180 | @article{ding2020cdt, 181 | title={Cdt: Cascading decision trees for explainable reinforcement learning}, 182 | author={Ding, Zihan and Hernandez-Leal, Pablo and Ding, Gavin Weiguang and Li, Changjian and Huang, Ruitong}, 183 | journal={arXiv preprint arXiv:2011.07553}, 184 | year={2020} 185 | } 186 | ``` 187 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantumiracle/Cascading-Decision-Tree/8c8be18031511547d35e540a01262f1c006b9687/src/__init__.py -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantumiracle/Cascading-Decision-Tree/8c8be18031511547d35e540a01262f1c006b9687/src/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /src/cdt/CDT.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ Cascade a feature learning tree and a soft decision tree (sparse in features) """ 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | class CDT(nn.Module): 8 | def __init__(self, args): 9 | super(CDT, self).__init__() 10 | self.args = args 11 | print('CDT parameters: ', args) 12 | self.device = torch.device(self.args['device']) 13 | 14 | self.sigmoid = nn.Sigmoid() 15 | self.softmax = nn.Softmax(dim=1) 16 | 17 | self.feature_learning_init() 18 | self.decision_init() 19 | 20 | self.max_leaf_idx = None 21 | 22 | self.optimizer = torch.optim.Adam(self.parameters(), lr=self.args['lr'], weight_decay=self.args['weight_decay']) 23 | self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=self.args['exp_scheduler_gamma']) 24 | 25 | def get_tree_weights(self, Bias=False): 26 | """Return tree weights as a list""" 27 | if Bias: 28 | return self.state_dict()['fl_inner_nodes.weight'].detach().cpu().numpy(), self.state_dict()['dc_inner_nodes.weight'].detach().cpu().numpy() 29 | else: # no bias 30 | return self.state_dict()['fl_inner_nodes.weight'][:, 1:].detach().cpu().numpy(), self.state_dict()['dc_inner_nodes.weight'][:, 1:].detach().cpu().numpy() 31 | 32 | def get_feature_weights(self,): 33 | return self.state_dict()['fl_leaf_weights'].detach().cpu().numpy().reshape(self.num_fl_leaves, self.args['num_intermediate_variables'], self.args['input_dim']) 34 | 35 | 36 | def feature_learning_init(self): 37 | self.num_fl_inner_nodes = 2**self.args['feature_learning_depth'] -1 38 | self.num_fl_leaves = self.num_fl_inner_nodes + 1 39 | self.fl_inner_nodes = nn.Linear(self.args['input_dim']+1, self.num_fl_inner_nodes, bias=False) 40 | # coefficients of feature combinations 41 | fl_leaf_weights = torch.randn(self.num_fl_leaves*self.args['num_intermediate_variables'], self.args['input_dim']) 42 | self.fl_leaf_weights = nn.Parameter(fl_leaf_weights) 43 | 44 | # temperature term 45 | if self.args['beta_fl'] is True or self.args['beta_fl']==1: # learnable 46 | beta_fl = torch.randn(self.num_fl_inner_nodes) # use different beta_fl for each node 47 | # beta_fl = torch.randn(1) # or use one beta_fl across all nodes 48 | self.beta_fl = nn.Parameter(beta_fl) 49 | elif self.args['beta_fl'] is False or self.args['beta_fl']==0: 50 | self.beta_fl = torch.ones(1).to(self.device) # or use one beta_fl across all nodes 51 | else: # pass in value for beta_fl 52 | self.beta_fl = torch.tensor(self.args['beta_fl']).to(self.device) 53 | 54 | def feature_learning_forward(self): 55 | """ 56 | Forward the tree for feature learning. 57 | Return the probabilities for reaching each leaf. 58 | """ 59 | path_prob = self.sigmoid(self.beta_fl*self.fl_inner_nodes(self.aug_data)) 60 | 61 | path_prob = torch.unsqueeze(path_prob, dim=2) 62 | path_prob = torch.cat((path_prob, 1-path_prob), dim=2) 63 | _mu = self.aug_data.data.new(self.batch_size,1,1).fill_(1.) 64 | 65 | begin_idx = 0 66 | end_idx = 1 67 | for layer_idx in range(0, self.args['feature_learning_depth']): 68 | _path_prob = path_prob[:, begin_idx:end_idx, :] 69 | 70 | _mu = _mu.view(self.batch_size, -1, 1).repeat(1, 1, 2) 71 | _mu = _mu * _path_prob 72 | begin_idx = end_idx # index for each layer 73 | end_idx = begin_idx + 2 ** (layer_idx+1) 74 | mu = _mu.view(self.batch_size, self.num_fl_leaves) 75 | 76 | return mu 77 | 78 | 79 | def decision_init(self): 80 | self.num_dc_inner_nodes = 2**self.args['decision_depth'] -1 81 | self.num_dc_leaves = self.num_dc_inner_nodes + 1 82 | self.dc_inner_nodes = nn.Linear(self.args['num_intermediate_variables']+1, self.num_dc_inner_nodes, bias=False) 83 | 84 | dc_leaves = torch.randn(self.num_dc_leaves, self.args['output_dim']) 85 | self.dc_leaves = nn.Parameter(dc_leaves) 86 | 87 | # temperature term 88 | if self.args['beta_dc'] is True or self.args['beta_dc'] == 1: # learnable 89 | beta_dc = torch.randn(self.num_dc_inner_nodes) # use different beta_dc for each node 90 | # beta_dc = torch.randn(1) # or use one beta_dc across all nodes 91 | self.beta_dc = nn.Parameter(beta_dc) 92 | elif self.args['beta_dc'] is False or self.args['beta_dc'] == 0: 93 | self.beta_dc = torch.ones(1).to(self.device) # or use one beta_dc across all nodes 94 | else: # pass in value for beta_dc 95 | self.beta_dc = torch.tensor(self.args['beta_dc']).to(self.device) 96 | 97 | def decision_forward(self): 98 | """ 99 | Forward the differentiable decision tree 100 | """ 101 | self.intermediate_features_construct() 102 | aug_features = self._data_augment_(self.features) 103 | path_prob = self.sigmoid(self.beta_dc*self.dc_inner_nodes(aug_features)) 104 | feature_batch_size = self.features.shape[0] 105 | 106 | path_prob = torch.unsqueeze(path_prob, dim=2) 107 | path_prob = torch.cat((path_prob, 1-path_prob), dim=2) 108 | _mu = aug_features.data.new(feature_batch_size,1,1).fill_(1.) 109 | 110 | begin_idx = 0 111 | end_idx = 1 112 | for layer_idx in range(0, self.args['decision_depth']): 113 | _path_prob = path_prob[:, begin_idx:end_idx, :] 114 | 115 | _mu = _mu.view(feature_batch_size, -1, 1).repeat(1, 1, 2) 116 | _mu = _mu * _path_prob 117 | begin_idx = end_idx # index for each layer 118 | end_idx = begin_idx + 2 ** (layer_idx+1) 119 | mu = _mu.view(feature_batch_size, self.num_dc_leaves) 120 | 121 | return mu 122 | 123 | def intermediate_features_construct(self): 124 | """ 125 | Construct the intermediate features for decision making, with learned feature combinations from feature learning module. 126 | """ 127 | features = self.fl_leaf_weights.view(-1, self.args['input_dim']) @ self.data.transpose(0,1) # data: (batch_size, feature_dim); return: (num_fl_leaves*num_intermediate_variables, batch) 128 | self.features = features.contiguous().view(self.num_fl_leaves, self.args['num_intermediate_variables'], -1).permute(2,0,1).contiguous().view(-1, self.args['num_intermediate_variables']) # return: (N, num_intermediate_variables) where N=batch_size*num_fl_leaves 129 | 130 | def decision_leaves(self, p): 131 | distribution_per_leaf = self.softmax(self.dc_leaves) 132 | average_distribution = torch.mm(p, distribution_per_leaf) # sum(probability of each leaf * leaf distribution) 133 | return average_distribution 134 | 135 | def forward(self, data, LogProb=True): 136 | self.data = data 137 | self.batch_size = data.size()[0] 138 | self.aug_data = self._data_augment_(data) 139 | fl_probs = self.feature_learning_forward() # (batch_size, num_fl_leaves) 140 | dc_probs = self.decision_forward() 141 | dc_probs = dc_probs.view(self.batch_size, self.num_fl_leaves, -1) # (batch_size, num_fl_leaves, num_dc_leaves) 142 | 143 | _mu = torch.bmm(fl_probs.unsqueeze(1), dc_probs).squeeze(1) # (batch_size, num_dc_leaves) 144 | output = self.decision_leaves(_mu) 145 | 146 | if self.args['greatest_path_probability']: 147 | vs, ids = torch.max(fl_probs, 1) # ids is the leaf index with maximal path probability 148 | # get the path with greatest probability, get index of it, feature vector and feature value on that leaf 149 | self.max_leaf_idx_fl = ids 150 | self.max_feature_vector = self.fl_leaf_weights.view(self.num_fl_leaves, self.args['num_intermediate_variables'], self.args['input_dim'])[ids] 151 | self.max_feature_value = self.features.view(-1, self.num_fl_leaves, self.args['num_intermediate_variables'])[:, ids, :] 152 | 153 | one_dc_probs = dc_probs[torch.arange(dc_probs.shape[0]), ids, :] # select decision path probabilities of learned features with largest probability 154 | one_hot_path_probability_dc = torch.zeros(one_dc_probs.shape).to(self.device) 155 | vs_dc, ids_dc = torch.max(one_dc_probs, 1) # ids is the leaf index with maximal path probability 156 | self.max_leaf_idx_dc = ids_dc 157 | one_hot_path_probability_dc.scatter_(1, ids_dc.view(-1,1), 1.) 158 | prediction = self.decision_leaves(one_hot_path_probability_dc) 159 | 160 | else: # prediction value equals to the average distribution 161 | prediction = output 162 | 163 | if LogProb: 164 | output = torch.log(output) 165 | prediction = torch.log(prediction) 166 | 167 | return prediction, output, 0 168 | 169 | 170 | """ Add constant 1 onto the front of each instance, serving as the bias """ 171 | def _data_augment_(self, input): 172 | batch_size = input.size()[0] 173 | input = input.view(batch_size, -1) 174 | bias = torch.ones(batch_size, 1).to(self.device) 175 | input = torch.cat((bias, input), 1) 176 | return input 177 | 178 | def save_model(self, model_path, id=''): 179 | torch.save(self.state_dict(), model_path+id) 180 | 181 | def load_model(self, model_path, id=''): 182 | self.load_state_dict(torch.load(model_path+id, map_location='cpu')) 183 | self.eval() 184 | 185 | 186 | if __name__ == '__main__': 187 | learner_args = { 188 | 'num_intermediate_variables': 3, 189 | 'feature_learning_depth': 2, 190 | 'decision_depth': 2, 191 | 'input_dim': 8, 192 | 'output_dim': 4, 193 | } 194 | -------------------------------------------------------------------------------- /src/cdt/__init__.py: -------------------------------------------------------------------------------- 1 | from .CDT import CDT 2 | from .cdt_discretization import discretize_cdt -------------------------------------------------------------------------------- /src/cdt/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantumiracle/Cascading-Decision-Tree/8c8be18031511547d35e540a01262f1c006b9687/src/cdt/__init__.pyc -------------------------------------------------------------------------------- /src/cdt/__pycache__/CDT.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantumiracle/Cascading-Decision-Tree/8c8be18031511547d35e540a01262f1c006b9687/src/cdt/__pycache__/CDT.cpython-36.pyc -------------------------------------------------------------------------------- /src/cdt/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantumiracle/Cascading-Decision-Tree/8c8be18031511547d35e540a01262f1c006b9687/src/cdt/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /src/cdt/__pycache__/cdt_discretization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantumiracle/Cascading-Decision-Tree/8c8be18031511547d35e540a01262f1c006b9687/src/cdt/__pycache__/cdt_discretization.cpython-36.pyc -------------------------------------------------------------------------------- /src/cdt/cdt_discretization.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ Discretize the (soft) differentiable tree into normal decision tree according to DDT paper""" 3 | import torch 4 | import torch.nn as nn 5 | import sys 6 | import numpy as np 7 | import copy 8 | 9 | def discretize_cdt(original_tree, FL=True, DC=True): 10 | """ 11 | Discretize the cascading tree 12 | if FL: discretize the feature learning tree; 13 | if DC: discretize the decision making tree. 14 | """ 15 | tree = copy.deepcopy(original_tree) 16 | for name, parameter in tree.named_parameters(): 17 | 18 | # discretize feature learning tree and decision making tree separately 19 | if FL: 20 | if name == 'beta_fl': 21 | setattr(tree, name, nn.Parameter(100*torch.ones(parameter.shape))) # 100 is a large enough value to make soft decision hard 22 | 23 | elif name == 'fl_inner_nodes.weight': 24 | parameters=[] 25 | for weights in parameter: 26 | bias = weights[0] 27 | max_id = np.argmax(np.abs(weights[1:].detach().cpu().numpy()))+1 28 | max_v = weights[max_id].detach().cpu().numpy() 29 | new_weights = torch.zeros(weights.shape) 30 | if max_v>0: 31 | new_weights[max_id] = torch.tensor(1) 32 | else: 33 | new_weights[max_id] = torch.tensor(-1) 34 | new_weights[0] = bias/np.abs(max_v) 35 | parameters.append(new_weights) 36 | 37 | tree.fl_inner_nodes.weight = nn.Parameter(torch.stack(parameters)) 38 | 39 | if DC: 40 | if name == 'beta_dc': 41 | setattr(tree, name, nn.Parameter(100*torch.ones(parameter.shape))) 42 | 43 | elif name == 'dc_inner_nodes.weight': 44 | parameters=[] 45 | # print(parameter) 46 | for weights in parameter: 47 | bias = weights[0] 48 | max_id = np.argmax(np.abs(weights[1:].detach().cpu().numpy()))+1 49 | max_v = weights[max_id].detach().cpu().numpy() 50 | new_weights = torch.zeros(weights.shape) 51 | if max_v>0: 52 | new_weights[max_id] = torch.tensor(1) 53 | else: 54 | new_weights[max_id] = torch.tensor(-1) 55 | new_weights[0] = bias/np.abs(max_v) 56 | parameters.append(new_weights) 57 | 58 | tree.dc_inner_nodes.weight = nn.Parameter(torch.stack(parameters)) 59 | 60 | return tree 61 | -------------------------------------------------------------------------------- /src/cdt/cdt_evaluation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import gym 6 | from torch.distributions import Categorical 7 | import argparse 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | from cdt_plot import draw_tree, get_path 11 | import sys 12 | sys.path.append("..") 13 | from heuristic_evaluation import normalize 14 | import os 15 | import copy 16 | 17 | EnvName = 'CartPole-v1' # LunarLander-v2 18 | 19 | __all__ = ['evaluate', 20 | 'plot_importance_single_episode' 21 | ] 22 | 23 | def evaluate(model, tree, episodes=1, frameskip=1, seed=None, DrawTree=None, img_path = 'img/eval_tree'): 24 | env = gym.make(EnvName) 25 | if seed: 26 | env.seed(seed) 27 | state_dim = env.observation_space.shape[0] 28 | action_dim = env.action_space.n # discrete 29 | if not os.path.exists(img_path): 30 | os.makedirs(img_path) 31 | average_weight_list = [] 32 | 33 | # show values on tree nodes 34 | print(tree.state_dict()) 35 | # show probs on tree leaves 36 | softmax = nn.Softmax(dim=-1) 37 | print(softmax(tree.state_dict()['dc_leaves']).detach().cpu().numpy()) 38 | 39 | for n_epi in range(episodes): 40 | print('Episode: ', n_epi) 41 | average_weight_list_epi = [] 42 | s = env.reset() 43 | done = False 44 | reward = 0.0 45 | step=0 46 | while not done: 47 | a = model(torch.Tensor([s])) 48 | if step%frameskip==0: 49 | if DrawTree is not None: 50 | draw_tree(tree, input_img=s, DrawTree=DrawTree, savepath=img_path+'_'+DrawTree+'/{:04}.png'.format(step)) 51 | 52 | s_prime, r, done, info = env.step(a) 53 | # env.render() 54 | s = s_prime 55 | 56 | reward += r 57 | step+=1 58 | if done: 59 | break 60 | 61 | average_weight_list.append(average_weight_list_epi) 62 | print("# of episode :{}, reward : {:.1f}, episode length: {}".format(n_epi, reward, step)) 63 | # np.save('data/s/dt_importance.npy', average_weight_list) 64 | 65 | env.close() 66 | 67 | 68 | def plot_importance_single_episode(data_path='./data/sdt_importance.npy', save_path='./img/sdt_importance.png', epi_id=0): 69 | data = np.load(data_path, allow_pickle=True)[epi_id] 70 | for i, weights_per_feature in enumerate(np.array(data).T): 71 | plt.plot(weights_per_feature, label='Dim: {}'.format(i)) 72 | plt.legend(loc=4) 73 | if save_path: 74 | plt.savefig(save_path) 75 | plt.close() 76 | else: 77 | plt.show() 78 | 79 | if __name__ == '__main__': 80 | from CDT import Cascade_DDT 81 | learner_args = { 82 | 'num_intermediate_variables': 2, 83 | 'feature_learning_depth': 2, 84 | 'decision_depth': 2, 85 | 'input_dim': 4, 86 | 'output_dim': 2, 87 | 'lr': 1e-3, 88 | 'weight_decay': 0., # 5e-4 89 | 'batch_size': 1280, 90 | 'exp_scheduler_gamma': 1., 91 | 'cuda': True, 92 | 'epochs': 40, 93 | 'log_interval': 100, 94 | 'greatest_path_probability': True, 95 | 'beta_fl' : False, # temperature for feature learning 96 | 'beta_dc' : False, # temperature for decision making 97 | } 98 | learner_args['model_path'] = './model/cdt/'+str(learner_args['feature_learning_depth'])+'_'\ 99 | +str(learner_args['decision_depth'])+'_var'+str(learner_args['num_intermediate_variables'])+'_id'+str(4) 100 | 101 | 102 | # for reproduciblility 103 | seed=3 104 | if seed: 105 | torch.manual_seed(seed) 106 | np.random.seed(seed) 107 | learner_args['cuda'] = False # cpu 108 | 109 | tree = Cascade_DDT(learner_args) 110 | Discretized=True # whether load the discretized tree 111 | if Discretized: 112 | tree.load_model(learner_args['model_path']+'_discretized') 113 | else: 114 | tree.load_model(learner_args['model_path']) 115 | 116 | num_params = 0 117 | for key, v in tree.state_dict().items(): 118 | print(key, v.reshape(-1).shape[0]) 119 | num_params+=v.reshape(-1).shape[0] 120 | print('Total number of parameters in model: ', num_params) 121 | 122 | model = lambda x: tree.forward(x)[0].data.max(1)[1].squeeze().detach().numpy() 123 | img_path = 'img/eval_tree_{}_{}'.format(tree.args['feature_learning_depth'], tree.args['decision_depth']) 124 | if Discretized: 125 | img_path += '_discretized' 126 | evaluate(model, tree, episodes=10, frameskip=1, seed=seed, DrawTree=None, DrawImportance=False, \ 127 | img_path=img_path) 128 | 129 | # plot_importance_single_episode(epi_id=0) -------------------------------------------------------------------------------- /src/cdt/cdt_il_train.json: -------------------------------------------------------------------------------- 1 | { "General": { 2 | "policy_approx" : "CDT" 3 | }, 4 | "CartPole-v1": { 5 | "learner_args": { 6 | "num_intermediate_variables" : 2, 7 | "feature_learning_depth" : 2, 8 | "decision_depth" : 2, 9 | "input_dim" : 4, 10 | "output_dim" : 2, 11 | "lr" : 1e-3, 12 | "weight_decay" : 0.0, 13 | "batch_size" : 1280, 14 | "exp_scheduler_gamma" : 1.0, 15 | "device" : "cuda", 16 | "epochs" : 80, 17 | "log_interval" : 100, 18 | "greatest_path_probability" : 1, 19 | "beta_fl" : 0, 20 | "beta_dc" : 0, 21 | "model_path" : "../data/cdt/model/cartpole/il_model", 22 | "log_path" : "../data/cdt/log/cartpole/il_log" 23 | } 24 | }, 25 | 26 | "LunarLander-v2": { 27 | "learner_args": { 28 | "num_intermediate_variables" : 2, 29 | "feature_learning_depth" : 3, 30 | "decision_depth" : 3, 31 | "input_dim" : 8, 32 | "output_dim" : 4, 33 | "lr" : 1e-3, 34 | "weight_decay" : 0.0, 35 | "batch_size" : 1280, 36 | "exp_scheduler_gamma" : 1.0, 37 | "device" : "cuda", 38 | "epochs" : 80, 39 | "log_interval" : 100, 40 | "greatest_path_probability" : 1, 41 | "beta_fl" : 0, 42 | "beta_dc" : 0, 43 | "model_path" : "../data/cdt/model/lunarlander/il_model", 44 | "log_path" : "../data/cdt/log/lunarlander/il_log" 45 | } 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /src/cdt/cdt_rl_train.json: -------------------------------------------------------------------------------- 1 | { "General": { 2 | "policy_approx" : "CDT" 3 | }, 4 | "CartPole-v1": { 5 | "learner_args": { 6 | "num_intermediate_variables" : 2, 7 | "feature_learning_depth" : 2, 8 | "decision_depth" : 2, 9 | "input_dim" : 4, 10 | "output_dim" : 2, 11 | "lr" : 1e-3, 12 | "weight_decay" : 0.0, 13 | "batch_size" : 1280, 14 | "exp_scheduler_gamma" : 1.0, 15 | "device" : "cuda", 16 | "episodes" : 40, 17 | "log_interval" : 100, 18 | "greatest_path_probability" : 1, 19 | "beta_fl" : 0, 20 | "beta_dc" : 0 21 | }, 22 | "alg_confs": { 23 | "learning_rate" : 0.0005, 24 | "gamma" : 0.98, 25 | "lmbda" : 0.95, 26 | "eps_clip" : 0.1, 27 | "K_epoch" : 3, 28 | "hidden_dim" : 128 29 | }, 30 | "train_confs": { 31 | "episodes" : 3000, 32 | "t_horizon" : 1000, 33 | "model_path" : "../data/cdt/model/cartpole/rl_ppo", 34 | "log_path" : "../data/cdt/log/cartpole/rl_reward" 35 | } 36 | }, 37 | 38 | "LunarLander-v2": { 39 | "learner_args": { 40 | "num_intermediate_variables" : 2, 41 | "feature_learning_depth" : 3, 42 | "decision_depth" : 3, 43 | "input_dim" : 8, 44 | "output_dim" : 4, 45 | "lr" : 1e-3, 46 | "weight_decay" : 0.0, 47 | "batch_size" : 1280, 48 | "exp_scheduler_gamma" : 1.0, 49 | "device" : "cuda", 50 | "episodes" : 40, 51 | "log_interval" : 100, 52 | "greatest_path_probability" : 1, 53 | "beta_fl" : 0, 54 | "beta_dc" : 0 55 | }, 56 | "alg_confs": { 57 | "learning_rate" : 0.0005, 58 | "gamma" : 0.98, 59 | "lmbda" : 0.95, 60 | "eps_clip" : 0.1, 61 | "K_epoch" : 3, 62 | "hidden_dim" : 128 63 | }, 64 | "train_confs": { 65 | "episodes" : 5000, 66 | "t_horizon" : 1000, 67 | "model_path" : "../data/cdt/model/lunarlander/rl_ppo", 68 | "log_path" : "../data/cdt/log/lunarlander/rl_reward" 69 | } 70 | }, 71 | 72 | 73 | "MountainCar-v0": { 74 | "learner_args": { 75 | "num_intermediate_variables" : 1, 76 | "feature_learning_depth" : 2, 77 | "decision_depth" : 2, 78 | "input_dim" : 2, 79 | "output_dim" : 3, 80 | "lr" : 1e-3, 81 | "weight_decay" : 0.0, 82 | "batch_size" : 128, 83 | "exp_scheduler_gamma" : 1.0, 84 | "device" : "cuda", 85 | "episodes" : 40, 86 | "log_interval" : 100, 87 | "greatest_path_probability" : 1, 88 | "beta_fl" : 0, 89 | "beta_dc" : 0 90 | }, 91 | "alg_confs": { 92 | "learning_rate" : 0.005, 93 | "gamma" : 0.999, 94 | "lmbda" : 0.98, 95 | "eps_clip" : 0.1, 96 | "K_epoch" : 10, 97 | "hidden_dim" : 32 98 | }, 99 | "train_confs": { 100 | "episodes" : 5000, 101 | "t_horizon" : 1000, 102 | "model_path" : "../data/cdt/model/mountaincar/rl_ppo", 103 | "log_path" : "../data/cdt/log/mountaincar/rl_reward" 104 | } 105 | }, 106 | 107 | "Acrobot-v1": { 108 | "learner_args": { 109 | "num_intermediate_variables" : 2, 110 | "feature_learning_depth" : 2, 111 | "decision_depth" : 2, 112 | "input_dim" : 6, 113 | "output_dim" : 3, 114 | "lr" : 1e-3, 115 | "weight_decay" : 0.0, 116 | "batch_size" : 128, 117 | "exp_scheduler_gamma" : 1.0, 118 | "device" : "cuda", 119 | "episodes" : 40, 120 | "log_interval" : 100, 121 | "greatest_path_probability" : 1, 122 | "beta_fl" : 0, 123 | "beta_dc" : 0 124 | }, 125 | "alg_confs": { 126 | "learning_rate" : 0.0005, 127 | "gamma" : 0.98, 128 | "lmbda" : 0.95, 129 | "eps_clip" : 0.1, 130 | "K_epoch" : 3, 131 | "hidden_dim" : 128 132 | }, 133 | "train_confs": { 134 | "episodes" : 7000, 135 | "t_horizon" : 1000, 136 | "model_path" : "../data_no_norm/cdt/model/acrobot/rl_ppo", 137 | "log_path" : "../data_no_norm/cdt/log/acrobot/rl_reward" 138 | } 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /src/cdt/cdt_rl_train_compare.json: -------------------------------------------------------------------------------- 1 | { "General": { 2 | "policy_approx" : "CDT" 3 | }, 4 | "CartPole-v1": { 5 | "learner_args": { 6 | "num_intermediate_variables" : 2, 7 | "feature_learning_depth" : 2, 8 | "decision_depth" : 2, 9 | "input_dim" : 4, 10 | "output_dim" : 2, 11 | "lr" : 1e-3, 12 | "weight_decay" : 0.0, 13 | "batch_size" : 1280, 14 | "exp_scheduler_gamma" : 1.0, 15 | "device" : "cuda", 16 | "episodes" : 40, 17 | "log_interval" : 100, 18 | "greatest_path_probability" : 1, 19 | "beta_fl" : 0, 20 | "beta_dc" : 0 21 | }, 22 | "alg_confs": { 23 | "learning_rate" : 0.0005, 24 | "gamma" : 0.98, 25 | "lmbda" : 0.95, 26 | "eps_clip" : 0.1, 27 | "K_epoch" : 3, 28 | "hidden_dim" : 128 29 | }, 30 | "train_confs": { 31 | "episodes" : 3000, 32 | "t_horizon" : 1000, 33 | "model_path" : "../data/cdt_compare_depth/model/cartpole/rl_ppo", 34 | "log_path" : "../data/cdt_compare_depth/log/cartpole/rl_reward" 35 | } 36 | }, 37 | 38 | "LunarLander-v2": { 39 | "learner_args": { 40 | "num_intermediate_variables" : 2, 41 | "feature_learning_depth" : 3, 42 | "decision_depth" : 3, 43 | "input_dim" : 8, 44 | "output_dim" : 4, 45 | "lr" : 1e-3, 46 | "weight_decay" : 0.0, 47 | "batch_size" : 1280, 48 | "exp_scheduler_gamma" : 1.0, 49 | "device" : "cuda", 50 | "episodes" : 40, 51 | "log_interval" : 100, 52 | "greatest_path_probability" : 1, 53 | "beta_fl" : 0, 54 | "beta_dc" : 0 55 | }, 56 | "alg_confs": { 57 | "learning_rate" : 0.0005, 58 | "gamma" : 0.98, 59 | "lmbda" : 0.95, 60 | "eps_clip" : 0.1, 61 | "K_epoch" : 3, 62 | "hidden_dim" : 128 63 | }, 64 | "train_confs": { 65 | "episodes" : 5000, 66 | "t_horizon" : 1000, 67 | "model_path" : "../data/cdt_compare_depth/model/lunarlander/rl_ppo", 68 | "log_path" : "../data/cdt_compare_depth/log/lunarlander/rl_reward" 69 | } 70 | } 71 | } -------------------------------------------------------------------------------- /src/cdt/deprecated/cdt_discretization.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ Discretize the (soft) differentiable tree into normal decision tree according to DDT paper""" 3 | import torch 4 | import torch.nn as nn 5 | import sys 6 | sys.path.append("..") 7 | from utils.dataset import Dataset 8 | import numpy as np 9 | import copy 10 | 11 | def discretize_tree(original_tree, FL=True, DC=True): 12 | """ 13 | Discretize the cascading tree 14 | if FL: discretize the feature learning tree; 15 | if DC: discretize the decision making tree. 16 | """ 17 | tree = copy.deepcopy(original_tree) 18 | for name, parameter in tree.named_parameters(): 19 | 20 | # discretize feature learning tree and decision making tree separately 21 | if FL: 22 | if name == 'beta_fl': 23 | setattr(tree, name, nn.Parameter(100*torch.ones(parameter.shape))) 24 | 25 | elif name == 'fl_inner_nodes.weight': 26 | parameters=[] 27 | for weights in parameter: 28 | bias = weights[0] 29 | max_id = np.argmax(np.abs(weights[1:].detach()))+1 30 | max_v = weights[max_id].detach() 31 | new_weights = torch.zeros(weights.shape) 32 | if max_v>0: 33 | new_weights[max_id] = torch.tensor(1) 34 | else: 35 | new_weights[max_id] = torch.tensor(-1) 36 | new_weights[0] = bias/np.abs(max_v) 37 | parameters.append(new_weights) 38 | 39 | tree.fl_inner_nodes.weight = nn.Parameter(torch.stack(parameters)) 40 | 41 | if DC: 42 | if name == 'beta_dc': 43 | setattr(tree, name, nn.Parameter(100*torch.ones(parameter.shape))) 44 | 45 | elif name == 'dc_inner_nodes.weight': 46 | parameters=[] 47 | # print(parameter) 48 | for weights in parameter: 49 | bias = weights[0] 50 | max_id = np.argmax(np.abs(weights[1:].detach()))+1 51 | max_v = weights[max_id].detach() 52 | new_weights = torch.zeros(weights.shape) 53 | if max_v>0: 54 | new_weights[max_id] = torch.tensor(1) 55 | else: 56 | new_weights[max_id] = torch.tensor(-1) 57 | new_weights[0] = bias/np.abs(max_v) 58 | parameters.append(new_weights) 59 | 60 | tree.dc_inner_nodes.weight = nn.Parameter(torch.stack(parameters)) 61 | 62 | return tree 63 | 64 | def onehot_coding(target, output_dim): 65 | target_onehot = torch.FloatTensor(target.size()[0], output_dim) 66 | target_onehot.data.zero_() 67 | target_onehot.scatter_(1, target.view(-1, 1), 1.) 68 | return target_onehot 69 | 70 | def discretization_evaluation(tree, discretized_tree): 71 | # Load data 72 | # data_dir = '../data/discrete_' 73 | data_dir = '../data/cartpole_greedy_ppo_' 74 | data_path = data_dir+'state.npy' 75 | label_path = data_dir+'action.npy' 76 | 77 | # a data loader with all data in dataset 78 | test_loader = torch.utils.data.DataLoader(Dataset(data_path, label_path, partition='test', ToTensor=True), 79 | batch_size=int(1e4), 80 | shuffle=True) 81 | accuracy_list=[] 82 | accuracy_list_=[] 83 | correct=0. 84 | correct_=0. 85 | for batch_idx, (data, target) in enumerate(test_loader): 86 | # data, target = data.to(device), target.to(device) 87 | target_onehot = onehot_coding(target, tree.args['output_dim']) 88 | prediction, _, _ = tree.forward(data) 89 | prediction_, _, _ = discretized_tree.forward(data) 90 | with torch.no_grad(): 91 | pred = prediction.data.max(1)[1] 92 | correct += pred.eq(target.view(-1).data).sum() 93 | pred_ = prediction_.data.max(1)[1] 94 | correct_ += pred_.eq(target.view(-1).data).sum() 95 | accuracy = 100. * float(correct) / len(test_loader.dataset) 96 | accuracy_ = 100. * float(correct_) / len(test_loader.dataset) 97 | print('Original Tree Accuracy: {:.4f} | Discretized Tree Accuracy: {:.4f}'.format(accuracy, accuracy_)) 98 | 99 | 100 | if __name__ == '__main__': 101 | from CDT import Cascade_DDT 102 | learner_args = { 103 | 'num_intermediate_variables': 2, 104 | 'feature_learning_depth': 1, 105 | 'decision_depth': 2, 106 | 'input_dim': 4, 107 | 'output_dim': 2, 108 | 'lr': 1e-3, 109 | 'weight_decay': 0., # 5e-4 110 | 'batch_size': 1280, 111 | 'exp_scheduler_gamma': 1., 112 | 'cuda': True, 113 | 'epochs': 40, 114 | 'log_interval': 100, 115 | 'greatest_path_probability': True, 116 | 'beta_fl' : False, # temperature for feature learning 117 | 'beta_dc' : False, # temperature for decision making 118 | } 119 | # discretize_type=[[True, False], [False, True], [True, True]] 120 | discretize_type=[[True, True]] 121 | print(learner_args['num_intermediate_variables'], learner_args['feature_learning_depth'], learner_args['decision_depth']) 122 | 123 | for dis_type in discretize_type: 124 | print(dis_type) 125 | for i in range(4,7): 126 | learner_args['model_path'] = './model/cdt_'+str(learner_args['feature_learning_depth'])+'_'\ 127 | +str(learner_args['decision_depth'])+'_var'+str(learner_args['num_intermediate_variables'])+'_id'+str(i) 128 | 129 | learner_args['cuda'] = False # cpu 130 | 131 | tree = Cascade_DDT(learner_args) 132 | tree.load_model(learner_args['model_path']) 133 | 134 | discretized_tree = discretize_tree(tree, FL=dis_type[0], DC=dis_type[1]) 135 | discretization_evaluation(tree, discretized_tree) 136 | 137 | discretized_tree.save_model(model_path = learner_args['model_path']+'_discretized') 138 | -------------------------------------------------------------------------------- /src/cdt/deprecated/cdt_il_train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import sys 6 | sys.path.append("..") 7 | from utils.dataset import Dataset 8 | import numpy as np 9 | from torch.utils.tensorboard import SummaryWriter 10 | from utils.heuristic_evaluation import difference_metric, intermediate_features_in_heuristic_tree 11 | from CDT import CDT 12 | import argparse 13 | 14 | parser = argparse.ArgumentParser(description='parse') 15 | parser.add_argument('--depth1', dest='feature_learning_depth', default=False) 16 | parser.add_argument('--depth2', dest='decision_depth', default=False) 17 | parser.add_argument('--vars', dest='num_intermediate_variables', default=False) 18 | parser.add_argument('--id', dest='id', default=False) 19 | args = parser.parse_args() 20 | 21 | def onehot_coding(target, device, output_dim): 22 | target_onehot = torch.FloatTensor(target.size()[0], output_dim).to(device) 23 | target_onehot.data.zero_() 24 | target_onehot.scatter_(1, target.view(-1, 1), 1.) 25 | return target_onehot 26 | 27 | use_cuda = True 28 | learner_args = { 29 | 'num_intermediate_variables': int(args.num_intermediate_variables), 30 | 'feature_learning_depth': int(args.feature_learning_depth), 31 | 'decision_depth': int(args.decision_depth), 32 | 'input_dim': 8, 33 | 'output_dim': 4, 34 | 'lr': 1e-3, 35 | 'weight_decay': 0., # 5e-4 36 | 'batch_size': 1280, 37 | 'exp_scheduler_gamma': 1., 38 | 'cuda': use_cuda, 39 | 'epochs': 40, 40 | 'log_interval': 100, 41 | 'greatest_path_probability': True, 42 | 'beta_fl' : False, # temperature for feature learning 43 | 'beta_dc' : False, # temperature for decision making 44 | } 45 | learner_args['model_path'] = './model/cdt/'+str(learner_args['feature_learning_depth'])+'_'+str(learner_args['decision_depth'])+'_var'+str(learner_args['num_intermediate_variables'])+'_id'+str(args.id) 46 | 47 | 48 | device = torch.device('cuda' if use_cuda else 'cpu') 49 | 50 | def train_tree(tree): 51 | writer = SummaryWriter(log_dir='runs/cdt_'+str(learner_args['feature_learning_depth'])+'_'+str(learner_args['decision_depth'])+'_var'+str(learner_args['num_intermediate_variables'])+'_id'+str(args.id)) 52 | # criterion = nn.CrossEntropyLoss() # torch CrossEntropyLoss = LogSoftmax + NLLLoss 53 | criterion = nn.NLLLoss() # since we already have log probability, simply using Negative Log-likelihood loss can provide cross-entropy loss 54 | 55 | # Load data 56 | data_dir = '../data/discrete_' 57 | data_path = data_dir+'state.npy' 58 | label_path = data_dir+'action.npy' 59 | train_loader = torch.utils.data.DataLoader(Dataset(data_path, label_path, partition='train'), 60 | batch_size=learner_args['batch_size'], 61 | shuffle=True) 62 | 63 | test_loader = torch.utils.data.DataLoader(Dataset(data_path, label_path, partition='test'), 64 | batch_size=learner_args['batch_size'], 65 | shuffle=True) 66 | 67 | 68 | # Utility variables 69 | best_testing_acc = 0. 70 | testing_acc_list = [] 71 | 72 | for epoch in range(1, learner_args['epochs']+1): 73 | epoch_training_loss_list = [] 74 | epoch_feature_difference_list = [] 75 | 76 | # Training stage 77 | tree.train() 78 | for batch_idx, (data, target) in enumerate(train_loader): 79 | data, target = data.to(device), target.to(device) 80 | target_onehot = onehot_coding(target, device, learner_args['output_dim']) 81 | prediction, output, penalty = tree.forward(data) 82 | 83 | difference=0 84 | intermediate_features = tree.fl_leaf_weights.detach().cpu().numpy() 85 | difference = difference_metric(intermediate_features, list2=np.array(intermediate_features_in_heuristic_tree)[:, 1:]) # remove the constants for intermediate feature in heuristic 86 | epoch_feature_difference_list.append(difference) 87 | 88 | tree.optimizer.zero_grad() 89 | loss = criterion(output, target.view(-1)) 90 | loss += penalty 91 | loss.backward() 92 | tree.optimizer.step() 93 | 94 | # Print intermediate training status 95 | if batch_idx % learner_args['log_interval'] == 0: 96 | with torch.no_grad(): 97 | pred = prediction.data.max(1)[1] 98 | correct = pred.eq(target.view(-1).data).sum() 99 | loss = criterion(output, target.view(-1)) 100 | epoch_training_loss_list.append(loss.detach().cpu().data.numpy()) 101 | print('Epoch: {:02d} | Batch: {:03d} | CrossEntropy-loss: {:.5f} | Correct: {}/{} | Difference: {}'.format( 102 | epoch, batch_idx, loss.data, correct, output.size()[0], difference)) 103 | 104 | tree.save_model(model_path = learner_args['model_path']) 105 | writer.add_scalar('Training Loss', np.mean(epoch_training_loss_list), epoch) 106 | writer.add_scalar('Training Feature Difference', np.mean(epoch_feature_difference_list), epoch) 107 | 108 | # Testing stage 109 | tree.eval() 110 | correct = 0. 111 | for batch_idx, (data, target) in enumerate(test_loader): 112 | data, target = data.to(device), target.to(device) 113 | batch_size = data.size()[0] 114 | prediction, _, _ = tree.forward(data) 115 | pred = prediction.data.max(1)[1] 116 | correct += pred.eq(target.view(-1).data).sum() 117 | accuracy = 100. * float(correct) / len(test_loader.dataset) 118 | if accuracy > best_testing_acc: 119 | best_testing_acc = accuracy 120 | testing_acc_list.append(accuracy) 121 | writer.add_scalar('Testing Accuracy', accuracy, epoch) 122 | print('\nEpoch: {:02d} | Testing Accuracy: {}/{} ({:.3f}%) | Historical Best: {:.3f}% \n'.format(epoch, correct, len(test_loader.dataset), accuracy, best_testing_acc)) 123 | 124 | 125 | if __name__ == '__main__': 126 | tree = CDT(learner_args).to(device) 127 | train_tree(tree) 128 | -------------------------------------------------------------------------------- /src/cdt/deprecated/cdt_rl_train.py: -------------------------------------------------------------------------------- 1 | """ PPO with cascading decision tree (CDT) as policy function approximator """ 2 | import gym 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torch.distributions import Categorical 8 | import argparse 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from CDT import CDT 12 | import sys 13 | sys.path.append("..") 14 | from rl import StateNormWrapper 15 | 16 | #Hyperparameters 17 | learning_rate = 0.0005 18 | gamma = 0.98 19 | lmbda = 0.95 20 | eps_clip = 0.1 21 | K_epoch = 3 22 | Episodes = 5000 23 | EnvName = 'CartPole-v1' # LunarLander-v2 24 | # EnvName = 'LunarLander-v2' 25 | # EnvName = 'FetchReach-v1' 26 | 27 | class PPO(nn.Module): 28 | def __init__(self, state_dim, action_dim, learner_args): 29 | super(PPO, self).__init__() 30 | self.data = [] 31 | self.model_path = learner_args['model_path'] 32 | self.device = learner_args['device'] 33 | hidden_dim=128 34 | self.fc1 = nn.Linear(state_dim,hidden_dim) 35 | # self.fc_pi = nn.Linear(hidden_dim,action_dim) 36 | self.fc_v = nn.Linear(hidden_dim,1) 37 | 38 | self.cdt = CDT(learner_args).to(self.device) 39 | self.pi = lambda x: self.cdt.forward(x, LogProb=False)[1] 40 | 41 | self.optimizer = optim.Adam(list(self.parameters())+list(self.cdt.parameters()), lr=learning_rate) 42 | 43 | def v(self, x): 44 | if isinstance(x, (np.ndarray, np.generic) ): 45 | x = torch.tensor(x) 46 | x = F.relu(self.fc1(x)) 47 | v = self.fc_v(x) 48 | return v 49 | 50 | def put_data(self, transition): 51 | self.data.append(transition) 52 | 53 | def make_batch(self): 54 | s_lst, a_lst, r_lst, s_prime_lst, prob_a_lst, done_lst = [], [], [], [], [], [] 55 | for transition in self.data: 56 | s, a, r, s_prime, prob_a, done = transition 57 | 58 | s_lst.append(s) 59 | a_lst.append([a]) 60 | r_lst.append([r]) 61 | s_prime_lst.append(s_prime) 62 | prob_a_lst.append([prob_a]) 63 | done_mask = 0 if done else 1 64 | done_lst.append([done_mask]) 65 | 66 | s,a,r,s_prime,done_mask, prob_a = torch.tensor(s_lst, dtype=torch.float).to(self.device), torch.tensor(a_lst).to(self.device), \ 67 | torch.tensor(r_lst).to(self.device), torch.tensor(s_prime_lst, dtype=torch.float).to(self.device), \ 68 | torch.tensor(done_lst, dtype=torch.float).to(self.device), torch.tensor(prob_a_lst).to(self.device) 69 | self.data = [] 70 | return s, a, r, s_prime, done_mask, prob_a 71 | 72 | def train_net(self): 73 | s, a, r, s_prime, done_mask, prob_a = self.make_batch() 74 | 75 | for i in range(K_epoch): 76 | td_target = r + gamma * self.v(s_prime) * done_mask 77 | delta = td_target - self.v(s) 78 | delta = delta.detach() 79 | 80 | advantage_lst = [] 81 | advantage = 0.0 82 | for delta_t in torch.flip(delta, [0]): 83 | advantage = gamma * lmbda * advantage + delta_t[0] 84 | advantage_lst.append([advantage]) 85 | advantage_lst.reverse() 86 | advantage = torch.tensor(advantage_lst, dtype=torch.float).to(self.device) 87 | 88 | pi = self.pi(s) 89 | pi_a = pi.gather(1,a) 90 | ratio = torch.exp(torch.log(pi_a) - torch.log(prob_a)) # a/b == exp(log(a)-log(b)) 91 | surr1 = ratio * advantage 92 | surr2 = torch.clamp(ratio, 1-eps_clip, 1+eps_clip) * advantage 93 | loss = -torch.min(surr1, surr2) + F.smooth_l1_loss(self.v(s) , td_target.detach()) 94 | 95 | self.optimizer.zero_grad() 96 | loss.mean().backward() 97 | self.optimizer.step() 98 | 99 | def choose_action(self, s, Greedy=False): 100 | prob = self.pi(torch.from_numpy(s).unsqueeze(0).float().to(self.device)).squeeze() # make sure input state shape is correct 101 | if Greedy: 102 | a = torch.argmax(prob, dim=-1).item() 103 | return a 104 | else: 105 | m = Categorical(prob) 106 | a = m.sample().item() 107 | return a, prob 108 | 109 | def load_model(self, ): 110 | self.load_state_dict(torch.load(self.model_path)) 111 | 112 | def run(EnvName, learner_args, train=False, test=False): 113 | # env = StateNormWrapper(gym.make(EnvName), file_name="../rl/rl_train.json") 114 | env = gym.make(EnvName) 115 | state_dim = env.observation_space.shape[0] 116 | action_dim = env.action_space.n # discrete 117 | print(state_dim, action_dim) 118 | model = PPO(state_dim, action_dim, learner_args).to(learner_args['device']) 119 | print_interval = 20 120 | if test: 121 | model.load_model() 122 | rewards_list=[] 123 | for n_epi in range(Episodes): 124 | s = env.reset() 125 | done = False 126 | reward = 0.0 127 | step=0 128 | while not done and step < 1000: 129 | a, prob = model.choose_action(s) 130 | s_prime, r, done, info = env.step(a) 131 | if test: 132 | env.render() 133 | model.put_data((s, a, r/100.0, s_prime, prob[a].item(), done)) 134 | # model.put_data((s, a, r, s_prime, prob[a].item(), done)) 135 | 136 | s = s_prime 137 | 138 | reward += r 139 | step+=1 140 | if done: 141 | break 142 | if train: 143 | model.train_net() 144 | rewards_list.append(reward) 145 | if train: 146 | if n_epi%print_interval==0 and n_epi!=0: 147 | # plot(rewards_list) 148 | np.save(learner_args['log_path'], rewards_list) 149 | torch.save(model.state_dict(), learner_args['model_path']) 150 | print("# of episode :{}, reward : {:.1f}, episode length: {}".format(n_epi, reward, step)) 151 | else: 152 | print("# of episode :{}, reward : {:.1f}, episode length: {}".format(n_epi, reward, step)) 153 | 154 | env.close() 155 | 156 | if __name__ == '__main__': 157 | 158 | parser = argparse.ArgumentParser(description='Train or test neural net motor controller.') 159 | parser.add_argument('--depth1', dest='feature_learning_depth', default=False) 160 | parser.add_argument('--depth2', dest='decision_depth', default=False) 161 | parser.add_argument('--train', dest='train', action='store_true', default=False) 162 | parser.add_argument('--test', dest='test', action='store_true', default=False) 163 | parser.add_argument('--id', dest='id', default=False) 164 | 165 | args = parser.parse_args() 166 | 167 | env = gym.make(EnvName) 168 | print('Env info: State space is {}, Action space is {}'.format(env.observation_space, env.action_space)) 169 | state_dim = env.observation_space.shape[0] 170 | action_dim = env.action_space.n # discrete 171 | env.close() 172 | 173 | learner_args = { 174 | 'num_intermediate_variables': 2, 175 | 'feature_learning_depth': int(args.feature_learning_depth), 176 | 'decision_depth': int(args.decision_depth), 177 | 'input_dim': 4, 178 | 'output_dim': 2, 179 | 'lr': 1e-3, 180 | 'weight_decay': 0., # 5e-4 181 | 'batch_size': 1280, 182 | 'exp_scheduler_gamma': 1., 183 | 'cuda': False, 184 | 'episodes': 40, 185 | 'log_interval': 100, 186 | 'greatest_path_probability': True, 187 | 'beta_fl' : False, # temperature for feature learning 188 | 'beta_dc' : False, # temperature for decision making 189 | } 190 | 191 | file_name=EnvName+'_depth_'+args.feature_learning_depth+'_'+args.decision_depth+'_id'+str(args.id) 192 | learner_args['model_path'] = '../../data/cdt/model/rl_'+file_name 193 | learner_args['log_path'] = '../../data/cdt/log/rl_'+file_name 194 | learner_args['device'] = torch.device('cuda' if learner_args['cuda'] else 'cpu') 195 | 196 | if args.train: 197 | run(EnvName, learner_args, train=True, test=False) 198 | if args.test: 199 | run(EnvName, learner_args, train=False, test=True) 200 | -------------------------------------------------------------------------------- /src/hdt/HDT_lunarlander.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | '''' 3 | A decision tree of heuristic agent for LunarLander-v2 environment 4 | The heuristic agent is provided at: 5 | https://github.com/openai/gym/blob/master/gym/envs/box2d/lunar_lander.py 6 | for both continuous and discrete cases, but in this script it's only for discrete. 7 | The decision tree here consists of a main tree and a sub tree with hierarchical 8 | structure and intermediate features. 9 | ''' 10 | import torch 11 | import torch.nn as nn 12 | import numpy as np 13 | import gym 14 | 15 | # hyperparameters 16 | EnvName = 'LunarLander-v2' 17 | 18 | 19 | def dict_to_vector(dict, dim=8): 20 | v = np.zeros(dim+1) # the last dim is bias 21 | for key, value in dict.items(): 22 | v[key]=value 23 | return v 24 | 25 | 26 | ############ The main tree ######### 27 | 28 | # In main tree: a list contains weights of all nodes: decision nodes (dict) + leaf nodes (list) 29 | node_list = [ 30 | {6:1, 7:1}, # key: index of non-zero weight; value: value of correspongding weight 31 | {0:0.5, 2:1, 8:-0.4}, # last dim (dim 8) is bias 32 | {0:-0.5, 2:-1, 8:-0.4}, 33 | {0:1}, 34 | {0:1}, 35 | {0:1}, 36 | [9*[0], [0,0,0,-0.5, 0,0,0,0,0]], # [at, ht] 37 | 38 | [[0,0,0,0,-0.5,-1,0,0,0.2], [0.275, -0.5, 0,-0.5,0,0,0,0,0]], 39 | [[0,0,0,0,-0.5,-1,0,0,0.2], [-0.275, -0.5, 0,-0.5,0,0,0,0,0]], 40 | [[0,0,0,0,-0.5,-1,0,0,-0.2], [0.275, -0.5, 0,-0.5,0,0,0,0,0]], 41 | [[0,0,0,0,-0.5,-1,0,0,-0.2], [-0.275, -0.5, 0,-0.5,0,0,0,0,0]], 42 | [[0.25, 0,0.5,0,-0.5, -1,0,0,0], [0.275, -0.5, 0,-0.5,0,0,0,0,0]], 43 | [[0.25, 0,0.5,0,-0.5, -1,0,0,0], [-0.275, -0.5, 0,-0.5,0,0,0,0,0]], 44 | ] 45 | 46 | # In main tree: a list indicating indices of left and right child for each node 47 | child_list = [ 48 | [6,1], 49 | [3,2], 50 | [4,5], 51 | [7,8], 52 | [9,10], 53 | [11,12] 54 | ] 55 | 56 | 57 | class Node(object): 58 | """ Node in main tree """ 59 | def __init__(self, id, weights, left_child_id, right_child_id): 60 | super(Node, self).__init__() 61 | self.id = id 62 | self.weights = weights 63 | self.left_child_id = left_child_id 64 | self.right_child_id = right_child_id 65 | 66 | def decide(self, aug_x): 67 | prod = np.sum(self.weights*aug_x) # weights include bias 68 | if prod>0: 69 | return [self.left_child_id] 70 | else: 71 | return [self.right_child_id] 72 | 73 | class HeuristicTree(object): 74 | def __init__(self, node_list, child_list): 75 | super(HeuristicTree, self).__init__() 76 | self.node_list=[] 77 | for i, node in enumerate(node_list): 78 | if isinstance(node, dict): # inner node 79 | w = dict_to_vector(node, dim=8) 80 | self.node_list.append(Node(i, w, child_list[i][0], child_list[i][1])) 81 | else: # leaf 82 | self.node_list.append(SubTree(subtree_node_list, subtree_child_list, node[0], node[1], i)) 83 | 84 | def forward(self, x, Info=True): 85 | aug_x = np.concatenate((np.array(x), [1])) 86 | child = self.node_list[0].decide(aug_x)[0] 87 | decision_path=[self.node_list[0].id] 88 | decision_weights=[self.node_list[0].weights] 89 | while True: 90 | last_child = child 91 | info = self.node_list[child].decide(aug_x) 92 | child=info[0] 93 | if isinstance(self.node_list[last_child], SubTree): 94 | break 95 | decision_path.append(last_child) 96 | decision_weights.append(self.node_list[last_child].weights) 97 | 98 | sub_tree_path = info[1] 99 | sub_tree_weights = info[2] 100 | decision_path+=sub_tree_path 101 | decision_weights+=sub_tree_weights 102 | 103 | if Info: 104 | return [child, decision_path, decision_weights] 105 | 106 | else: 107 | return child 108 | 109 | 110 | ############ The sub tree ########### 111 | 112 | # In SubTree: a list contains weights of all nodes: decision nodes (dict) + leaf nodes (int) 113 | subtree_node_list=[ 114 | {0:1}, # key: index of non-zero weight; value: value of correspongding weight 115 | {0:-1, 1:1}, # 0: at, 1: ht 116 | {0:1, 1:1}, 117 | {1:1, 2:-0.05}, # last dim (dim 2) is bias 118 | {0:1, 2:-0.05}, 119 | {1:1, 2:-0.05}, 120 | {0:-1, 2:-0.05}, 121 | {0:1, 2:-0.05}, 122 | {0:-1, 2:-0.05}, 123 | 2, 124 | 1, 125 | 0, 126 | 1, 127 | 0, 128 | 2, 129 | 3, 130 | 0, 131 | 3, 132 | 0 133 | ] 134 | 135 | # In SubTree: a list indicating indices of left and right child for each node 136 | subtree_child_list=[ 137 | [1,2], 138 | [3,4], 139 | [5,6], 140 | [9,7], 141 | [12,13], 142 | [14,8], 143 | [17,18], 144 | [10,11], 145 | [15,16] 146 | ] 147 | 148 | 149 | class SubNode(object): 150 | """ Node in tree (SubTree) """ 151 | def __init__(self, id, weights, at, ht, left_child_id, right_child_id): 152 | super(SubNode, self).__init__() 153 | self.id = id 154 | self.weights = weights[0]*np.array(at)+weights[1]*np.array(ht) 155 | self.weights[-1]+=weights[-1] 156 | self.left_child_id = left_child_id 157 | self.right_child_id = right_child_id 158 | 159 | def decide(self, aug_x): 160 | prod = np.sum(self.weights*aug_x) # weights include bias 161 | if prod>0: 162 | return self.left_child_id 163 | else: 164 | return self.right_child_id 165 | 166 | class Leaf(object): 167 | """ Leaf in tree (SubTree) """ 168 | def __init__(self, id, value): 169 | super(Leaf, self).__init__() 170 | self.id = id 171 | self.value = value 172 | 173 | class SubTree(object): 174 | def __init__(self, node_list, child_list, at, ht, tree_id): 175 | super(SubTree, self).__init__() 176 | self.node_list=[] 177 | for i, node in enumerate(node_list): 178 | sub_id = 'sub_'+str(i) # idex of node on subtree 179 | if isinstance(node, dict): # inner node 180 | w = dict_to_vector(node, dim=2) 181 | self.node_list.append(SubNode(sub_id, w, at, ht, child_list[i][0], child_list[i][1])) 182 | else: # leaf 183 | self.node_list.append(Leaf(sub_id, node)) 184 | 185 | def decide(self, aug_x, Path=False): 186 | child = self.node_list[0].decide(aug_x) 187 | decision_path=[self.node_list[0].id] 188 | weights_list=[self.node_list[0].weights] 189 | while isinstance(self.node_list[child], SubNode): 190 | weights_list.append(self.node_list[child].weights) 191 | decision_path.append(self.node_list[child].id) 192 | child = self.node_list[child].decide(aug_x) 193 | decision_path.append(self.node_list[child].id) # add leaf 194 | return [self.node_list[child].value, decision_path, weights_list] 195 | 196 | 197 | ############ Test ########### 198 | 199 | 200 | def run(model, episodes=1, seed=None): 201 | env = gym.make(EnvName) 202 | import time 203 | if seed: 204 | env.seed(seed) 205 | for n_epi in range(episodes): 206 | s = env.reset() 207 | done = False 208 | reward = 0.0 209 | step=0 210 | while not done: 211 | a = model(s) 212 | print(a) 213 | s_prime, r, done, info = env.step(a) 214 | env.render() 215 | s = s_prime 216 | time.sleep(0.1) 217 | reward += r 218 | step+=1 219 | if done: 220 | break 221 | 222 | print("# of episode :{}, reward : {:.1f}, episode length: {}".format(n_epi, reward, step)) 223 | 224 | 225 | def evaluate(model, episodes=1, frameskip=1, seed=None): 226 | from heuristic_evaluation import normalize 227 | 228 | env = gym.make(EnvName) 229 | if seed: 230 | torch.manual_seed(seed) 231 | np.random.seed(seed) 232 | env.seed(seed) 233 | state_dim = env.observation_space.shape[0] 234 | action_dim = env.action_space.n # discrete 235 | average_weight_list = [] 236 | 237 | for n_epi in range(episodes): 238 | print('Episode: ', n_epi) 239 | average_weight_list_epi = [] 240 | s = env.reset() 241 | done = False 242 | reward = 0.0 243 | step=0 244 | while not done: 245 | info = model(s) 246 | a=info[0] 247 | if step%frameskip==0: 248 | average_weight = np.mean(np.abs(normalize(np.array(info[2])[:, :-1])), axis=0) # take absolute to prevent that positive and negative will counteract 249 | average_weight_list_epi.append(average_weight) 250 | 251 | s_prime, r, done, _ = env.step(a) 252 | # env.render() 253 | s = s_prime 254 | 255 | reward += r 256 | step+=1 257 | if done: 258 | break 259 | 260 | average_weight_list.append(average_weight_list_epi) 261 | print("# of episode :{}, reward : {:.1f}, episode length: {}".format(n_epi, reward, step)) 262 | np.save('data/heuristic_tree_importance.npy', average_weight_list) 263 | 264 | env.close() 265 | 266 | if __name__ == '__main__': 267 | tree = HeuristicTree(node_list, child_list) 268 | ## RL test 269 | # model = lambda x: tree.forward(x, Info=False) 270 | # run(model, episodes=100) 271 | 272 | # tree evaluation 273 | model = lambda x: tree.forward(x, Info=True) 274 | evaluate(model, episodes=1, seed=10) 275 | from sdt_evaluation import plot_importance_single_episode 276 | plot_importance_single_episode(data_path='data/heuristic_tree_importance.npy', save_path='./img/heuristic_tree_importance.png', ) -------------------------------------------------------------------------------- /src/hdt/__init__.py: -------------------------------------------------------------------------------- 1 | from .heuristic_lunarlander import HeuristicAgentLunarLander 2 | 3 | -------------------------------------------------------------------------------- /src/hdt/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantumiracle/Cascading-Decision-Tree/8c8be18031511547d35e540a01262f1c006b9687/src/hdt/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /src/hdt/__pycache__/heuristic_agents.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantumiracle/Cascading-Decision-Tree/8c8be18031511547d35e540a01262f1c006b9687/src/hdt/__pycache__/heuristic_agents.cpython-36.pyc -------------------------------------------------------------------------------- /src/hdt/__pycache__/heuristic_lunarlander.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantumiracle/Cascading-Decision-Tree/8c8be18031511547d35e540a01262f1c006b9687/src/hdt/__pycache__/heuristic_lunarlander.cpython-36.pyc -------------------------------------------------------------------------------- /src/hdt/heuristic_agents.py: -------------------------------------------------------------------------------- 1 | import sys, math 2 | import numpy as np 3 | 4 | import gym 5 | from gym import spaces 6 | 7 | class HeuristicAgent(): 8 | def __init__(self,): 9 | pass 10 | def choose_action(self,): 11 | pass 12 | 13 | class HeuristicAgentLunarLander(HeuristicAgent): 14 | """ 15 | Heuristic agent for LunarLander environment created by Oleg Klimov. Licensed on the same terms as the rest of OpenAI Gym. 16 | """ 17 | 18 | def __init__(self, Continuous=False): 19 | super(HeuristicAgent, self).__init__() 20 | self.continuous = Continuous 21 | 22 | def choose_action(self, s, DIST = False): 23 | # Heuristic for: 24 | # 1. Testing. 25 | # 2. Demonstration rollout. 26 | angle_targ = s[0]*0.5 + s[2]*1.0 # angle should point towards center (s[0] is horizontal coordinate, s[2] hor speed) 27 | if angle_targ > 0.4: angle_targ = 0.4 # more than 0.4 radians (22 degrees) is bad 28 | if angle_targ < -0.4: angle_targ = -0.4 29 | hover_targ = 0.55*np.abs(s[0]) # target y should be proporional to horizontal offset 30 | 31 | # PID controller: s[4] angle, s[5] angularSpeed 32 | angle_todo = (angle_targ - s[4])*0.5 - (s[5])*1.0 33 | #print("angle_targ=%0.2f, angle_todo=%0.2f" % (angle_targ, angle_todo)) 34 | 35 | # PID controller: s[1] vertical coordinate s[3] vertical speed 36 | hover_todo = (hover_targ - s[1])*0.5 - (s[3])*0.5 37 | #print("hover_targ=%0.2f, hover_todo=%0.2f" % (hover_targ, hover_todo)) 38 | 39 | if s[6] or s[7]: # legs have contact 40 | angle_todo = 0 41 | hover_todo = -(s[3])*0.5 # override to reduce fall speed, that's all we need after contact 42 | 43 | if self.continuous: 44 | a = np.array( [hover_todo*20 - 1, -angle_todo*20] ) 45 | a = np.clip(a, -1, +1) 46 | else: 47 | a = 0 # do nothing 48 | if hover_todo > np.abs(angle_todo) and hover_todo > 0.05: a = 2 # fire main 49 | elif angle_todo < -0.05: a = 3 # fire right 50 | elif angle_todo > +0.05: a = 1 # fire left 51 | return a, None 52 | 53 | 54 | 55 | class HeuristicAgentCartPole(HeuristicAgent): 56 | """ 57 | Heuristic agent for CartPole environment. 58 | Ref: https://github.com/ZhiqingXiao/OpenAIGymSolution/blob/master/CartPole-v0 59 | """ 60 | def __init__(self,): 61 | super(HeuristicAgent, self).__init__() 62 | 63 | def choose_action(self, s): 64 | position, velocity, angle, angle_velocity = s 65 | action = int(3. * angle + angle_velocity > 0.) 66 | return action, None 67 | 68 | 69 | 70 | def run(env, agent, episodes=10, render=False, verbose=False): 71 | for _ in range(episodes): 72 | step=0 73 | observation = env.reset() 74 | episode_reward = 0. 75 | while True: 76 | if render: 77 | env.render() 78 | action, _ = agent.choose_action(observation) 79 | observation, reward, done, _ = env.step(action) 80 | episode_reward += reward 81 | step+=1 82 | if done: 83 | break 84 | if verbose: 85 | print('get {} rewards in {} steps'.format( 86 | episode_reward, step)) 87 | 88 | env.close() 89 | return episode_reward 90 | 91 | if __name__ == '__main__': 92 | # EnvName = 'CartPole-v1' 93 | EnvName = 'LunarLander-v2' 94 | 95 | np.random.seed(0) 96 | env = gym.make(EnvName) 97 | env.seed(0) 98 | agent = eval('HeuristicAgent'+EnvName.split('-')[0])() 99 | run(env, agent, render=True, verbose=True) -------------------------------------------------------------------------------- /src/hdt/heuristic_lunarlander.py: -------------------------------------------------------------------------------- 1 | import sys, math 2 | import numpy as np 3 | 4 | import gym 5 | from gym import spaces 6 | 7 | class HeuristicAgentLunarLander(): 8 | """ 9 | Heuristic agent for LunarLander environment created by Oleg Klimov. Licensed on the same terms as the rest of OpenAI Gym. 10 | """ 11 | 12 | def __init__(self, env, Continuous): 13 | super(HeuristicAgentLunarLander, self).__init__() 14 | self.continuous = Continuous 15 | 16 | def choose_action(self, s, DIST = False): 17 | # Heuristic for: 18 | # 1. Testing. 19 | # 2. Demonstration rollout. 20 | angle_targ = s[0]*0.5 + s[2]*1.0 # angle should point towards center (s[0] is horizontal coordinate, s[2] hor speed) 21 | if angle_targ > 0.4: angle_targ = 0.4 # more than 0.4 radians (22 degrees) is bad 22 | if angle_targ < -0.4: angle_targ = -0.4 23 | hover_targ = 0.55*np.abs(s[0]) # target y should be proporional to horizontal offset 24 | 25 | # PID controller: s[4] angle, s[5] angularSpeed 26 | angle_todo = (angle_targ - s[4])*0.5 - (s[5])*1.0 27 | #print("angle_targ=%0.2f, angle_todo=%0.2f" % (angle_targ, angle_todo)) 28 | 29 | # PID controller: s[1] vertical coordinate s[3] vertical speed 30 | hover_todo = (hover_targ - s[1])*0.5 - (s[3])*0.5 31 | #print("hover_targ=%0.2f, hover_todo=%0.2f" % (hover_targ, hover_todo)) 32 | 33 | if s[6] or s[7]: # legs have contact 34 | angle_todo = 0 35 | hover_todo = -(s[3])*0.5 # override to reduce fall speed, that's all we need after contact 36 | 37 | if self.continuous: 38 | a = np.array( [hover_todo*20 - 1, -angle_todo*20] ) 39 | a = np.clip(a, -1, +1) 40 | else: 41 | a = 0 # do nothing 42 | if hover_todo > np.abs(angle_todo) and hover_todo > 0.05: a = 2 # fire main 43 | elif angle_todo < -0.05: a = 3 # fire right 44 | elif angle_todo > +0.05: a = 1 # fire left 45 | return a, None 46 | 47 | -------------------------------------------------------------------------------- /src/il/il.json: -------------------------------------------------------------------------------- 1 | { "data_collect_confs": { 2 | "episodes" : 10000, 3 | "t_horizon" : 1000, 4 | "data_path" : "../data/il/samples/", 5 | "img_path" : "../data/il/imgs/" 6 | } 7 | } 8 | 9 | -------------------------------------------------------------------------------- /src/il_data_collect.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import argparse 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import torch 6 | from rl import PPO 7 | import json 8 | import pickle 9 | from hdt import HeuristicAgentLunarLander 10 | 11 | filename = "./il/il.json" 12 | with open(filename, "r") as read_file: 13 | il_confs = json.load(read_file) # hyperparameters for rl training 14 | 15 | def collect_demo(env, agent, seed=None, render=False, collect_data=False): 16 | """ 17 | Collect demonstrations. 18 | """ 19 | env.seed(seed) 20 | total_reward_list=[] 21 | a_list=[] 22 | s_list=[] 23 | for i in range(il_confs["data_collect_confs"]["episodes"]): 24 | print('Episode: ', i) 25 | total_reward = 0 26 | steps = 0 27 | s = env.reset() 28 | while steps < il_confs["data_collect_confs"]["t_horizon"]: 29 | a, _ = agent.choose_action(s) 30 | s_list.append(s) 31 | a_list.append([a]) 32 | s, r, done, info = env.step(a) 33 | total_reward += r 34 | 35 | if render and not collect_data : 36 | still_open = env.render() 37 | if still_open == False: break 38 | 39 | steps += 1 40 | if done: break 41 | 42 | print("# of episode :{}, reward : {:.1f}, episode length: {}".format(i, total_reward, steps)) 43 | 44 | total_reward_list.append(total_reward) 45 | print('Average reward: {}'.format(np.mean(total_reward_list))) 46 | np.save(il_confs["data_collect_confs"]["data_path"]+env.spec.id.split("-")[0].lower()+'/state', s_list) 47 | np.save(il_confs["data_collect_confs"]["data_path"]+env.spec.id.split("-")[0].lower()+'/action', a_list) 48 | return total_reward 49 | 50 | def norm_state(env): 51 | ''' normalize data ''' 52 | file_name="./rl/rl.json" 53 | with open(file_name, "r") as read_file: 54 | general_rl_confs = json.load(read_file) # hyperparameters for rl training 55 | print(env.spec.id) 56 | data_path_prefix = general_rl_confs["data_collect_confs"]["data_path"]+env.spec.id.split("-")[0].lower() 57 | with open(data_path_prefix+'/state_info.pkl', 'rb') as f: 58 | state_stats=pickle.load(f) 59 | 60 | states_data_path = il_confs["data_collect_confs"]["data_path"]+env.spec.id.split("-")[0].lower()+'/state' 61 | states = np.load(states_data_path+'.npy') 62 | mean = state_stats['mean'] 63 | std = state_stats['std'] 64 | states = (states-mean)/std 65 | 66 | np.save(states_data_path+'_norm', states) 67 | 68 | 69 | if __name__ == '__main__': 70 | EnvName = 'CartPole-v1' 71 | # EnvName = 'LunarLander-v2' 72 | 73 | env = gym.make(EnvName) 74 | if EnvName == 'LunarLander-v2': # the heuristic agent exists for LunarLander 75 | agent = HeuristicAgentLunarLander(env, Continuous=False) 76 | elif EnvName == 'CartPole-v1': # no heuristic agent for CartPole, so use a well-trained RL agent 77 | filename = "./mlp/mlp_rl_train.json" 78 | with open(filename, "r") as read_file: 79 | rl_confs = json.load(read_file) # hyperparameters for rl training 80 | state_dim = env.observation_space.shape[0] 81 | action_dim = env.action_space.n # discrete 82 | agent = PPO(state_dim, action_dim, 'MLP', rl_confs[EnvName]["learner_args"], \ 83 | **rl_confs[EnvName]["alg_confs"]).to(torch.device(rl_confs[EnvName]["learner_args"]["device"])) 84 | agent.load_model(rl_confs[EnvName]["train_confs"]["model_path"]) 85 | 86 | # collect_demo(env, agent, render=False, collect_data = False) 87 | norm_state(env) 88 | 89 | 90 | -------------------------------------------------------------------------------- /src/il_eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ Discretize the (soft) differentiable tree into normal decision tree according to DDT paper""" 3 | import torch 4 | import torch.nn as nn 5 | import sys 6 | sys.path.append("..") 7 | from utils.dataset import Dataset 8 | import numpy as np 9 | import copy 10 | from utils.common_func import onehot_coding 11 | from cdt import discretize_cdt, CDT 12 | from sdt import discretize_sdt, SDT 13 | import json 14 | import argparse 15 | 16 | 17 | def discretization_evaluation(tree, device, discretized_tree, data_path): 18 | # Load data 19 | input_path = data_path+'state.npy' 20 | label_path = data_path+'action.npy' 21 | 22 | # a data loader with all data in dataset 23 | test_loader = torch.utils.data.DataLoader(Dataset(input_path, label_path, partition='test', ToTensor=True), 24 | batch_size=1280, 25 | shuffle=True) 26 | accuracy_list=[] 27 | accuracy_list_=[] 28 | correct=0. 29 | correct_=0. 30 | for batch_idx, (data, target) in enumerate(test_loader): 31 | data, target = data.to(device), target.to(device) 32 | target_onehot = onehot_coding(target, device, tree.args['output_dim']) 33 | prediction, _, _ = tree.forward(data) 34 | prediction_, _, _ = discretized_tree.forward(data) 35 | with torch.no_grad(): 36 | pred = prediction.data.max(1)[1] 37 | correct += pred.eq(target.view(-1).data).sum() 38 | pred_ = prediction_.data.max(1)[1] 39 | correct_ += pred_.eq(target.view(-1).data).sum() 40 | accuracy = 100. * float(correct) / len(test_loader.dataset) 41 | accuracy_ = 100. * float(correct_) / len(test_loader.dataset) 42 | print('Original Tree Accuracy: {:.4f} | Discretized Tree Accuracy: {:.4f}'.format(accuracy, accuracy_)) 43 | 44 | 45 | if __name__ == '__main__': 46 | parser = argparse.ArgumentParser( 47 | description='Imitation learning evaluation.') 48 | 49 | parser.add_argument('--env', 50 | dest='EnvName', 51 | action='store', 52 | default=None) 53 | 54 | parser.add_argument('--method', 55 | dest='METHOD', 56 | action='store', 57 | default=None) 58 | 59 | args = parser.parse_args() 60 | 61 | METHOD = args.METHOD # one of: 'cdt', 'sdt' 62 | 63 | if METHOD == 'cdt': 64 | filename = "./cdt/cdt_il_train.json" 65 | elif METHOD == 'sdt': 66 | filename = "./sdt/sdt_il_train.json" 67 | else: 68 | raise NotImplementedError 69 | 70 | EnvName = args.EnvName 71 | 72 | with open(filename, "r") as read_file: 73 | il_confs = json.load(read_file) # hyperparameters for rl training 74 | 75 | general_filename = "./il/il.json" 76 | with open(general_filename, "r") as read_file: 77 | general_il_confs = json.load(read_file) # hyperparameters for rl training 78 | 79 | discretize_type=[True, True] 80 | device = torch.device('cuda') 81 | 82 | for idx in range(1,6): 83 | # add id 84 | model_path = il_confs[EnvName]["learner_args"]["model_path"]+str(idx) 85 | log_path = il_confs[EnvName]["learner_args"]["log_path"]+str(idx) 86 | 87 | if METHOD == 'cdt': 88 | tree = CDT(il_confs[EnvName]["learner_args"]).to(device) 89 | tree.load_model(model_path) 90 | discretized_tree = discretize_cdt(tree, FL=discretize_type[0], DC=discretize_type[1]).to(device) 91 | elif METHOD == 'sdt': 92 | tree = SDT(il_confs[EnvName]["learner_args"]).to(device) 93 | tree.load_model(model_path) 94 | discretized_tree = discretize_sdt(tree).to(device) 95 | else: 96 | raise NotImplementedError 97 | 98 | data_path = general_il_confs["data_collect_confs"]["data_path"]+EnvName.split("-")[0].lower()+'/' 99 | discretization_evaluation(tree, device, discretized_tree, data_path) 100 | 101 | discretized_tree.save_model(model_path = model_path+'_discretized') 102 | 103 | -------------------------------------------------------------------------------- /src/il_train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import sys 5 | sys.path.append("..") 6 | from utils.dataset import Dataset 7 | import numpy as np 8 | from torch.utils.tensorboard import SummaryWriter 9 | from utils.heuristic_evaluation import difference_metric, intermediate_features_in_heuristic_tree 10 | from utils.common_func import onehot_coding 11 | from cdt import CDT 12 | from sdt import SDT 13 | import argparse 14 | import json 15 | 16 | 17 | def train_tree(tree, device, data_path, learner_args): 18 | criterion = nn.NLLLoss( 19 | ) # since we already have log probability, simply using Negative Log-likelihood loss can provide cross-entropy loss 20 | 21 | # Load data 22 | input_path = data_path + 'state_norm.npy' 23 | label_path = data_path + 'action.npy' 24 | train_loader = torch.utils.data.DataLoader( 25 | Dataset(input_path, label_path, partition='train'), 26 | batch_size=learner_args['batch_size'], 27 | shuffle=True, 28 | ) 29 | 30 | test_loader = torch.utils.data.DataLoader( 31 | Dataset(input_path, label_path, partition='test'), 32 | batch_size=learner_args['batch_size'], 33 | shuffle=True, 34 | ) 35 | 36 | # Utility variables 37 | best_testing_acc = 0. 38 | testing_acc_list = [] 39 | 40 | for epoch in range(1, learner_args['epochs'] + 1): 41 | epoch_training_loss_list = [] 42 | epoch_feature_difference_list = [] 43 | 44 | # Training stage 45 | tree.train() 46 | for batch_idx, (data, target) in enumerate(train_loader): 47 | data, target = data.to(device), target.to(device) 48 | target_onehot = onehot_coding(target, device, 49 | learner_args['output_dim']) 50 | prediction, output, penalty = tree.forward(data) 51 | 52 | difference = 0 53 | 54 | tree.optimizer.zero_grad() 55 | loss = criterion(output, target.view(-1)) 56 | loss += penalty 57 | loss.backward() 58 | tree.optimizer.step() 59 | 60 | # Print intermediate training status 61 | if batch_idx % learner_args['log_interval'] == 0: 62 | with torch.no_grad(): 63 | pred = prediction.data.max(1)[1] 64 | correct = pred.eq(target.view(-1).data).sum() 65 | loss = criterion(output, target.view(-1)) 66 | epoch_training_loss_list.append( 67 | loss.detach().cpu().data.numpy()) 68 | print( 69 | 'Epoch: {:02d} | Batch: {:03d} | CrossEntropy-loss: {:.5f} | Correct: {}/{} | Difference: {}' 70 | .format(epoch, batch_idx, loss.data, correct, 71 | output.size()[0], difference)) 72 | 73 | tree.save_model(model_path=learner_args['model_path']) 74 | 75 | # intermediate_features = tree.fl_leaf_weights.detach().cpu().numpy() 76 | # difference = difference_metric(intermediate_features, list2=np.array(intermediate_features_in_heuristic_tree)[:, 1:]) # remove the constants for intermediate feature in heuristic 77 | # epoch_feature_difference_list.append(difference) 78 | 79 | # Testing stage 80 | tree.eval() 81 | correct = 0. 82 | for batch_idx, (data, target) in enumerate(test_loader): 83 | data, target = data.to(device), target.to(device) 84 | batch_size = data.size()[0] 85 | prediction, _, _ = tree.forward(data) 86 | pred = prediction.data.max(1)[1] 87 | correct += pred.eq(target.view(-1).data).sum() 88 | accuracy = 100. * float(correct) / len(test_loader.dataset) 89 | if accuracy > best_testing_acc: 90 | best_testing_acc = accuracy 91 | testing_acc_list.append(accuracy) 92 | print( 93 | '\nEpoch: {:02d} | Testing Accuracy: {}/{} ({:.3f}%) | Historical Best: {:.3f}% \n' 94 | .format(epoch, correct, len(test_loader.dataset), accuracy, 95 | best_testing_acc)) 96 | 97 | # log data 98 | # np.save(learner_args['log_path']+'_diff', epoch_feature_difference_list) 99 | np.save(learner_args['log_path'] + '_acc', testing_acc_list) 100 | print('Best Testing Accuracy in Training: {}'.format(best_testing_acc)) 101 | 102 | 103 | if __name__ == '__main__': 104 | parser = argparse.ArgumentParser( 105 | description='Imitation learning training.') 106 | parser.add_argument('--env', 107 | dest='EnvName', 108 | action='store', 109 | default=None) 110 | parser.add_argument('--method', 111 | dest='METHOD', 112 | action='store', 113 | default=None) 114 | 115 | parser.add_argument('--id', dest='id', action='store', default=0) 116 | 117 | args = parser.parse_args() 118 | 119 | METHOD = args.METHOD # one of: 'cdt', 'sdt' 120 | 121 | if METHOD == 'cdt': 122 | filename = "./cdt/cdt_il_train.json" 123 | elif METHOD == 'sdt': 124 | filename = "./sdt/sdt_il_train.json" 125 | else: 126 | raise NotImplementedError 127 | 128 | EnvName = args.EnvName 129 | 130 | with open(filename, "r") as read_file: 131 | il_confs = json.load(read_file) # hyperparameters for rl training 132 | 133 | general_filename = "./il/il.json" 134 | with open(general_filename, "r") as read_file: 135 | general_il_confs = json.load( 136 | read_file) # hyperparameters for rl training 137 | 138 | device = torch.device(il_confs[EnvName]["learner_args"]["device"]) 139 | 140 | # add id 141 | il_confs[EnvName]["learner_args"]["model_path"] = il_confs[EnvName][ 142 | "learner_args"]["model_path"] + args.id 143 | il_confs[EnvName]["learner_args"][ 144 | "log_path"] = il_confs[EnvName]["learner_args"]["log_path"] + args.id 145 | 146 | if METHOD == 'cdt': 147 | tree = CDT(il_confs[EnvName]["learner_args"]).to(device) 148 | elif METHOD == 'sdt': 149 | tree = SDT(il_confs[EnvName]["learner_args"]).to(device) 150 | else: 151 | raise NotImplementedError 152 | data_path = general_il_confs["data_collect_confs"]["data_path"] + EnvName.split( 153 | "-")[0].lower() + '/' 154 | train_tree(tree, device, data_path, il_confs[EnvName]["learner_args"]) 155 | -------------------------------------------------------------------------------- /src/il_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #--mem=20G 4 | #--gres=gpu:0 5 | 6 | hostname 7 | echo $CUDA_VISIBLE_DEVICES 8 | 9 | min=1 10 | max=5 11 | inter=1 12 | # declare an array to loop through 13 | declare -a methods=("sdt" "cdt") 14 | declare -a envs=("CartPole-v1" "LunarLander-v2" "MountainCar-v0") 15 | 16 | 17 | ## now loop through the above array 18 | for env in "${envs[@]}"; 19 | do 20 | for method in "${methods[@]}"; 21 | do 22 | for ((i=min; i <= max; i+=inter)); 23 | do 24 | echo python3 il_train.py --env="$env" --method="$method" --id="$i" 25 | python3 il_train.py --env="$env" --method="$method" --id="$i" 26 | done 27 | done 28 | done -------------------------------------------------------------------------------- /src/mlp/mlp_rl_train.json: -------------------------------------------------------------------------------- 1 | { "General": { 2 | "policy_approx" : "MLP" 3 | }, 4 | "CartPole-v1": { 5 | "learner_args": { 6 | "device": "cuda" 7 | }, 8 | "alg_confs": { 9 | "learning_rate" : 0.0005, 10 | "gamma" : 0.98, 11 | "lmbda" : 0.95, 12 | "eps_clip" : 0.1, 13 | "K_epoch" : 3, 14 | "hidden_dim" : 128 15 | }, 16 | "train_confs": { 17 | "episodes" : 3000, 18 | "t_horizon" : 1000, 19 | "model_path" : "../data/mlp/model/cartpole/ppo", 20 | "log_path" : "../data/mlp/log/cartpole/reward" 21 | } 22 | }, 23 | 24 | "LunarLander-v2": { 25 | "learner_args": { 26 | "device": "cuda" 27 | }, 28 | "alg_confs": { 29 | "learning_rate" : 0.0005, 30 | "gamma" : 0.98, 31 | "lmbda" : 0.95, 32 | "eps_clip" : 0.1, 33 | "K_epoch" : 3, 34 | "hidden_dim" : 128 35 | }, 36 | "train_confs": { 37 | "episodes" : 5000, 38 | "t_horizon" : 1000, 39 | "model_path" : "../data/mlp/model/lunarlander/ppo", 40 | "log_path" : "../data/mlp/log/lunarlander/reward" 41 | } 42 | }, 43 | 44 | "MountainCar-v0": { 45 | "learner_args": { 46 | "device": "cuda" 47 | }, 48 | "alg_confs": { 49 | "learning_rate" : 0.005, 50 | "gamma" : 0.999, 51 | "lmbda" : 0.98, 52 | "eps_clip" : 0.1, 53 | "K_epoch" : 10, 54 | "hidden_dim" : 32 55 | }, 56 | "train_confs": { 57 | "episodes" : 5000, 58 | "t_horizon" : 1000, 59 | "model_path" : "../data/mlp/model/mountaincar/ppo", 60 | "log_path" : "../data/mlp/log/mountaincar/reward" 61 | } 62 | }, 63 | 64 | "Acrobot-v1": { 65 | "learner_args": { 66 | "device": "cuda" 67 | }, 68 | "alg_confs": { 69 | "learning_rate" : 0.0005, 70 | "gamma" : 0.98, 71 | "lmbda" : 0.95, 72 | "eps_clip" : 0.1, 73 | "K_epoch" : 3, 74 | "hidden_dim" : 128 75 | }, 76 | "train_confs": { 77 | "episodes" : 7000, 78 | "t_horizon" : 1000, 79 | "model_path" : "../data_no_norm/mlp/model/acrobot/ppo", 80 | "log_path" : "../data_no_norm/mlp/log/acrobot/reward" 81 | } 82 | } 83 | 84 | 85 | 86 | } 87 | -------------------------------------------------------------------------------- /src/rl/.ipynb_checkpoints/state_statistics-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "def plot(rewards):\n", 10 | " # clear_output(True)\n", 11 | " plt.figure(figsize=(10,5))\n", 12 | " plt.plot(rewards)\n", 13 | " plt.savefig('ppo_discrete_lunarlandar.png')\n", 14 | " # plt.show()\n", 15 | " plt.clf() \n", 16 | " plt.close()" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "import json\n", 26 | "filename = \"./rl_train.json\"\n", 27 | "with open(filename, \"r\") as read_file:\n", 28 | " rl_confs = json.load(read_file) # hyperparameters for rl training\n" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 11, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "name": "stdout", 38 | "output_type": "stream", 39 | "text": [ 40 | "state data shape: (1450500, 4)\n", 41 | "state data mean: [2.49940274e-01 2.92654985e-02 9.50181107e-05 1.42541021e-04] and std: [0.15910504 0.1479451 0.01019282 0.18987567]\n" 42 | ] 43 | } 44 | ], 45 | "source": [ 46 | "import numpy as np\n", 47 | "from collections import namedtuple\n", 48 | "\n", 49 | "EnvName = 'CartPole-v1' \n", 50 | "# EnvName = 'LunarLander-v2'\n", 51 | "\n", 52 | "data_path_prefix = rl_confs[EnvName][\"data_collect_confs\"][\"data_path\"]\n", 53 | "states = np.load(data_path_prefix+'greedy_state.npy')\n", 54 | "print('state data shape: ', states.shape)\n", 55 | "mean = np.mean(states, axis=0)\n", 56 | "std = np.std(states, axis=0)\n", 57 | "print('state data mean: {} and std: {}'.format(mean, std))\n", 58 | "state_info = {'env_name': EnvName,\n", 59 | " 'mean': mean,\n", 60 | " 'std': std}\n", 61 | "\n", 62 | "np.save(data_path_prefix+'state_info', state_info)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [] 71 | } 72 | ], 73 | "metadata": { 74 | "kernelspec": { 75 | "display_name": "Python 3", 76 | "language": "python", 77 | "name": "python3" 78 | }, 79 | "language_info": { 80 | "codemirror_mode": { 81 | "name": "ipython", 82 | "version": 3 83 | }, 84 | "file_extension": ".py", 85 | "mimetype": "text/x-python", 86 | "name": "python", 87 | "nbconvert_exporter": "python", 88 | "pygments_lexer": "ipython3", 89 | "version": "3.6.8" 90 | } 91 | }, 92 | "nbformat": 4, 93 | "nbformat_minor": 2 94 | } 95 | -------------------------------------------------------------------------------- /src/rl/.ipynb_checkpoints/test-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Test json" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 5, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "data": { 17 | "text/plain": [ 18 | "{'learning_rate': 0.0005,\n", 19 | " 'gamma': 0.98,\n", 20 | " 'lmbda': 0.95,\n", 21 | " 'eps_clip': 0.1,\n", 22 | " 'K_epoch': 3,\n", 23 | " 'Episodes': 1000}" 24 | ] 25 | }, 26 | "execution_count": 5, 27 | "metadata": {}, 28 | "output_type": "execute_result" 29 | } 30 | ], 31 | "source": [ 32 | "import json\n", 33 | "def json_read_file(filename):\n", 34 | " '''\n", 35 | " @brief:\n", 36 | " read data from json file\n", 37 | " @params:\n", 38 | " filename\n", 39 | " @return:\n", 40 | " (dict) parsed json file\n", 41 | " '''\n", 42 | " with open(filename, \"r\") as read_file:\n", 43 | " return json.load(read_file)\n", 44 | " \n", 45 | "CONF_FILE = json_read_file(\"./rl_train.json\")\n", 46 | "CONF_FILE[\"CartPole-v1\"]" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "def plot(rewards):\n", 56 | " # clear_output(True)\n", 57 | " plt.figure(figsize=(10,5))\n", 58 | " plt.plot(rewards)\n", 59 | " plt.savefig('ppo_discrete_lunarlandar.png')\n", 60 | " # plt.show()\n", 61 | " plt.clf() \n", 62 | " plt.close()" 63 | ] 64 | } 65 | ], 66 | "metadata": { 67 | "kernelspec": { 68 | "display_name": "Python 3", 69 | "language": "python", 70 | "name": "python3" 71 | }, 72 | "language_info": { 73 | "codemirror_mode": { 74 | "name": "ipython", 75 | "version": 3 76 | }, 77 | "file_extension": ".py", 78 | "mimetype": "text/x-python", 79 | "name": "python", 80 | "nbconvert_exporter": "python", 81 | "pygments_lexer": "ipython3", 82 | "version": "3.6.8" 83 | } 84 | }, 85 | "nbformat": 4, 86 | "nbformat_minor": 2 87 | } 88 | -------------------------------------------------------------------------------- /src/rl/PPO.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch.distributions import Categorical 7 | import argparse 8 | import numpy as np 9 | import sys 10 | sys.path.append("..") 11 | from cdt import CDT 12 | from sdt import SDT 13 | 14 | class PolicyMLP(nn.Module): 15 | def __init__(self, state_dim, action_dim, hidden_dim): 16 | super(PolicyMLP, self).__init__() 17 | self.fc1 = nn.Linear(state_dim, hidden_dim) 18 | self.fc2 = nn.Linear(hidden_dim, action_dim) 19 | 20 | def forward(self, x, softmax_dim = -1): 21 | x = F.relu(self.fc1(x)) 22 | x = self.fc2(x) 23 | prob = F.softmax(x, dim=softmax_dim) 24 | return prob 25 | 26 | class PPO(nn.Module): 27 | def __init__(self, state_dim, action_dim, policy_approx = None, learner_args={}, **kwargs): 28 | super(PPO, self).__init__() 29 | self.learning_rate = kwargs['learning_rate'] 30 | self.gamma = kwargs['gamma'] 31 | self.lmbda = kwargs['lmbda'] 32 | self.eps_clip = kwargs['eps_clip'] 33 | self.K_epoch = kwargs['K_epoch'] 34 | self.device = torch.device(learner_args['device']) 35 | 36 | hidden_dim = kwargs['hidden_dim'] 37 | 38 | self.data = [] 39 | if policy_approx == 'MLP': 40 | self.policy = PolicyMLP(state_dim, action_dim, hidden_dim).to(self.device) 41 | self.pi = lambda x: self.policy.forward(x, softmax_dim=-1) 42 | elif policy_approx == 'SDT': 43 | self.policy = SDT(learner_args).to(self.device) 44 | self.pi = lambda x: self.policy.forward(x, LogProb=False)[1] 45 | elif policy_approx == 'CDT': 46 | self.policy = CDT(learner_args).to(self.device) 47 | self.pi = lambda x: self.policy.forward(x, LogProb=False)[1] 48 | else: 49 | raise NotImplementedError 50 | 51 | self.fc1 = nn.Linear(state_dim,hidden_dim) 52 | self.fc_v = nn.Linear(hidden_dim,1) 53 | 54 | self.optimizer = optim.Adam(list(self.parameters())+list(self.policy.parameters()), lr=self.learning_rate) 55 | 56 | def v(self, x): 57 | if isinstance(x, (np.ndarray, np.generic) ): 58 | x = torch.tensor(x) 59 | x = F.relu(self.fc1(x)) 60 | v = self.fc_v(x) 61 | return v 62 | 63 | def put_data(self, transition): 64 | self.data.append(transition) 65 | 66 | def make_batch(self): 67 | s_lst, a_lst, r_lst, s_prime_lst, prob_a_lst, done_lst = [], [], [], [], [], [] 68 | for transition in self.data: 69 | s, a, r, s_prime, prob_a, done = transition 70 | 71 | s_lst.append(s) 72 | a_lst.append([a]) 73 | r_lst.append([r]) 74 | s_prime_lst.append(s_prime) 75 | prob_a_lst.append([prob_a]) 76 | done_mask = 0 if done else 1 77 | done_lst.append([done_mask]) 78 | 79 | s,a,r,s_prime,done_mask, prob_a = torch.tensor(s_lst, dtype=torch.float).to(self.device), torch.tensor(a_lst).to(self.device), \ 80 | torch.tensor(r_lst).to(self.device), torch.tensor(s_prime_lst, dtype=torch.float).to(self.device), \ 81 | torch.tensor(done_lst, dtype=torch.float).to(self.device), torch.tensor(prob_a_lst).to(self.device) 82 | self.data = [] 83 | return s, a, r, s_prime, done_mask, prob_a 84 | 85 | def train_net(self): 86 | s, a, r, s_prime, done_mask, prob_a = self.make_batch() 87 | 88 | for i in range(self.K_epoch): 89 | td_target = r + self.gamma * self.v(s_prime) * done_mask 90 | delta = td_target - self.v(s) 91 | delta = delta.detach() 92 | 93 | advantage_lst = [] 94 | advantage = 0.0 95 | for delta_t in torch.flip(delta, [0]): 96 | advantage = self.gamma * self.lmbda * advantage + delta_t[0] 97 | advantage_lst.append([advantage]) 98 | advantage_lst.reverse() 99 | advantage = torch.tensor(advantage_lst, dtype=torch.float).to(self.device) 100 | 101 | pi = self.pi(s) 102 | pi_a = pi.gather(1,a) 103 | ratio = torch.exp(torch.log(pi_a) - torch.log(prob_a)) # a/b == exp(log(a)-log(b)) 104 | 105 | surr1 = ratio * advantage 106 | surr2 = torch.clamp(ratio, 1-self.eps_clip, 1+self.eps_clip) * advantage 107 | loss = -torch.min(surr1, surr2) + F.smooth_l1_loss(self.v(s) , td_target.detach()) 108 | 109 | self.optimizer.zero_grad() 110 | loss.mean().backward() 111 | self.optimizer.step() 112 | 113 | def choose_action(self, s, Greedy=False): 114 | prob = self.pi(torch.from_numpy(s).unsqueeze(0).float().to(self.device)).squeeze() # make sure input state shape is correct 115 | if Greedy: 116 | a = torch.argmax(prob, dim=-1).item() 117 | return a 118 | else: 119 | m = Categorical(prob) 120 | a = m.sample().item() 121 | return a, prob 122 | 123 | def load_model(self, path=None): 124 | self.load_state_dict(torch.load(path)) 125 | 126 | -------------------------------------------------------------------------------- /src/rl/PPO.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantumiracle/Cascading-Decision-Tree/8c8be18031511547d35e540a01262f1c006b9687/src/rl/PPO.pyc -------------------------------------------------------------------------------- /src/rl/__init__.py: -------------------------------------------------------------------------------- 1 | from .PPO import PPO 2 | from .env_wrapper import StateNormWrapper -------------------------------------------------------------------------------- /src/rl/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantumiracle/Cascading-Decision-Tree/8c8be18031511547d35e540a01262f1c006b9687/src/rl/__init__.pyc -------------------------------------------------------------------------------- /src/rl/env_wrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import spaces 3 | import gym 4 | import json 5 | import pickle 6 | 7 | class StateNormWrapper(gym.Wrapper): 8 | """ 9 | Normalize state value for environments. 10 | """ 11 | def __init__(self, env, file_name): 12 | super(StateNormWrapper, self).__init__(env) 13 | with open(file_name, "r") as read_file: 14 | rl_confs = json.load(read_file) # hyperparameters for rl training 15 | print(env.spec.id) 16 | data_path_prefix = rl_confs["data_collect_confs"]["data_path"]+env.spec.id.split("-")[0].lower()+'/' 17 | with open(data_path_prefix+'state_info.pkl', 'rb') as f: 18 | self.state_stats=pickle.load(f) 19 | 20 | def norm(self, s): 21 | mean = self.state_stats['mean'] 22 | std = self.state_stats['std'] 23 | s = (s-mean)/std 24 | return s 25 | 26 | 27 | def step(self, a): 28 | observation, reward, done, info = self.env.step(a) 29 | return self.norm(observation), reward, done, info 30 | 31 | def reset(self, **kwargs): 32 | observation = self.env.reset(**kwargs) 33 | return self.norm(observation) 34 | 35 | def render(self, **kwargs): 36 | pass 37 | 38 | 39 | if __name__ == '__main__': 40 | import matplotlib.pyplot as plt 41 | 42 | # test 43 | # EnvName = 'CartPole-v1' 44 | EnvName = 'LunarLander-v2' 45 | 46 | env = StateNormWrapper(gym.make(EnvName), file_name="rl_train.json") 47 | 48 | for _ in range(10): 49 | env.reset() 50 | for _ in range(1000): 51 | # env.render() 52 | a = env.action_space.sample() 53 | s, r, d, _ = env.step(a) # take a random action 54 | if d: 55 | break 56 | print(s) 57 | # print(s.shape) 58 | env.close() 59 | 60 | 61 | -------------------------------------------------------------------------------- /src/rl/rl.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_collect_confs": { 3 | "episodes" : 3000, 4 | "data_path" : "../data/rl/samples/" 5 | } 6 | } -------------------------------------------------------------------------------- /src/rl_data_collect.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import argparse 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from rl.PPO import PPO 6 | import json 7 | 8 | filename = "./mlp/mlp_rl_train.json" 9 | with open(filename, "r") as read_file: 10 | rl_confs = json.load(read_file) 11 | 12 | filename = "./rl/rl.json" 13 | with open(filename, "r") as read_file: 14 | general_rl_confs = json.load(read_file) # hyperparameters for rl training 15 | 16 | def collect_data(EnvName, learner_args, epi): 17 | env = gym.make(EnvName) 18 | print('Env: ', env.spec.id) 19 | state_dim = env.observation_space.shape[0] 20 | action_dim = env.action_space.n # discrete 21 | model = PPO(state_dim, action_dim, policy_approx='MLP', learner_args=learner_args, **rl_confs[EnvName]["alg_confs"]) 22 | model.load_model(rl_confs[EnvName]["train_confs"]["model_path"]) 23 | data_path_prefix = general_rl_confs["data_collect_confs"]["data_path"]+env.spec.id.split("-")[0].lower() 24 | a_list=[] 25 | s_list=[] 26 | prob_list=[] 27 | for n_epi in range(epi): 28 | print('Episode: ', n_epi) 29 | s = env.reset() 30 | done = False 31 | reward = 0.0 32 | step=0 33 | while step<1000: 34 | # a, prob=model.choose_action(s) # uncomment these if wanna collect output probability rather than action only 35 | a = model.choose_action(s, Greedy=True) 36 | s_list.append(s) 37 | a_list.append([a]) 38 | # prob_list.append(prob.detach().cpu().numpy()) 39 | s, r, done, info = env.step(a) 40 | step+=1 41 | if done: 42 | break 43 | if n_epi % 100 == 0: 44 | np.save(data_path_prefix+'/greedy_state', s_list) 45 | np.save(data_path_prefix+'/greedy_action', a_list) 46 | # np.save(env.spec.id+'_ppo_prob', prob_list) 47 | env.close() 48 | 49 | 50 | if __name__ == '__main__': 51 | # EnvName = 'CartPole-v1' 52 | # EnvName = 'LunarLander-v2' 53 | EnvName = 'MountainCar-v0' 54 | 55 | learner_args = {'device': 'cpu'} 56 | 57 | collect_data(EnvName, learner_args, epi=3000) -------------------------------------------------------------------------------- /src/rl_eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ Discretize the (soft) differentiable tree into normal decision tree according to DDT paper""" 3 | import torch 4 | import torch.nn as nn 5 | import os 6 | from utils.dataset import Dataset 7 | import numpy as np 8 | import copy 9 | from utils.common_func import onehot_coding 10 | from cdt import discretize_cdt, CDT 11 | from sdt import discretize_sdt, SDT 12 | from cdt.cdt_plot import draw_tree, get_path 13 | from rl import PPO 14 | import json 15 | import argparse 16 | import gym 17 | 18 | 19 | def evaluation(tree, 20 | device, 21 | episodes=100, 22 | frameskip=1, 23 | seed=None, 24 | DrawTree=None, 25 | img_path=None, 26 | log_path=None): 27 | model = lambda x: tree.forward(x)[0].data.max(1)[1].squeeze().detach().cpu().numpy() 28 | env = gym.make(EnvName) 29 | if seed: 30 | env.seed(seed) 31 | state_dim = env.observation_space.shape[0] 32 | action_dim = env.action_space.n # discrete 33 | if not os.path.exists(img_path): 34 | os.makedirs(img_path) 35 | average_weight_list = [] 36 | reward_list = [] 37 | 38 | # show values on tree nodes 39 | # print(tree.state_dict()) 40 | # show probs on tree leaves 41 | # softmax = nn.Softmax(dim=-1) 42 | # print(softmax(tree.state_dict()['dc_leaves']).detach().cpu().numpy()) 43 | 44 | for n_epi in range(episodes): 45 | print('Episode: ', n_epi) 46 | average_weight_list_epi = [] 47 | s = env.reset() 48 | done = False 49 | reward = 0.0 50 | step = 0 51 | while not done: 52 | a = model(torch.Tensor([s]).to(device)) 53 | if step % frameskip == 0: 54 | if DrawTree is not None: 55 | draw_tree(tree, 56 | input_img=s, 57 | DrawTree=DrawTree, 58 | savepath=img_path + '_' + DrawTree + 59 | '/{:04}.png'.format(step)) 60 | 61 | s_prime, r, done, info = env.step(a) 62 | # env.render() 63 | s = s_prime 64 | 65 | reward += r 66 | step += 1 67 | if done: 68 | break 69 | reward_list.append(reward) 70 | 71 | average_weight_list.append(average_weight_list_epi) 72 | print("# of episode :{}, reward : {:.1f}, episode length: {}".format( 73 | n_epi, reward, step)) 74 | 75 | np.save(log_path, reward_list) 76 | 77 | env.close() 78 | 79 | 80 | if __name__ == '__main__': 81 | parser = argparse.ArgumentParser( 82 | description='Reinforcement learning evaluation.') 83 | 84 | parser.add_argument('--env', 85 | dest='EnvName', 86 | action='store', 87 | default=None) 88 | 89 | parser.add_argument('--method', 90 | dest='METHOD', 91 | action='store', 92 | default=None) 93 | 94 | args = parser.parse_args() 95 | 96 | METHOD = args.METHOD # one of: 'cdt', 'sdt' 97 | 98 | if METHOD == 'cdt': 99 | filename = "./cdt/cdt_rl_train.json" 100 | elif METHOD == 'sdt': 101 | filename = "./sdt/sdt_rl_train.json" 102 | else: 103 | raise NotImplementedError 104 | 105 | EnvName = args.EnvName 106 | 107 | with open(filename, "r") as read_file: 108 | rl_confs = json.load(read_file) # hyperparameters for rl training 109 | 110 | with open('./rl/rl.json', "r") as read_file: 111 | general_rl_confs = json.load(read_file) # hyperparameters for rl training 112 | 113 | discretize_type = [True, True] # for feature learning tree and decision making tree respectively, True means discretization 114 | device = torch.device('cuda') 115 | 116 | env = gym.make(EnvName) 117 | state_dim = env.observation_space.shape[0] 118 | action_dim = env.action_space.n # discrete 119 | 120 | for idx in range(1, 6): 121 | # add id 122 | model_path = rl_confs[EnvName]["train_confs"]["model_path"] + str(idx) 123 | log_path = rl_confs[EnvName]["train_confs"]["log_path"] + '_eval' + str(idx) 124 | discretized_log_path = rl_confs[EnvName]["train_confs"]["log_path"] + '_eval_discretized' + str(idx) 125 | 126 | model = PPO(state_dim, action_dim, rl_confs["General"]["policy_approx"], rl_confs[EnvName]["learner_args"], \ 127 | **rl_confs[EnvName]["alg_confs"]).to(device) 128 | model.load_model(model_path) 129 | tree = model.policy 130 | 131 | if METHOD == 'cdt': 132 | discretized_tree = discretize_cdt(tree, 133 | FL=discretize_type[0], 134 | DC=discretize_type[1]).to(device) 135 | elif METHOD == 'sdt': 136 | discretized_tree = discretize_sdt(tree).to(device) 137 | else: 138 | raise NotImplementedError 139 | 140 | evaluation(tree, device, log_path = log_path, img_path=general_rl_confs["data_collect_confs"]["data_path"]+EnvName.split("-")[0].lower()+'/imgs/') 141 | evaluation(discretized_tree, device, log_path = discretized_log_path, img_path=general_rl_confs["data_collect_confs"]["data_path"]+EnvName.split("-")[0].lower()+'/imgs/discretized_') 142 | 143 | discretized_tree.save_model(rl_confs[EnvName]["train_confs"]["model_path"] + '_discretized' + str(idx)) 144 | -------------------------------------------------------------------------------- /src/rl_train.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import torch 3 | import argparse 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import json 7 | from rl import PPO 8 | from rl import StateNormWrapper 9 | 10 | def run(EnvName, 11 | rl_confs, 12 | mode=None, 13 | episodes=1000, 14 | t_horizon=1000, 15 | model_path=None, 16 | log_path=None): 17 | env = StateNormWrapper(gym.make(EnvName), file_name="./rl/rl.json") # for state normalization 18 | env = gym.make(EnvName) 19 | state_dim = env.observation_space.shape[0] 20 | action_dim = env.action_space.n # discrete 21 | model = PPO(state_dim, action_dim, rl_confs["General"]["policy_approx"], rl_confs[EnvName]["learner_args"], \ 22 | **rl_confs[EnvName]["alg_confs"]).to(torch.device(rl_confs[EnvName]["learner_args"]["device"])) 23 | print_interval = 20 24 | if mode == 'test': 25 | model.load_model(model_path) 26 | rewards_list = [] 27 | for n_epi in range(episodes): 28 | s = env.reset() 29 | done = False 30 | reward = 0.0 31 | step = 0 32 | while not done and step < t_horizon: 33 | if mode == 'train': 34 | a, prob = model.choose_action(s) 35 | else: 36 | a = model.choose_action(s, Greedy=True) 37 | # a, prob=model.choose_action(s) 38 | 39 | s_prime, r, done, info = env.step(a) 40 | 41 | if mode == 'test': 42 | env.render() 43 | else: 44 | model.put_data( 45 | (s, a, r / 100.0, s_prime, prob[a].item(), done)) 46 | # model.put_data((s, a, r, s_prime, prob[a].item(), done)) 47 | 48 | s = s_prime 49 | 50 | reward += r 51 | step += 1 52 | if done: 53 | break 54 | if mode == 'train': 55 | model.train_net() 56 | if n_epi % print_interval == 0 and n_epi != 0: 57 | # plot(rewards_list) 58 | np.save(log_path, rewards_list) 59 | torch.save(model.state_dict(), model_path) 60 | print("# of episode :{}, reward : {:.1f}, episode length: {}". 61 | format(n_epi, reward, step)) 62 | else: 63 | print( 64 | "# of episode :{}, reward : {:.1f}, episode length: {}".format( 65 | n_epi, reward, step)) 66 | rewards_list.append(reward) 67 | env.close() 68 | 69 | 70 | if __name__ == '__main__': 71 | parser = argparse.ArgumentParser( 72 | description='Reinforcement learning training.') 73 | parser.add_argument('--train', 74 | dest='train', 75 | action='store_true', 76 | default=False) 77 | parser.add_argument('--test', 78 | dest='test', 79 | action='store_true', 80 | default=False) 81 | 82 | parser.add_argument('--env', 83 | dest='EnvName', 84 | action='store', 85 | default=None) 86 | 87 | parser.add_argument('--method', 88 | dest='METHOD', 89 | action='store', 90 | default=None) 91 | 92 | parser.add_argument('--id', 93 | dest='id', 94 | action='store', 95 | default=0) 96 | 97 | args = parser.parse_args() 98 | 99 | METHOD = args.METHOD # one of: 'mlp', 'cdt', 'sdt' 100 | 101 | if METHOD == 'mlp': 102 | filename = "./mlp/mlp_rl_train.json" 103 | elif METHOD == 'cdt': 104 | filename = "./cdt/cdt_rl_train.json" 105 | elif METHOD == 'sdt': 106 | filename = "./sdt/sdt_rl_train.json" 107 | else: 108 | raise NotImplementedError 109 | 110 | with open(filename, "r") as read_file: 111 | rl_confs = json.load(read_file) # hyperparameters for rl training 112 | 113 | EnvName = args.EnvName 114 | 115 | # add id 116 | rl_confs[EnvName]["train_confs"]["model_path"] = rl_confs[EnvName]["train_confs"]["model_path"]+args.id 117 | rl_confs[EnvName]["train_confs"]["log_path"] = rl_confs[EnvName]["train_confs"]["log_path"]+args.id 118 | 119 | if args.train: 120 | run(EnvName, 121 | rl_confs, 122 | mode='train', 123 | **rl_confs[EnvName]["train_confs"]) 124 | if args.test: 125 | run(EnvName, 126 | rl_confs, 127 | mode='test', 128 | **rl_confs[EnvName]["train_confs"]) 129 | -------------------------------------------------------------------------------- /src/rl_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #--mem=20G 4 | #--gres=gpu:0 5 | 6 | hostname 7 | echo $CUDA_VISIBLE_DEVICES 8 | 9 | min=1 10 | max=3 11 | inter=1 12 | # declare an array to loop through 13 | declare -a methods=("sdt" "cdt" "mlp") 14 | declare -a envs=("CartPole-v1" "LunarLander-v2" "MountainCar-v0") 15 | 16 | ## now loop through the above array 17 | for env in "${envs[@]}"; 18 | do 19 | for method in "${methods[@]}"; 20 | do 21 | for ((i=min; i <= max; i+=inter)); 22 | do 23 | echo python3 rl_train.py --train --env="$env" --method="$method" --id="$i" 24 | python3 rl_train.py --train --env="$env" --method="$method" --id="$i" 25 | done 26 | done 27 | done 28 | -------------------------------------------------------------------------------- /src/rl_train_compare_cdt.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import torch 3 | import argparse 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import json 7 | from rl import PPO 8 | from rl import StateNormWrapper 9 | 10 | def run(EnvName, 11 | rl_confs, 12 | mode=None, 13 | episodes=1000, 14 | t_horizon=1000, 15 | model_path=None, 16 | log_path=None): 17 | env = StateNormWrapper(gym.make(EnvName), file_name="./rl/rl.json") # for state normalization 18 | env = gym.make(EnvName) 19 | state_dim = env.observation_space.shape[0] 20 | action_dim = env.action_space.n # discrete 21 | model = PPO(state_dim, action_dim, rl_confs["General"]["policy_approx"], rl_confs[EnvName]["learner_args"], \ 22 | **rl_confs[EnvName]["alg_confs"]).to(torch.device(rl_confs[EnvName]["learner_args"]["device"])) 23 | print_interval = 20 24 | if mode == 'test': 25 | model.load_model(model_path) 26 | rewards_list = [] 27 | for n_epi in range(episodes): 28 | s = env.reset() 29 | done = False 30 | reward = 0.0 31 | step = 0 32 | while not done and step < t_horizon: 33 | if mode == 'train': 34 | a, prob = model.choose_action(s) 35 | else: 36 | a = model.choose_action(s, Greedy=True) 37 | # a, prob=model.choose_action(s) 38 | 39 | s_prime, r, done, info = env.step(a) 40 | 41 | if mode == 'test': 42 | env.render() 43 | else: 44 | model.put_data( 45 | (s, a, r / 100.0, s_prime, prob[a].item(), done)) 46 | # model.put_data((s, a, r, s_prime, prob[a].item(), done)) 47 | 48 | s = s_prime 49 | 50 | reward += r 51 | step += 1 52 | if done: 53 | break 54 | if mode == 'train': 55 | model.train_net() 56 | if n_epi % print_interval == 0 and n_epi != 0: 57 | # plot(rewards_list) 58 | np.save(log_path, rewards_list) 59 | torch.save(model.state_dict(), model_path) 60 | print("# of episode :{}, reward : {:.1f}, episode length: {}". 61 | format(n_epi, reward, step)) 62 | else: 63 | print( 64 | "# of episode :{}, reward : {:.1f}, episode length: {}".format( 65 | n_epi, reward, step)) 66 | rewards_list.append(reward) 67 | env.close() 68 | 69 | 70 | if __name__ == '__main__': 71 | parser = argparse.ArgumentParser( 72 | description='Reinforcement learning training.') 73 | parser.add_argument('--train', 74 | dest='train', 75 | action='store_true', 76 | default=False) 77 | parser.add_argument('--test', 78 | dest='test', 79 | action='store_true', 80 | default=False) 81 | 82 | parser.add_argument('--env', 83 | dest='EnvName', 84 | action='store', 85 | default=None) 86 | 87 | parser.add_argument('--method', 88 | dest='METHOD', 89 | action='store', 90 | default=None) 91 | 92 | parser.add_argument('--id', 93 | dest='id', 94 | action='store', 95 | default=0) 96 | 97 | parser.add_argument('--fl_depth', 98 | dest='feature_learning_depth', 99 | action='store', 100 | default=2) 101 | 102 | parser.add_argument('--dm_depth', 103 | dest='decision_depth', 104 | action='store', 105 | default=2) 106 | 107 | args = parser.parse_args() 108 | 109 | METHOD = args.METHOD # one of: 'cdt', 'sdt' 110 | 111 | filename = "./cdt/cdt_rl_train_compare.json" 112 | 113 | with open(filename, "r") as read_file: 114 | rl_confs = json.load(read_file) # hyperparameters for rl training 115 | 116 | EnvName = args.EnvName 117 | 118 | rl_confs[EnvName]["learner_args"]["feature_learning_depth"]=int(args.feature_learning_depth) 119 | rl_confs[EnvName]["learner_args"]["decision_depth"]=int(args.decision_depth) 120 | 121 | # add id 122 | rl_confs[EnvName]["train_confs"]["model_path"] = rl_confs[EnvName]["train_confs"]["model_path"]+'_'+args.feature_learning_depth+args.decision_depth+'_'+args.id 123 | rl_confs[EnvName]["train_confs"]["log_path"] = rl_confs[EnvName]["train_confs"]["log_path"]+'_'+args.feature_learning_depth+args.decision_depth+'_'+args.id 124 | 125 | if args.train: 126 | run(EnvName, 127 | rl_confs, 128 | mode='train', 129 | **rl_confs[EnvName]["train_confs"]) 130 | if args.test: 131 | run(EnvName, 132 | rl_confs, 133 | mode='test', 134 | **rl_confs[EnvName]["train_confs"]) 135 | -------------------------------------------------------------------------------- /src/rl_train_compare_cdt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #--mem=20G 4 | #--gres=gpu:0 5 | 6 | hostname 7 | echo $CUDA_VISIBLE_DEVICES 8 | 9 | min=1 10 | max=3 11 | inter=1 12 | 13 | ## now loop through the above array 14 | for ((m=1; m <= 3; m+=1)); 15 | do 16 | for ((n=2; n <= 2; n+=1)); 17 | do 18 | for ((i=min; i <= max; i+=inter)); 19 | do 20 | echo python3 rl_train_compare_cdt.py --train --env='CartPole-v1' --method="cdt" --id="$i" --fl_depth="$m" --dm_depth="$n" 21 | python3 rl_train_compare_cdt.py --train --env='CartPole-v1' --method="cdt" --id="$i" --fl_depth="$m" --dm_depth="$n" 22 | done 23 | done 24 | done 25 | -------------------------------------------------------------------------------- /src/rl_train_compare_sdt.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import torch 3 | import argparse 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import json 7 | from rl import PPO 8 | from rl import StateNormWrapper 9 | 10 | def run(EnvName, 11 | rl_confs, 12 | mode=None, 13 | episodes=1000, 14 | t_horizon=1000, 15 | model_path=None, 16 | log_path=None): 17 | env = StateNormWrapper(gym.make(EnvName), file_name="./rl/rl.json") # for state normalization 18 | env = gym.make(EnvName) 19 | state_dim = env.observation_space.shape[0] 20 | action_dim = env.action_space.n # discrete 21 | model = PPO(state_dim, action_dim, rl_confs["General"]["policy_approx"], rl_confs[EnvName]["learner_args"], \ 22 | **rl_confs[EnvName]["alg_confs"]).to(torch.device(rl_confs[EnvName]["learner_args"]["device"])) 23 | print_interval = 20 24 | if mode == 'test': 25 | model.load_model(model_path) 26 | rewards_list = [] 27 | for n_epi in range(episodes): 28 | s = env.reset() 29 | done = False 30 | reward = 0.0 31 | step = 0 32 | while not done and step < t_horizon: 33 | if mode == 'train': 34 | a, prob = model.choose_action(s) 35 | else: 36 | a = model.choose_action(s, Greedy=True) 37 | # a, prob=model.choose_action(s) 38 | 39 | s_prime, r, done, info = env.step(a) 40 | 41 | if mode == 'test': 42 | env.render() 43 | else: 44 | model.put_data( 45 | (s, a, r / 100.0, s_prime, prob[a].item(), done)) 46 | # model.put_data((s, a, r, s_prime, prob[a].item(), done)) 47 | 48 | s = s_prime 49 | 50 | reward += r 51 | step += 1 52 | if done: 53 | break 54 | if mode == 'train': 55 | model.train_net() 56 | if n_epi % print_interval == 0 and n_epi != 0: 57 | # plot(rewards_list) 58 | np.save(log_path, rewards_list) 59 | torch.save(model.state_dict(), model_path) 60 | print("# of episode :{}, reward : {:.1f}, episode length: {}". 61 | format(n_epi, reward, step)) 62 | else: 63 | print( 64 | "# of episode :{}, reward : {:.1f}, episode length: {}".format( 65 | n_epi, reward, step)) 66 | rewards_list.append(reward) 67 | env.close() 68 | 69 | 70 | if __name__ == '__main__': 71 | parser = argparse.ArgumentParser( 72 | description='Reinforcement learning training.') 73 | parser.add_argument('--train', 74 | dest='train', 75 | action='store_true', 76 | default=False) 77 | parser.add_argument('--test', 78 | dest='test', 79 | action='store_true', 80 | default=False) 81 | 82 | parser.add_argument('--env', 83 | dest='EnvName', 84 | action='store', 85 | default=None) 86 | 87 | parser.add_argument('--method', 88 | dest='METHOD', 89 | action='store', 90 | default=None) 91 | 92 | parser.add_argument('--id', 93 | dest='id', 94 | action='store', 95 | default=0) 96 | 97 | parser.add_argument('--depth', 98 | dest='depth', 99 | action='store', 100 | default=2) 101 | 102 | args = parser.parse_args() 103 | 104 | METHOD = args.METHOD # one of: 'mlp', 'cdt', 'sdt' 105 | 106 | filename = "./sdt/sdt_rl_train_compare.json" 107 | 108 | with open(filename, "r") as read_file: 109 | rl_confs = json.load(read_file) # hyperparameters for rl training 110 | 111 | EnvName = args.EnvName 112 | 113 | rl_confs[EnvName]["learner_args"]["depth"]=int(args.depth) 114 | 115 | # add id 116 | rl_confs[EnvName]["train_confs"]["model_path"] = rl_confs[EnvName]["train_confs"]["model_path"]+'_'+args.depth+'_'+args.id 117 | rl_confs[EnvName]["train_confs"]["log_path"] = rl_confs[EnvName]["train_confs"]["log_path"]+'_'+args.depth+'_'+args.id 118 | 119 | if args.train: 120 | run(EnvName, 121 | rl_confs, 122 | mode='train', 123 | **rl_confs[EnvName]["train_confs"]) 124 | if args.test: 125 | run(EnvName, 126 | rl_confs, 127 | mode='test', 128 | **rl_confs[EnvName]["train_confs"]) 129 | -------------------------------------------------------------------------------- /src/rl_train_compare_sdt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #--mem=20G 4 | #--gres=gpu:0 5 | 6 | hostname 7 | echo $CUDA_VISIBLE_DEVICES 8 | 9 | min=1 10 | max=3 11 | inter=1 12 | 13 | ## now loop through the above array 14 | for ((m=2; m <= 4; m+=1)); 15 | do 16 | for ((i=min; i <= max; i+=inter)); 17 | do 18 | echo python3 rl_train_compare_sdt.py --train --env='CartPole-v1' --method="sdt" --id="$i" --depth="$m" 19 | python3 rl_train_compare_sdt.py --train --env='CartPole-v1' --method="sdt" --id="$i" --depth="$m" 20 | done 21 | done 22 | -------------------------------------------------------------------------------- /src/sdt/SDT.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | '''' Soft Decision Tree ''' 3 | import torch 4 | import torch.nn as nn 5 | from collections import OrderedDict 6 | 7 | 8 | class SDT(nn.Module): 9 | """ Soft Desicion Tree """ 10 | def __init__(self, args): 11 | super(SDT, self).__init__() 12 | self.args = args 13 | print('SDT parameters: ', args) 14 | self.device = torch.device(self.args['device']) 15 | self.inner_node_num = 2 ** self.args['depth'] - 1 16 | self.leaf_num = 2 ** self.args['depth'] 17 | self.max_depth = self.args['depth'] 18 | self.max_leaf_idx=None # the leaf index with maximal path probability 19 | 20 | # Different penalty coefficients for nodes in different layer 21 | self.penalty_list = [args['lamda'] * (2 ** (-depth)) for depth in range(0, self.args['depth'])] 22 | 23 | # inner nodes operation 24 | # Initialize inner nodes and leaf nodes (input dimension on innner nodes is added by 1, serving as bias) 25 | self.linear = nn.Linear(self.args['input_dim']+1, self.inner_node_num, bias=False) 26 | self.sigmoid = nn.Sigmoid() 27 | # temperature term 28 | if self.args['beta']: 29 | beta = torch.randn(self.inner_node_num) # use different beta for each node 30 | # beta = torch.randn(1) # or use one beta across all nodes 31 | self.beta = nn.Parameter(beta) 32 | else: 33 | self.beta = torch.ones(1).to(self.device) # or use one beta across all nodes 34 | 35 | # leaf nodes operation 36 | # p*softmax(Q) instead of softmax(p*Q) 37 | param = torch.randn(self.leaf_num, self.args['output_dim']) 38 | self.param = nn.Parameter(param) 39 | self.softmax = nn.Softmax(dim=1) 40 | 41 | self.optimizer = torch.optim.Adam(self.parameters(), lr=self.args['lr'], weight_decay=self.args['weight_decay']) 42 | self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=self.args['exp_scheduler_gamma']) 43 | 44 | def leaf_nodes(self, p): 45 | distribution_per_leaf = self.softmax(self.param) 46 | average_distribution = torch.mm(p, distribution_per_leaf) 47 | return average_distribution 48 | 49 | def inner_nodes(self, x): 50 | self.inner_probs = self.sigmoid(self.beta*self.linear(x)) 51 | return self.inner_probs 52 | 53 | def get_tree_weights(self, Bias=False): 54 | """Return tree weights as a list""" 55 | if Bias: 56 | return self.state_dict()['linear.weight'].detach().cpu().numpy() 57 | else: # no bias 58 | return self.state_dict()['linear.weight'][:, 1:].detach().cpu().numpy() 59 | 60 | 61 | def forward(self, data, LogProb=True, Alpha=False, Weights=False): 62 | _mu, _penalty, _alpha = self._forward(data) 63 | output = self.leaf_nodes(_mu) # average over leaves 64 | 65 | if self.args['greatest_path_probability']: 66 | one_hot_path_probability = torch.zeros(_mu.shape).to(self.device) 67 | vs, ids = torch.max(_mu, 1) # ids is the leaf index with maximal path probability 68 | one_hot_path_probability.scatter_(1, ids.view(-1,1), 1.) 69 | 70 | prediction = self.leaf_nodes(one_hot_path_probability) 71 | self.max_leaf_idx = ids 72 | 73 | else: # prediction value equals to the average distribution 74 | prediction = output 75 | 76 | if LogProb: 77 | output = torch.log(output) 78 | prediction = torch.log(prediction) 79 | 80 | if Weights: 81 | weights = self.get_tree_weights(Bias=True) 82 | 83 | # L1 regularization for feature sparsity on nodes 84 | if self.args['l1_regularization']: 85 | L1_reg = torch.tensor(0., requires_grad=True).to(self.device) 86 | for name, param in self.named_parameters(): 87 | if name == 'linear.weight': 88 | L1_reg = L1_reg + torch.norm(param[:, 1:], 1).to(self.device) # ignore the bias term; L1 norm 89 | 90 | _penalty+=5e-3*L1_reg 91 | 92 | outputs = (prediction, output, _penalty, ) 93 | if Weights: 94 | outputs = outputs + (weights, ) 95 | if Alpha: 96 | outputs = outputs + (_alpha, ) 97 | return outputs 98 | 99 | """ Core implementation on data forwarding in SDT """ 100 | def _forward(self, data): 101 | batch_size = data.size()[0] 102 | data = self._data_augment_(data) 103 | path_prob = self.inner_nodes(data) 104 | path_prob = torch.unsqueeze(path_prob, dim=2) 105 | path_prob = torch.cat((path_prob, 1-path_prob), dim=2) 106 | _mu = data.data.new(batch_size,1,1).fill_(1.) 107 | _penalty = torch.tensor(0.).to(self.device) 108 | 109 | begin_idx = 0 110 | end_idx = 1 111 | _alpha_list=[] 112 | 113 | for layer_idx in range(0, self.args['depth']): 114 | _path_prob = path_prob[:, begin_idx:end_idx, :] 115 | penalty, alpha_list = self._cal_penalty(layer_idx, _mu, _path_prob) 116 | _penalty += penalty # extract inner nodes in current layer to calculate regularization term 117 | _alpha_list = _alpha_list + alpha_list 118 | _mu = _mu.view(batch_size, -1, 1).repeat(1, 1, 2) 119 | _mu = _mu * _path_prob 120 | begin_idx = end_idx # index for each layer 121 | end_idx = begin_idx + 2 ** (layer_idx+1) 122 | mu = _mu.view(batch_size, self.leaf_num) 123 | 124 | # mean value of alpha where it's larger than 0.5, which can describe how unbalance are the decision nodes 125 | half_alpha_list = [i for i in _alpha_list if i > 0.5] 126 | return mu, _penalty, torch.mean(torch.stack(half_alpha_list)).detach().cpu().numpy() # mu contains the path probability for each leaf 127 | 128 | """ Calculate penalty term for inner-nodes in different layer """ 129 | def _cal_penalty(self, layer_idx, _mu, _path_prob): 130 | penalty = torch.tensor(0.).to(self.device) 131 | batch_size = _mu.size()[0] 132 | _mu = _mu.view(batch_size, 2**layer_idx) 133 | _path_prob = _path_prob.view(batch_size, 2**(layer_idx+1)) 134 | alpha_list=[] 135 | for node in range(0, 2**(layer_idx+1)): 136 | numerical_bound = 1e-7 # prevent numerical issue 137 | alpha = torch.sum(_path_prob[:, node]*_mu[:,node//2], dim=0) / (torch.sum(_mu[:,node//2], dim=0) + numerical_bound) # not dividing 0. 138 | origin_alpha=alpha 139 | # if alpha ==1 or alpha == 0, log will cause numerical problem, so alpha should be bounded 140 | alpha = torch.clamp(alpha, numerical_bound, 1-numerical_bound) # no log(negative value) 141 | alpha_list.append(alpha) 142 | if torch.isnan(torch.tensor(alpha_list)).any(): 143 | print(origin_alpha, alpha) 144 | penalty -= self.penalty_list[layer_idx] * 0.5 * (torch.log(alpha) + torch.log(1-alpha)) 145 | return penalty, alpha_list 146 | 147 | """ Add constant 1 onto the front of each instance, serving as the bias """ 148 | def _data_augment_(self, input): 149 | batch_size = input.size()[0] 150 | input = input.view(batch_size, -1) 151 | bias = torch.ones(batch_size, 1).to(self.device) 152 | input = torch.cat((bias, input), 1) 153 | return input 154 | 155 | def save_model(self, model_path, id=''): 156 | torch.save(self.state_dict(), model_path+id) 157 | 158 | def load_model(self, model_path, id=''): 159 | self.load_state_dict(torch.load(model_path+id, map_location='cpu')) 160 | self.eval() 161 | 162 | -------------------------------------------------------------------------------- /src/sdt/__init__.py: -------------------------------------------------------------------------------- 1 | from .SDT import SDT 2 | from .sdt_discretization import discretize_sdt -------------------------------------------------------------------------------- /src/sdt/__pycache__/SDT.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantumiracle/Cascading-Decision-Tree/8c8be18031511547d35e540a01262f1c006b9687/src/sdt/__pycache__/SDT.cpython-36.pyc -------------------------------------------------------------------------------- /src/sdt/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantumiracle/Cascading-Decision-Tree/8c8be18031511547d35e540a01262f1c006b9687/src/sdt/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /src/sdt/__pycache__/sdt_discretization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantumiracle/Cascading-Decision-Tree/8c8be18031511547d35e540a01262f1c006b9687/src/sdt/__pycache__/sdt_discretization.cpython-36.pyc -------------------------------------------------------------------------------- /src/sdt/deprecated/sdt_discretization.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ Discretize the (soft) differentiable tree into normal decision tree according to DDT paper""" 3 | import torch 4 | import torch.nn as nn 5 | import sys 6 | sys.path.append("..") 7 | from utils.dataset import Dataset 8 | import numpy as np 9 | import copy 10 | 11 | def discretize_tree(original_tree): 12 | tree = copy.deepcopy(original_tree) 13 | for name, parameter in tree.named_parameters(): 14 | # print(name) 15 | if name == 'beta': 16 | setattr(tree, name, nn.Parameter(100*torch.ones(parameter.shape))) 17 | 18 | elif name == 'linear.weight': 19 | parameters=[] 20 | # print(parameter) 21 | for weights in parameter: 22 | bias = weights[0] 23 | max_id = np.argmax(np.abs(weights[1:].detach()))+1 24 | max_v = weights[max_id].detach() 25 | new_weights = torch.zeros(weights.shape) 26 | if max_v>0: 27 | new_weights[max_id] = torch.tensor(1) 28 | else: 29 | new_weights[max_id] = torch.tensor(-1) 30 | new_weights[0] = bias/np.abs(max_v) 31 | parameters.append(new_weights) 32 | tree.linear.weight = nn.Parameter(torch.stack(parameters)) 33 | # print(tree.linear.weight.data) 34 | return tree 35 | 36 | def onehot_coding(target, output_dim): 37 | target_onehot = torch.FloatTensor(target.size()[0], output_dim) 38 | target_onehot.data.zero_() 39 | target_onehot.scatter_(1, target.view(-1, 1), 1.) 40 | return target_onehot 41 | 42 | def discretization_evaluation(tree, discretized_tree): 43 | # Load data 44 | # data_dir = '../data/discrete_' 45 | data_dir = '../data/cartpole_greedy_ppo_' 46 | data_path = data_dir+'state.npy' 47 | label_path = data_dir+'action.npy' 48 | 49 | # a data loader with all data in dataset 50 | test_loader = torch.utils.data.DataLoader(Dataset(data_path, label_path, partition='test', ToTensor=True), 51 | batch_size=int(1e4), 52 | shuffle=True) 53 | accuracy_list=[] 54 | accuracy_list_=[] 55 | correct=0. 56 | correct_=0. 57 | for batch_idx, (data, target) in enumerate(test_loader): 58 | # data, target = data.to(device), target.to(device) 59 | target_onehot = onehot_coding(target, tree.args['output_dim']) 60 | prediction, _, _, _ = tree.forward(data) 61 | prediction_, _, _, _ = discretized_tree.forward(data) 62 | with torch.no_grad(): 63 | pred = prediction.data.max(1)[1] 64 | correct += pred.eq(target.view(-1).data).sum() 65 | pred_ = prediction_.data.max(1)[1] 66 | correct_ += pred_.eq(target.view(-1).data).sum() 67 | accuracy = 100. * float(correct) / len(test_loader.dataset) 68 | accuracy_ = 100. * float(correct_) / len(test_loader.dataset) 69 | print('Original Tree Accuracy: {:.4f} | Discretized Tree Accuracy: {:.4f}'.format(accuracy, accuracy_)) 70 | 71 | if __name__ == '__main__': 72 | from sdt_train_cartpole import learner_args 73 | from SDT import SDT 74 | 75 | learner_args['cuda'] = False # cpu 76 | learner_args['depth'] = 4 77 | for i in range(4,7): 78 | learner_args['model_path'] = './model/sdt/'+str(learner_args['depth'])+'_id'+str(i) 79 | 80 | tree = SDT(learner_args) 81 | tree.load_model(learner_args['model_path']) 82 | 83 | discretized_tree = discretize_tree(tree) 84 | discretization_evaluation(tree, discretized_tree) 85 | 86 | discretized_tree.save_model(model_path = learner_args['model_path']+'_discretized') -------------------------------------------------------------------------------- /src/sdt/deprecated/sdt_il_train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import sys 6 | sys.path.append("..") 7 | from SDT import SDT 8 | from utils.dataset import Dataset 9 | import numpy as np 10 | from torch.utils.tensorboard import SummaryWriter 11 | from heuristic_evaluation import difference_metric 12 | import argparse 13 | 14 | __all__ = ["learner_args"] 15 | 16 | parser = argparse.ArgumentParser(description='parse') 17 | parser.add_argument('--depth', dest='depth', default=False) 18 | parser.add_argument('--id', dest='id', default=False) 19 | args = parser.parse_args() 20 | 21 | def onehot_coding(target, device, output_dim): 22 | target_onehot = torch.FloatTensor(target.size()[0], output_dim).to(device) 23 | target_onehot.data.zero_() 24 | target_onehot.scatter_(1, target.view(-1, 1), 1.) 25 | return target_onehot 26 | use_cuda = True 27 | learner_args = {'input_dim': 8, 28 | 'output_dim': 4, 29 | 'depth': int(args.depth), 30 | 'lamda': 1e-3, # 1e-3 31 | 'lr': 1e-3, 32 | 'weight_decay': 0., # 5e-4 33 | 'batch_size': 1280, 34 | 'epochs': 80, 35 | 'cuda': use_cuda, 36 | 'log_interval': 100, 37 | 'exp_scheduler_gamma': 1., 38 | 'beta' : True, # temperature 39 | 'l1_regularization': False, # for feature sparsity on nodes 40 | 'greatest_path_probability': True # when forwarding the SDT, \ 41 | # choose the leaf with greatest path probability or average over distributions of all leaves; \ 42 | # the former one has better explainability while the latter one achieves higher accuracy 43 | } 44 | learner_args['model_path'] = './model/sdt/'+str(learner_args['depth'])+'_id'+str(args.id) 45 | 46 | device = torch.device('cuda' if use_cuda else 'cpu') 47 | 48 | def train_tree(tree): 49 | writer = SummaryWriter(log_dir='runs/sdt_'+str(learner_args['depth'])+'_id'+str(args.id)) 50 | # criterion = nn.CrossEntropyLoss() # torch CrossEntropyLoss = LogSoftmax + NLLLoss 51 | criterion = nn.NLLLoss() # since we already have log probability, simply using Negative Log-likelihood loss can provide cross-entropy loss 52 | 53 | # Load data 54 | data_dir = '../data/discrete_' 55 | data_path = data_dir+'state.npy' 56 | label_path = data_dir+'action.npy' 57 | train_loader = torch.utils.data.DataLoader(Dataset(data_path, label_path, partition='train'), 58 | batch_size=learner_args['batch_size'], 59 | shuffle=True) 60 | 61 | test_loader = torch.utils.data.DataLoader(Dataset(data_path, label_path, partition='test'), 62 | batch_size=learner_args['batch_size'], 63 | shuffle=True) 64 | # Utility variables 65 | best_testing_acc = 0. 66 | testing_acc_list = [] 67 | 68 | for epoch in range(1, learner_args['epochs']+1): 69 | epoch_training_loss_list = [] 70 | epoch_weight_difference_list = [] 71 | 72 | # Training stage 73 | tree.train() 74 | for batch_idx, (data, target) in enumerate(train_loader): 75 | data, target = data.to(device), target.to(device) 76 | target_onehot = onehot_coding(target, device, learner_args['output_dim']) 77 | prediction, output, penalty, weights = tree.forward(data) 78 | difference = difference_metric(weights) 79 | epoch_weight_difference_list.append(difference) 80 | 81 | tree.optimizer.zero_grad() 82 | loss = criterion(output, target.view(-1)) 83 | loss += penalty 84 | loss.backward() 85 | tree.optimizer.step() 86 | 87 | # Print intermediate training status 88 | if batch_idx % learner_args['log_interval'] == 0: 89 | with torch.no_grad(): 90 | pred = prediction.data.max(1)[1] 91 | correct = pred.eq(target.view(-1).data).sum() 92 | loss = criterion(output, target.view(-1)) 93 | epoch_training_loss_list.append(loss.detach().cpu().data.numpy()) 94 | print('Epoch: {:02d} | Batch: {:03d} | CrossEntropy-loss: {:.5f} | Correct: {}/{} | Difference: {}'.format( 95 | epoch, batch_idx, loss.data, correct, output.size()[0], difference)) 96 | 97 | tree.save_model(model_path = learner_args['model_path']) 98 | writer.add_scalar('Training Loss', np.mean(epoch_training_loss_list), epoch) 99 | writer.add_scalar('Training Weight Difference', np.mean(epoch_weight_difference_list), epoch) 100 | 101 | # Testing stage 102 | tree.eval() 103 | correct = 0. 104 | alpha_list=[] 105 | for batch_idx, (data, target) in enumerate(test_loader): 106 | data, target = data.to(device), target.to(device) 107 | batch_size = data.size()[0] 108 | prediction, _, _,_, alpha = tree.forward(data, Alpha=True) 109 | alpha_list.append(alpha) 110 | pred = prediction.data.max(1)[1] 111 | correct += pred.eq(target.view(-1).data).sum() 112 | accuracy = 100. * float(correct) / len(test_loader.dataset) 113 | if accuracy > best_testing_acc: 114 | best_testing_acc = accuracy 115 | testing_acc_list.append(accuracy) 116 | writer.add_scalar('Testing Accuracy', accuracy, epoch) 117 | writer.add_scalar('Testing Alpha', np.mean(alpha_list), epoch) 118 | print('\nEpoch: {:02d} | Testing Accuracy: {}/{} ({:.3f}%) | Historical Best: {:.3f}% \n'.format(epoch, correct, len(test_loader.dataset), accuracy, best_testing_acc)) 119 | 120 | 121 | def test_tree(tree, epochs=10): 122 | criterion = nn.CrossEntropyLoss() 123 | 124 | # Utility variables 125 | best_testing_acc = 0. 126 | testing_acc_list = [] 127 | 128 | # Load data 129 | data_dir = '../data/discrete_' 130 | data_path = data_dir+'state.npy' 131 | label_path = data_dir+'action.npy' 132 | test_loader = torch.utils.data.DataLoader(Dataset(data_path, label_path, partition='test'), 133 | batch_size=learner_args['batch_size'], 134 | shuffle=True) 135 | 136 | for epoch in range(epochs): 137 | # Testing stage 138 | tree.eval() 139 | correct = 0. 140 | for batch_idx, (data, target) in enumerate(test_loader): 141 | data, target = data.to(device), target.to(device) 142 | batch_size = data.size()[0] 143 | prediction, _, _, _ = tree.forward(data) 144 | pred = prediction.data.max(1)[1] 145 | correct += pred.eq(target.view(-1).data).sum() 146 | accuracy = 100. * float(correct) / len(test_loader.dataset) 147 | if accuracy > best_testing_acc: 148 | best_testing_acc = accuracy 149 | testing_acc_list.append(accuracy) 150 | print('\nEpoch: {:02d} | Testing Accuracy: {}/{} ({:.3f}%) | Historical Best: {:.3f}%\n'.format(epoch, correct, len(test_loader.dataset), accuracy, best_testing_acc)) 151 | 152 | 153 | if __name__ == '__main__': 154 | 155 | tree = SDT(learner_args).to(device) 156 | train_tree(tree) 157 | -------------------------------------------------------------------------------- /src/sdt/deprecated/sdt_rl_train.py: -------------------------------------------------------------------------------- 1 | """ PPO with soft decision tree (SDT) as policy function approximator """ 2 | import gym 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torch.distributions import Categorical 8 | import argparse 9 | import numpy as np 10 | from SDT import SDT 11 | 12 | parser = argparse.ArgumentParser(description='Train or test neural net motor controller.') 13 | parser.add_argument('--depth', dest='depth', default=False) 14 | parser.add_argument('--train', dest='train', action='store_true', default=False) 15 | parser.add_argument('--test', dest='test', action='store_true', default=False) 16 | parser.add_argument('--id', dest='id', default=False) 17 | 18 | args = parser.parse_args() 19 | 20 | 21 | #Hyperparameters 22 | learning_rate = 0.0005 23 | gamma = 0.98 24 | lmbda = 0.95 25 | eps_clip = 0.1 26 | K_epoch = 3 27 | Episodes = 5000 # 3000 for CartPole, 5000 for LunarLander 28 | # T_horizon = 20 29 | # EnvName = 'CartPole-v1' # LunarLander-v2 30 | EnvName = 'LunarLander-v2' 31 | path=EnvName+'depth_'+args.depth+'_id'+str(args.id) 32 | model_path = './model/sdt_ppo/'+path 33 | env = gym.make(EnvName) 34 | state_dim = env.observation_space.shape[0] 35 | action_dim = env.action_space.n # discrete 36 | env.close() 37 | 38 | learner_args = {'input_dim': state_dim, 39 | 'output_dim': action_dim, 40 | 'depth': int(args.depth), 41 | 'lamda': 1e-3, # 1e-3 42 | 'lr': 1e-3, 43 | 'weight_decay': 0., # 5e-4 44 | 'batch_size': 1280, 45 | 'epochs': 40, 46 | 'cuda': True, 47 | 'log_interval': 100, 48 | 'exp_scheduler_gamma': 1., 49 | 'beta' : False, # temperature 50 | 'l1_regularization': False, # for feature sparsity on nodes 51 | 'greatest_path_probability': True # when forwarding the SDT, \ 52 | # choose the leaf with greatest path probability or average over distributions of all leaves; \ 53 | # the former one has better explainability while the latter one achieves higher accuracy 54 | } 55 | 56 | device = torch.device('cuda' if learner_args['cuda'] else 'cpu') 57 | 58 | 59 | class PPO(nn.Module): 60 | def __init__(self, state_dim, action_dim): 61 | super(PPO, self).__init__() 62 | self.data = [] 63 | hidden_dim=128 64 | self.fc1 = nn.Linear(state_dim,hidden_dim) 65 | # self.fc_pi = nn.Linear(hidden_dim,action_dim) 66 | self.fc_v = nn.Linear(hidden_dim,1) 67 | 68 | self.sdt = SDT(learner_args).to(device) 69 | self.pi = lambda x: self.sdt.forward(x, LogProb=False)[1] 70 | 71 | self.optimizer = optim.Adam(list(self.parameters())+list(self.sdt.parameters()), lr=learning_rate) 72 | 73 | def v(self, x): 74 | if isinstance(x, (np.ndarray, np.generic) ): 75 | x = torch.tensor(x) 76 | x = F.relu(self.fc1(x)) 77 | v = self.fc_v(x) 78 | return v 79 | 80 | def put_data(self, transition): 81 | self.data.append(transition) 82 | 83 | def make_batch(self): 84 | s_lst, a_lst, r_lst, s_prime_lst, prob_a_lst, done_lst = [], [], [], [], [], [] 85 | for transition in self.data: 86 | s, a, r, s_prime, prob_a, done = transition 87 | 88 | s_lst.append(s) 89 | a_lst.append([a]) 90 | r_lst.append([r]) 91 | s_prime_lst.append(s_prime) 92 | prob_a_lst.append([prob_a]) 93 | done_mask = 0 if done else 1 94 | done_lst.append([done_mask]) 95 | 96 | s,a,r,s_prime,done_mask, prob_a = torch.tensor(s_lst, dtype=torch.float).to(device), torch.tensor(a_lst).to(device), \ 97 | torch.tensor(r_lst).to(device), torch.tensor(s_prime_lst, dtype=torch.float).to(device), \ 98 | torch.tensor(done_lst, dtype=torch.float).to(device), torch.tensor(prob_a_lst).to(device) 99 | self.data = [] 100 | return s, a, r, s_prime, done_mask, prob_a 101 | 102 | def train_net(self): 103 | s, a, r, s_prime, done_mask, prob_a = self.make_batch() 104 | 105 | for i in range(K_epoch): 106 | td_target = r + gamma * self.v(s_prime) * done_mask 107 | delta = td_target - self.v(s) 108 | delta = delta.detach() 109 | 110 | advantage_lst = [] 111 | advantage = 0.0 112 | for delta_t in torch.flip(delta, [0]): 113 | advantage = gamma * lmbda * advantage + delta_t[0] 114 | advantage_lst.append([advantage]) 115 | advantage_lst.reverse() 116 | advantage = torch.tensor(advantage_lst, dtype=torch.float).to(device) 117 | 118 | pi = self.pi(s) 119 | pi_a = pi.gather(1,a) 120 | ratio = torch.exp(torch.log(pi_a) - torch.log(prob_a)) # a/b == exp(log(a)-log(b)) 121 | surr1 = ratio * advantage 122 | surr2 = torch.clamp(ratio, 1-eps_clip, 1+eps_clip) * advantage 123 | loss = -torch.min(surr1, surr2) + F.smooth_l1_loss(self.v(s) , td_target.detach()) 124 | 125 | self.optimizer.zero_grad() 126 | loss.mean().backward() 127 | self.optimizer.step() 128 | 129 | def choose_action(self, s): 130 | prob = self.pi(torch.from_numpy(s).unsqueeze(0).float().to(device)).squeeze() 131 | m = Categorical(prob) 132 | a = m.sample().item() 133 | return a, prob 134 | 135 | def load_model(self, ): 136 | self.load_state_dict(torch.load(model_path)) 137 | 138 | 139 | def run(train=False, test=False): 140 | env = gym.make(EnvName) 141 | state_dim = env.observation_space.shape[0] 142 | action_dim = env.action_space.n # discrete 143 | print(state_dim, action_dim) 144 | model = PPO(state_dim, action_dim).to(device) 145 | print_interval = 20 146 | if test: 147 | model.load_model() 148 | rewards_list=[] 149 | for n_epi in range(Episodes): 150 | s = env.reset() 151 | done = False 152 | reward = 0.0 153 | step=0 154 | while not done: 155 | a, prob = model.choose_action(s) 156 | s_prime, r, done, info = env.step(a) 157 | if test: 158 | env.render() 159 | model.put_data((s, a, r/100.0, s_prime, prob[a].item(), done)) 160 | # model.put_data((s, a, r, s_prime, prob[a].item(), done)) 161 | 162 | s = s_prime 163 | 164 | reward += r 165 | step+=1 166 | if done: 167 | break 168 | if train: 169 | model.train_net() 170 | rewards_list.append(reward) 171 | if train: 172 | if n_epi%print_interval==0 and n_epi!=0: 173 | # plot(rewards_list) 174 | np.save('./log/'+path, rewards_list) 175 | torch.save(model.state_dict(), model_path) 176 | print("# of episode :{}, reward : {:.1f}, episode length: {}".format(n_epi, reward, step)) 177 | else: 178 | print("# of episode :{}, reward : {:.1f}, episode length: {}".format(n_epi, reward, step)) 179 | 180 | env.close() 181 | 182 | if __name__ == '__main__': 183 | if args.train: 184 | run(train=True, test=False) 185 | if args.test: 186 | run(train=False, test=True) 187 | -------------------------------------------------------------------------------- /src/sdt/sdt_discretization.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ Discretize the (soft) differentiable tree into normal decision tree according to DDT paper""" 3 | import torch 4 | import torch.nn as nn 5 | import sys 6 | import numpy as np 7 | import copy 8 | 9 | def discretize_sdt(original_tree): 10 | tree = copy.copy(original_tree) 11 | for name, parameter in tree.named_parameters(): 12 | if name == 'beta': 13 | setattr(tree, name, nn.Parameter(100*torch.ones(parameter.shape))) 14 | 15 | elif name == 'linear.weight': 16 | parameters=[] 17 | for weights in parameter: 18 | bias = weights[0] 19 | max_id = np.argmax(np.abs(weights[1:].detach().cpu().numpy()))+1 20 | max_v = weights[max_id].detach().cpu().numpy() 21 | new_weights = torch.zeros(weights.shape) 22 | if max_v>0: 23 | new_weights[max_id] = torch.tensor(1) 24 | else: 25 | new_weights[max_id] = torch.tensor(-1) 26 | new_weights[0] = bias/np.abs(max_v) 27 | parameters.append(new_weights) 28 | tree.linear.weight = nn.Parameter(torch.stack(parameters)) 29 | return tree 30 | -------------------------------------------------------------------------------- /src/sdt/sdt_evaluation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import gym 6 | from torch.distributions import Categorical 7 | import argparse 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | from sdt_plot import draw_tree, get_path 11 | import sys 12 | sys.path.append("..") 13 | from heuristic_evaluation import normalize 14 | import os 15 | 16 | EnvName = 'CartPole-v1' # LunarLander-v2 17 | # EnvName = 'LunarLander-v2' 18 | 19 | 20 | def evaluate(model, tree, episodes=1, frameskip=1, seed=None, DrawTree=True, DrawImportance=True, WeightedImportance=True, img_path = 'img/eval_tree'): 21 | env = gym.make(EnvName) 22 | if seed: 23 | env.seed(seed) 24 | state_dim = env.observation_space.shape[0] 25 | action_dim = env.action_space.n # discrete 26 | if not os.path.exists(img_path): 27 | os.makedirs(img_path) 28 | tree_weights = tree.get_tree_weights() 29 | average_weight_list = [] 30 | 31 | # show values on tree nodes 32 | print(tree.get_tree_weights(Bias=True)) 33 | # show probs on tree leaves 34 | softmax = nn.Softmax(dim=-1) 35 | print(softmax(tree.state_dict()['param']).detach().cpu().numpy()) 36 | 37 | for n_epi in range(episodes): 38 | print('Episode: ', n_epi) 39 | average_weight_list_epi = [] 40 | s = env.reset() 41 | done = False 42 | reward = 0.0 43 | step=0 44 | while not done: 45 | a = model(torch.Tensor([s])) 46 | if step%frameskip==0: 47 | if DrawTree: 48 | draw_tree(tree, (tree.args['input_dim'],), input_img=s, savepath=img_path+'/{:04}.png'.format(step)) 49 | if DrawImportance: 50 | path_idx, inner_probs = get_path(tree, s, Probs=True) 51 | last_idx=0 52 | probs_on_path = [] 53 | for idx in path_idx[1:]: 54 | if idx == 2*last_idx+1: # parent node goes to left node 55 | probs_on_path.append(inner_probs[last_idx]) 56 | elif idx == 2*last_idx+2: # last index goes to right node, prob should be 1-prob 57 | probs_on_path.append(1-inner_probs[last_idx]) 58 | else: 59 | raise ValueError 60 | last_idx = idx 61 | 62 | weights_on_path = tree_weights[path_idx[:-1]] # remove leaf node, i.e. the last index 63 | weight_per_node = np.abs(normalize(weights_on_path)) 64 | if WeightedImportance: # average weights on path weighted by probabilities 65 | weight_per_node = [probs*weights for probs, weights in zip (probs_on_path, weight_per_node)] 66 | average_weight = np.mean(weight_per_node, axis=0) # take absolute to prevent that positive and negative will counteract 67 | average_weight_list_epi.append(average_weight) 68 | 69 | s_prime, r, done, info = env.step(a) 70 | # env.render() 71 | s = s_prime 72 | 73 | reward += r 74 | step+=1 75 | if done: 76 | break 77 | 78 | average_weight_list.append(average_weight_list_epi) 79 | print("# of episode :{}, reward : {:.1f}, episode length: {}".format(n_epi, reward, step)) 80 | path = 'data/sdt_importance_online.npy' 81 | np.save(path, average_weight_list) 82 | plot_importance_single_episode(data_path=path, save_path='./img/sdt_importance_online.png', epi_id=0) 83 | 84 | env.close() 85 | 86 | 87 | def evaluate_offline(model, tree, episodes=1, frameskip=1, seed=None, data_path='./data/evaluate_state.npy', DrawImportance=True, method='weight', WeightedImportance=False): 88 | states = np.load(data_path, allow_pickle=True) 89 | tree_weights = tree.get_tree_weights() 90 | average_weight_list=[] 91 | for n_epi in range(episodes): 92 | average_weight_list_epi = [] 93 | for i, s in enumerate(states[n_epi]): 94 | a = model(torch.Tensor([s])) 95 | if i%frameskip==0: 96 | if DrawImportance: 97 | if method == 'weight': 98 | path_idx, inner_probs = get_path(tree, s, Probs=True) 99 | 100 | # get probability on decision path (with greatest leaf probability) 101 | last_idx=0 102 | probs_on_path = [] 103 | for idx in path_idx[1:]: 104 | if idx == 2*last_idx+1: # parent node goes to left node 105 | probs_on_path.append(inner_probs[last_idx]) 106 | elif idx == 2*last_idx+2: # last index goes to right node, prob should be 1-prob 107 | probs_on_path.append(1-inner_probs[last_idx]) 108 | else: 109 | raise ValueError 110 | last_idx = idx 111 | 112 | weights_on_path = tree_weights[path_idx[:-1]] # remove leaf node, i.e. the last index 113 | weight_per_node = np.abs(normalize(weights_on_path)) 114 | if WeightedImportance: 115 | weight_per_node = [probs*weights for probs, weights in zip (probs_on_path, weight_per_node)] 116 | average_weight = np.mean(weight_per_node, axis=0) # take absolute to prevent that positive and negative will counteract 117 | average_weight_list_epi.append(average_weight) 118 | elif method == 'gradient': 119 | x = torch.Tensor([s]) 120 | x.requires_grad = True 121 | a = tree.forward(x)[1] # [1] is output, which requires gradient, but it's the expectation of leaves rather than the max-prob leaf 122 | gradient = torch.autograd.grad(outputs=a, inputs=x, grad_outputs=torch.ones_like(a), 123 | retain_graph=True, allow_unused=True) 124 | average_weight_list_epi.append(np.abs(gradient[0].squeeze().cpu().numpy())) 125 | 126 | average_weight_list.append(average_weight_list_epi) 127 | path = 'data/sdt_importance_offline.npy' 128 | np.save(path, average_weight_list) 129 | plot_importance_single_episode(data_path=path, save_path='./img/sdt_importance_offline.png', epi_id=0) 130 | 131 | def prediction_evaluation(tree, data_dir='../data/discrete_'): 132 | from utils.dataset import Dataset 133 | # Load data 134 | data_path = data_dir+'state.npy' 135 | label_path = data_dir+'action.npy' 136 | 137 | # a data loader with all data in dataset 138 | test_loader = torch.utils.data.DataLoader(Dataset(data_path, label_path, partition='test'), 139 | batch_size=int(1e4), 140 | shuffle=True) 141 | accuracy_list=[] 142 | correct=0. 143 | for batch_idx, (data, target) in enumerate(test_loader): 144 | # target_onehot = onehot_coding(target, tree.args['output_dim']) 145 | prediction, _, _, _ = tree.forward(data) 146 | with torch.no_grad(): 147 | pred = prediction.data.max(1)[1] 148 | correct += pred.eq(target.view(-1).data).sum() 149 | accuracy = 100. * float(correct) / len(test_loader.dataset) 150 | print('Tree Accuracy: {:.4f}'.format(accuracy)) 151 | 152 | 153 | def plot_importance_single_episode(data_path='data/sdt_importance.npy', save_path='./img/sdt_importance.png', epi_id=0): 154 | data = np.load(data_path, allow_pickle=True)[epi_id] 155 | markers=[".", "d", "o", "*", "^", "v", "p", "h"] 156 | for i, weights_per_feature in enumerate(np.array(data).T): 157 | plt.plot(weights_per_feature, label='Dim: {}'.format(i), marker=markers[i], markevery=8) 158 | plt.legend(loc=1) 159 | plt.xlabel('Step') 160 | plt.ylabel('Feature Importance') 161 | if save_path: 162 | plt.savefig(save_path) 163 | plt.close() 164 | else: 165 | plt.show() 166 | 167 | if __name__ == '__main__': 168 | # Cartpole 169 | from sdt_train_cartpole import learner_args # ignore this 170 | from SDT import SDT 171 | 172 | # for reproduciblility 173 | seed=3 174 | if seed: 175 | torch.manual_seed(seed) 176 | np.random.seed(seed) 177 | learner_args['cuda'] = False # cpu 178 | learner_args['depth'] = 2 179 | learner_args['model_path'] = './model/sdt/'+str(learner_args['depth'])+'_id'+str(4) 180 | 181 | tree = SDT(learner_args) 182 | Discretized=False # whether load the discretized tree 183 | if Discretized: 184 | tree.load_model(learner_args['model_path']+'_discretized') 185 | else: 186 | tree.load_model(learner_args['model_path']) 187 | 188 | num_params = 0 189 | for key, v in tree.state_dict().items(): 190 | print(key, v.shape) 191 | num_params+=v.reshape(-1).shape[0] 192 | print('Total number of parameters in model: ', num_params) 193 | 194 | 195 | model = lambda x: tree.forward(x)[0].data.max(1)[1].squeeze().detach().numpy() 196 | if Discretized: 197 | evaluate(model, tree, episodes=10, frameskip=1, seed=seed, DrawTree=False, DrawImportance=False, img_path='img/eval_tree{}_discretized'.format(tree.args['depth'])) 198 | else: 199 | evaluate(model, tree, episodes=10, frameskip=1, seed=seed, DrawTree=False, DrawImportance=False, img_path='img/eval_tree{}'.format(tree.args['depth'])) 200 | 201 | plot_importance_single_episode(epi_id=0) 202 | -------------------------------------------------------------------------------- /src/sdt/sdt_il_train.json: -------------------------------------------------------------------------------- 1 | { "General": { 2 | "policy_approx" : "SDT" 3 | }, 4 | "CartPole-v1": { 5 | "learner_args": { 6 | "input_dim": 4, 7 | "output_dim": 2, 8 | "depth": 3, 9 | "lamda": 1e-3, 10 | "lr": 1e-3, 11 | "weight_decay": 0.0, 12 | "batch_size": 1280, 13 | "epochs": 80, 14 | "device": "cuda", 15 | "log_interval": 100, 16 | "exp_scheduler_gamma": 1.0, 17 | "beta" : 0, 18 | "l1_regularization": 0, 19 | "greatest_path_probability": 1, 20 | "model_path" : "../data/sdt/model/cartpole/il_model", 21 | "log_path" : "../data/sdt/log/cartpole/il_log" 22 | } 23 | }, 24 | 25 | "LunarLander-v2": { 26 | "learner_args": { 27 | "input_dim": 8, 28 | "output_dim": 4, 29 | "depth": 4, 30 | "lamda": 1e-3, 31 | "lr": 1e-3, 32 | "weight_decay": 0.0, 33 | "batch_size": 1280, 34 | "epochs": 80, 35 | "device": "cuda", 36 | "log_interval": 100, 37 | "exp_scheduler_gamma": 1.0, 38 | "beta" : 0, 39 | "l1_regularization": 0, 40 | "greatest_path_probability": 1, 41 | "model_path" : "../data/sdt/model/lunarlander/il_model", 42 | "log_path" : "../data/sdt/log/lunarlander/il_log" 43 | } 44 | 45 | } 46 | } -------------------------------------------------------------------------------- /src/sdt/sdt_rl_train.json: -------------------------------------------------------------------------------- 1 | { "General": { 2 | "policy_approx" : "SDT" 3 | }, 4 | "CartPole-v1": { 5 | "learner_args": { 6 | "input_dim": 4, 7 | "output_dim": 2, 8 | "depth": 3, 9 | "lamda": 1e-3, 10 | "lr": 1e-3, 11 | "weight_decay": 0.0, 12 | "batch_size": 1280, 13 | "epochs": 40, 14 | "device": "cuda", 15 | "log_interval": 100, 16 | "exp_scheduler_gamma": 1.0, 17 | "beta" : 0, 18 | "l1_regularization": 0, 19 | "greatest_path_probability": 1 20 | }, 21 | "alg_confs": { 22 | "learning_rate" : 0.0005, 23 | "gamma" : 0.98, 24 | "lmbda" : 0.95, 25 | "eps_clip" : 0.1, 26 | "K_epoch" : 3, 27 | "hidden_dim" : 128 28 | }, 29 | "train_confs": { 30 | "episodes" : 3000, 31 | "t_horizon" : 1000, 32 | "model_path" : "../data/sdt/model/cartpole/rl_ppo", 33 | "log_path" : "../data/sdt/log/cartpole/rl_reward" 34 | } 35 | }, 36 | 37 | "LunarLander-v2": { 38 | "learner_args": { 39 | "input_dim": 8, 40 | "output_dim": 4, 41 | "depth": 4, 42 | "lamda": 1e-3, 43 | "lr": 1e-3, 44 | "weight_decay": 0.0, 45 | "batch_size": 1280, 46 | "epochs": 40, 47 | "device": "cuda", 48 | "log_interval": 100, 49 | "exp_scheduler_gamma": 1.0, 50 | "beta" : 0, 51 | "l1_regularization": 0, 52 | "greatest_path_probability": 1 53 | }, 54 | "alg_confs": { 55 | "learning_rate" : 0.0005, 56 | "gamma" : 0.98, 57 | "lmbda" : 0.95, 58 | "eps_clip" : 0.1, 59 | "K_epoch" : 3, 60 | "hidden_dim" : 128 61 | }, 62 | "train_confs": { 63 | "episodes" : 5000, 64 | "t_horizon" : 1000, 65 | "model_path" : "../data/sdt/model/lunarlander/rl_ppo", 66 | "log_path" : "../data/sdt/log/lunarlander/rl_reward" 67 | } 68 | 69 | }, 70 | 71 | "MountainCar-v0": { 72 | "learner_args": { 73 | "input_dim": 2, 74 | "output_dim": 3, 75 | "depth": 3, 76 | "lamda": 1e-3, 77 | "lr": 1e-3, 78 | "weight_decay": 0.0, 79 | "batch_size": 128, 80 | "epochs": 40, 81 | "device": "cuda", 82 | "log_interval": 100, 83 | "exp_scheduler_gamma": 1.0, 84 | "beta" : 0, 85 | "l1_regularization": 0, 86 | "greatest_path_probability": 1 87 | }, 88 | "alg_confs": { 89 | "learning_rate" : 0.005, 90 | "gamma" : 0.999, 91 | "lmbda" : 0.98, 92 | "eps_clip" : 0.1, 93 | "K_epoch" : 10, 94 | "hidden_dim" : 32 95 | }, 96 | "train_confs": { 97 | "episodes" : 5000, 98 | "t_horizon" : 1000, 99 | "model_path" : "../data/sdt/model/mountaincar/rl_ppo", 100 | "log_path" : "../data/sdt/log/mountaincar/rl_reward" 101 | } 102 | 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /src/sdt/sdt_rl_train_compare.json: -------------------------------------------------------------------------------- 1 | { "General": { 2 | "policy_approx" : "SDT" 3 | }, 4 | "CartPole-v1": { 5 | "learner_args": { 6 | "input_dim": 4, 7 | "output_dim": 2, 8 | "depth": 3, 9 | "lamda": 1e-3, 10 | "lr": 1e-3, 11 | "weight_decay": 0.0, 12 | "batch_size": 1280, 13 | "epochs": 40, 14 | "device": "cuda", 15 | "log_interval": 100, 16 | "exp_scheduler_gamma": 1.0, 17 | "beta" : 0, 18 | "l1_regularization": 0, 19 | "greatest_path_probability": 1 20 | }, 21 | "alg_confs": { 22 | "learning_rate" : 0.0005, 23 | "gamma" : 0.98, 24 | "lmbda" : 0.95, 25 | "eps_clip" : 0.1, 26 | "K_epoch" : 3, 27 | "hidden_dim" : 128 28 | }, 29 | "train_confs": { 30 | "episodes" : 3000, 31 | "t_horizon" : 1000, 32 | "model_path" : "../data/sdt_compare_depth/model/cartpole/rl_ppo", 33 | "log_path" : "../data/sdt_compare_depth/log/cartpole/rl_reward" 34 | } 35 | }, 36 | 37 | "LunarLander-v2": { 38 | "learner_args": { 39 | "input_dim": 8, 40 | "output_dim": 4, 41 | "depth": 4, 42 | "lamda": 1e-3, 43 | "lr": 1e-3, 44 | "weight_decay": 0.0, 45 | "batch_size": 1280, 46 | "epochs": 40, 47 | "device": "cuda", 48 | "log_interval": 100, 49 | "exp_scheduler_gamma": 1.0, 50 | "beta" : 0, 51 | "l1_regularization": 0, 52 | "greatest_path_probability": 1 53 | }, 54 | "alg_confs": { 55 | "learning_rate" : 0.0005, 56 | "gamma" : 0.98, 57 | "lmbda" : 0.95, 58 | "eps_clip" : 0.1, 59 | "K_epoch" : 3, 60 | "hidden_dim" : 128 61 | }, 62 | "train_confs": { 63 | "episodes" : 5000, 64 | "t_horizon" : 1000, 65 | "model_path" : "../data/sdt_compare_depth/model/lunarlander/rl_ppo", 66 | "log_path" : "../data/sdt_compare_depth/log/lunarlander/rl_reward" 67 | } 68 | 69 | }, 70 | 71 | "MountainCar-v0": { 72 | "learner_args": { 73 | "input_dim": 2, 74 | "output_dim": 3, 75 | "depth": 2, 76 | "lamda": 1e-3, 77 | "lr": 1e-3, 78 | "weight_decay": 0.0, 79 | "batch_size": 1280, 80 | "epochs": 40, 81 | "device": "cuda", 82 | "log_interval": 100, 83 | "exp_scheduler_gamma": 1.0, 84 | "beta" : 0, 85 | "l1_regularization": 0, 86 | "greatest_path_probability": 1 87 | }, 88 | "alg_confs": { 89 | "learning_rate" : 0.0005, 90 | "gamma" : 0.98, 91 | "lmbda" : 0.95, 92 | "eps_clip" : 0.1, 93 | "K_epoch" : 3, 94 | "hidden_dim" : 32 95 | }, 96 | "train_confs": { 97 | "episodes" : 5000, 98 | "t_horizon" : 1000, 99 | "model_path" : "../data/sdt_compare_depth/model/mountaincar/rl_ppo", 100 | "log_path" : "../data/sdt_compare_depth/log/mountaincar/rl_reward" 101 | } 102 | 103 | } 104 | } -------------------------------------------------------------------------------- /src/utils/__pycache__/common_func.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantumiracle/Cascading-Decision-Tree/8c8be18031511547d35e540a01262f1c006b9687/src/utils/__pycache__/common_func.cpython-36.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantumiracle/Cascading-Decision-Tree/8c8be18031511547d35e540a01262f1c006b9687/src/utils/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/heuristic_evaluation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantumiracle/Cascading-Decision-Tree/8c8be18031511547d35e540a01262f1c006b9687/src/utils/__pycache__/heuristic_evaluation.cpython-36.pyc -------------------------------------------------------------------------------- /src/utils/common_func.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | 4 | def plot(rewards, name:str): 5 | # clear_output(True) 6 | plt.figure(figsize=(10,5)) 7 | plt.plot(rewards) 8 | plt.savefig(name) 9 | # plt.show() 10 | plt.clf() 11 | plt.close() 12 | 13 | 14 | import json 15 | def json_read_file(filename): 16 | ''' 17 | @brief: 18 | read data from json file 19 | @params: 20 | filename 21 | @return: 22 | (dict) parsed json file 23 | ''' 24 | with open(filename, "r") as read_file: 25 | return json.load(read_file) 26 | 27 | 28 | def onehot_coding(target, device, output_dim): 29 | target_onehot = torch.FloatTensor(target.size()[0], output_dim).to(device) 30 | target_onehot.data.zero_() 31 | target_onehot.scatter_(1, target.view(-1, 1), 1.) 32 | return target_onehot -------------------------------------------------------------------------------- /src/utils/dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | from torch.utils import data 4 | import numpy as np 5 | import json 6 | 7 | class Dataset(data.Dataset): 8 | ''' 9 | Characterizes a dataset for PyTorch 10 | ''' 11 | def __init__(self, 12 | data_path, 13 | label_path, 14 | partition='all', 15 | train_ratio=0.8, 16 | total_ratio=1., 17 | ToTensor=True, 18 | ): 19 | """ 20 | Initialization 21 | 22 | :param data_path: (str) 23 | :param label_path: (str) 24 | :param partition: (str), choose from all data ('all'), training data ('traing') or testing data ('test') 25 | :param train_ratio: (float) ratio of training data over all data 26 | 27 | """ 28 | self.ToTensor = ToTensor 29 | # load data 30 | self.x = np.load(data_path) 31 | self.y = np.load(label_path) 32 | 33 | total_size = np.array(self.x).shape[0] 34 | total_size = int(total_size * 35 | total_ratio) # if only use partial dataset 36 | if partition == 'train': 37 | self.list_IDs = np.arange(int(total_size * train_ratio)) 38 | elif partition == 'test': 39 | self.list_IDs = np.arange(int(total_size * train_ratio), 40 | total_size) 41 | elif partition == 'all': 42 | self.list_IDs = np.arange(total_size) 43 | else: 44 | raise NotImplementedError 45 | 46 | def __len__(self): 47 | 'Denotes the total number of samples' 48 | return len(self.list_IDs) 49 | 50 | def __getitem__(self, index): 51 | 'Generates one sample of data' 52 | # Select sample 53 | ID = self.list_IDs[index] 54 | 55 | # Load data and get label 56 | x = self.x[ID] 57 | y = self.y[ID] 58 | 59 | if self.ToTensor: 60 | x = torch.FloatTensor(x) 61 | # y = torch.FloatTensor(y) 62 | return x, y 63 | -------------------------------------------------------------------------------- /src/utils/heuristic_evaluation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ This script contains information for heuristic decision tree of LunarLander-v2 (discrete case) """ 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | # In non-hierarchical heuristic decision tree, weights (weights&bias) of all nodes are listed here 8 | nodes_in_heuristic_tree = [ 9 | [0, 0,0,0,0,0,0,1,1], # first dim is bias, the rest are weights 10 | 11 | [-0.4, 0.5, 0,1,0,0,0,0,0], 12 | [-0.4, -0.5, 0,-1,0,0,0,0,0], 13 | [0, 1,0,0,0,0,0,0,0], 14 | 15 | # at 16 | [0.2, 0,0,0,0,-0.5,-1,0,0], 17 | [0.15, 0,0,0,0,-0.5,-1,0,0], 18 | [-0.25, 0,0,0,0,0.5,1,0,0], 19 | 20 | [-0.2, 0,0,0,0,-0.5,-1,0,0], 21 | [-0.25, 0,0,0,0,-0.5,-1,0,0], 22 | [0.15, 0,0,0,0,0.5,1,0,0], 23 | 24 | 25 | [0, 0.25, 0, 0.5, 0, -0.5, -1, 0, 0 ], 26 | [-0.05, 0.25, 0, 0.5, 0, -0.5, -1, 0, 0 ], 27 | [-0.05, -0.25, 0, -0.5, 0, 0.5, 1, 0, 0 ], 28 | 29 | 30 | # ht cases 31 | [-0.05, 0.275, -0.5, 0, -0.5, 0,0,0,0], 32 | 33 | [-0.05, -0.275, -0.5, 0, -0.5, 0,0,0,0], 34 | 35 | [-0.05, 0, 0, 0, -0.5, 0, 0, 0, 0], 36 | 37 | # at, ht cases 38 | [-0.2, 0.275, -0.5, 0,-0.5, 0.5,1,0,0], 39 | [0.2, 0.275, -0.5, 0,-0.5, -0.5, -1, 0,0], 40 | 41 | [-0.2, -0.275, -0.5, 0,-0.5, 0.5,1,0,0], 42 | [0.2, -0.275, -0.5, 0,-0.5, -0.5, -1, 0,0], 43 | 44 | [0.2, 0.275, -0.5, 0,-0.5, 0.5,1,0,0], 45 | [-0.2, 0.275, -0.5, 0,-0.5, -0.5, -1, 0,0], 46 | 47 | [0.2, -0.275, -0.5, 0,-0.5, 0.5,1,0,0], 48 | [-0.2, -0.275, -0.5, 0,-0.5, -0.5, -1, 0,0], 49 | 50 | [0, 0.025, -0.5, -0.5, -0.5, 0.5, 1, 0, 0], 51 | [0, 0.525, -0.5, 0.5, -0.5, -0.5, -1, 0, 0], 52 | 53 | [0, -0.525, -0.5, -0.5, -0.5, 0.5, 1, 0, 0], 54 | [0, -0.025, -0.5, 0.5, -0.5, -0.5, -1, 0, 0], 55 | 56 | ] 57 | 58 | # All intermediate feature vectors in heuristic decision tree 59 | intermediate_features_in_heuristic_tree = [ # first dim is constant, the rest are weights 60 | # at 61 | [0, 0,0,0,0,0,0,0,0], 62 | [0.2, 0,0,0,0,-0.5,-1,0,0], 63 | [-0.2, 0,0,0,0,-0.5,-1,0,0], 64 | [0, 0.25,0,0.5,0,-0.5,-1,0,0], 65 | 66 | # ht 67 | [0, 0,0,0,-0.5,0,0,0,0], 68 | [0, 0.275,-0.5,0,-0.5,0,0,0,0], 69 | [0, 0.275,-0.5,0,-0.5,0,0,0,0], 70 | ] 71 | 72 | def normalize(list_v): 73 | normalized_list = [] 74 | for v in list_v: 75 | if np.sum(np.abs(v)) == 0: 76 | continue 77 | else: 78 | v =np.array(v)/np.sum(np.abs(v)) 79 | normalized_list.append(v) 80 | return normalized_list 81 | 82 | def l1_norm(a,b): 83 | return np.linalg.norm(np.array(a)-np.array(b), ord=1) 84 | 85 | def difference_metric(list1, list2=nodes_in_heuristic_tree, norm=True): 86 | ''' 87 | Calculate minimal difference of list1 and list2 88 | ''' 89 | if norm: 90 | list1 = normalize(list1) 91 | list2 = normalize(list2) 92 | score = [] 93 | for v1 in list1: 94 | sim_list = [] 95 | for v2 in list2: 96 | sim = np.min([l1_norm(v1, v2), l1_norm(v1, -1.*np.array(v2))]) 97 | sim_list.append(sim) 98 | score.append(np.min(sim_list)) # should be changed to be mean rather than sum 99 | return np.mean(score) 100 | 101 | if __name__ == '__main__': 102 | a=np.ones((2,8)) 103 | print(difference_metric(a, nodes_in_heuristic_tree)) 104 | -------------------------------------------------------------------------------- /src/viper/.gitignore: -------------------------------------------------------------------------------- 1 | *.*~ 2 | __pycache__/ 3 | *.pkl 4 | data/ 5 | **/*.egg-info 6 | .python-version 7 | .idea/ 8 | .vscode/ 9 | .DS_Store 10 | _build/ 11 | **/.ipynb_checkpoints 12 | **/*.pt 13 | **/checkpoints 14 | **/wandb 15 | **/ild.out 16 | -------------------------------------------------------------------------------- /src/viper/algorithms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantumiracle/Cascading-Decision-Tree/8c8be18031511547d35e540a01262f1c006b9687/src/viper/algorithms/__init__.py -------------------------------------------------------------------------------- /src/viper/algorithms/config/Dagger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import ipdb as pdb 3 | import numpy as np 4 | from ..agents import DecisionTree 5 | from ..utils import Config 6 | import ray 7 | 8 | 9 | def getArgs(target_acc=None, max_depth=20, max_leaf_nodes=int(1e6), update_interval=1000, use_Q=False): 10 | # target acc overides max_depth and max_leaf_nodes 11 | 12 | p_args=None 13 | q_args=None 14 | pi_args=None 15 | 16 | agent_args=Config() 17 | agent_args.agent=DecisionTree 18 | agent_args.p_args = p_args 19 | agent_args.eps = 0 20 | agent_args.q_args = q_args 21 | agent_args.pi_args = pi_args 22 | agent_args.max_depth=max_depth 23 | agent_args.max_leaf_nodes=max_leaf_nodes 24 | agent_args.use_Q = use_Q 25 | agent_args.n_warmup = update_interval 26 | agent_args.update_interval = update_interval 27 | return agent_args 28 | 29 | -------------------------------------------------------------------------------- /src/viper/algorithms/config/QLearning_Atari.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from ..utils import Config, Logger 4 | from ..models import CNN 5 | from ..agents import QLearning 6 | from ..algorithm import RL 7 | from ..envs.Breakout import env_name, env_fn 8 | 9 | """ 10 | notice that 50M samples is typical for DQNs with visual input (refer to rainbow) 11 | 12 | the configs are the same as rainbow, 13 | batchsize *8, lr * 4, update frequency/ 8 14 | no noisy q and therefore eps of 3e-2 15 | 16 | """ 17 | algo_args = Config() 18 | 19 | algo_args.max_ep_len=2000 20 | algo_args.batch_size=256 21 | algo_args.n_warmup=int(2e5) 22 | algo_args.replay_size=int(1e6) 23 | # from rainbow 24 | algo_args.test_interval = int(3e4) 25 | algo_args.seed=0 26 | algo_args.save_interval=600 27 | algo_args.log_interval=int(2e2) 28 | algo_args.n_step=int(1e8) 29 | 30 | q_args=Config() 31 | q_args.network = CNN 32 | q_args.update_interval=32 33 | q_args.activation=torch.nn.ReLU 34 | q_args.lr=2e-4 35 | q_args.strides = [2]*6 36 | q_args.kernels = [3]*6 37 | q_args.paddings = [1]*6 38 | q_args.sizes = [4, 16, 32, 64, 128, 128, 5] # 4 actions, dueling q learning 39 | 40 | agent_args=Config() 41 | agent_args.agent=QLearning 42 | agent_args.eps=3e-2 43 | agent_args.gamma=0.99 44 | agent_args.target_sync_rate=q_args.update_interval/32000 45 | 46 | args = Config() 47 | args.env_name="Breakout-v0" 48 | args.name=f"{args.env_name}_{agent_args.agent}" 49 | device = 0 50 | 51 | q_args.env_fn = env_fn 52 | agent_args.env_fn = env_fn 53 | algo_args.env_fn = env_fn 54 | 55 | agent_args.p_args = None 56 | agent_args.q_args = q_args 57 | agent_args.pi_args = None 58 | algo_args.agent_args = agent_args 59 | args.algo_args = algo_args # do not call toDict() before config is set 60 | 61 | RL(logger = Logger(args), device=device, **algo_args._toDict()).run() -------------------------------------------------------------------------------- /src/viper/algorithms/config/SAC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import ipdb as pdb 3 | import numpy as np 4 | from ..utils import Config, listStack 5 | from ..models import MLP 6 | from ..agents import SAC 7 | import ray 8 | 9 | def getArgs(observation_dim, n_action): 10 | """ 11 | hyperparameters refer to the original paper as well as https://stable-baselines3.readthedocs.io/en/master/modules/sac.html 12 | """ 13 | p_args=None 14 | 15 | q_args=Config() 16 | q_args.network = MLP 17 | q_args.activation=torch.nn.ReLU 18 | q_args.lr=3e-4 19 | q_args.sizes = [observation_dim, 16, 32, n_action+1] 20 | q_args.update_interval=10 21 | # MBPO used 1/40 for continous control tasks 22 | # 1/20 for invert pendulum 23 | q_args.n_embedding = 0 24 | 25 | pi_args=Config() 26 | pi_args.network = MLP 27 | pi_args.activation=torch.nn.ReLU 28 | pi_args.lr=3e-4 29 | pi_args.sizes = [observation_dim, 16, 32, n_action] 30 | pi_args.update_interval=10 31 | 32 | agent_args=Config() 33 | agent_args.agent=SAC 34 | agent_args.eps=0 35 | agent_args.n_warmup=int(1e3) 36 | agent_args.batch_size=256 # the same as MBPO 37 | """ 38 | rainbow said 2e5 samples or 5e4 updates is typical for Qlearning 39 | bs256lr3e-4, it takes 2e4updates 40 | for the model on CartPole to learn done... 41 | 42 | Only 3e5 samples are needed for parameterized input continous motion control (refer to MBPO) 43 | 4e5 is needed fore model free CACC (refer to NeurComm) 44 | """ 45 | agent_args.gamma=0.99 46 | agent_args.alpha=0.2 47 | agent_args.target_entropy = 0.2 48 | # overrides alpha 49 | # 4 actions, 0.9 greedy = 0.6, 0.95 greedy= 0.37, 0.99 greedy 0.1 50 | agent_args.target_sync_rate=5e-3 51 | # called tau in MBPO 52 | # sync rate per update = update interval/target sync interval 53 | agent_args.p_args = p_args 54 | agent_args.q_args = q_args 55 | agent_args.pi_args = pi_args 56 | 57 | return agent_args 58 | 59 | -------------------------------------------------------------------------------- /src/viper/algorithms/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantumiracle/Cascading-Decision-Tree/8c8be18031511547d35e540a01262f1c006b9687/src/viper/algorithms/config/__init__.py -------------------------------------------------------------------------------- /src/viper/algorithms/config/heuristic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from ..utils import Config 3 | from ..agents import BaseAgent 4 | 5 | def getArgs(env_name): 6 | agent_args=Config() 7 | agent_args.p_args=None 8 | agent_args.q_args=None 9 | agent_args.pi_args=None 10 | if env_name == 'CartPole-v1': 11 | agent_args.agent = lambda **kwargs: HeuristicAgent(cartPoleHeuristic, **kwargs) 12 | elif env_name == 'LunarLander-v2': 13 | agent_args.agent = lambda **kwargs: HeuristicAgent(lunarLanderHeuristic, **kwargs) 14 | return agent_args 15 | 16 | 17 | class HeuristicAgent(BaseAgent): 18 | def __init__(self, policy, **kwargs): 19 | super().__init__(**kwargs) 20 | self.policy = policy 21 | 22 | def act(self, s, deterministic=True): 23 | assert deterministic==True 24 | return self.policy(s) 25 | 26 | def cartPoleHeuristic(observation): 27 | position, velocity, angle, angle_velocity = observation[0] 28 | action = int(3. * angle + angle_velocity > 0.) 29 | return np.array([action]) 30 | 31 | def lunarLanderHeuristic(s): 32 | """ 33 | The heuristic for 34 | 1. Testing 35 | 2. Demonstration rollout. 36 | Args: 37 | env: The environment 38 | s (list): The state. Attributes: 39 | s[0] is the horizontal coordinate 40 | s[1] is the vertical coordinate 41 | s[2] is the horizontal speed 42 | s[3] is the vertical speed 43 | s[4] is the angle 44 | s[5] is the angular speed 45 | s[6] 1 if first leg has contact, else 0 46 | s[7] 1 if second leg has contact, else 0 47 | returns: 48 | a: The heuristic to be fed into the step function defined above to determine the next step and reward. 49 | """ 50 | s = s[0] 51 | angle_targ = s[0] * 0.5 + s[2] * 1.0 # angle should point towards center 52 | if angle_targ > 0.4: 53 | angle_targ = 0.4 # more than 0.4 radians (22 degrees) is bad 54 | if angle_targ < -0.4: 55 | angle_targ = -0.4 56 | hover_targ = 0.55 * np.abs( 57 | s[0] 58 | ) # target y should be proportional to horizontal offset 59 | 60 | angle_todo = (angle_targ - s[4]) * 0.5 - (s[5]) * 1.0 61 | hover_todo = (hover_targ - s[1]) * 0.5 - (s[3]) * 0.5 62 | 63 | if s[6] or s[7]: # legs have contact 64 | angle_todo = 0 65 | hover_todo = ( 66 | -(s[3]) * 0.5 67 | ) # override to reduce fall speed, that's all we need after contact 68 | 69 | a = 0 70 | if hover_todo > np.abs(angle_todo) and hover_todo > 0.05: 71 | a = 2 72 | elif angle_todo < -0.05: 73 | a = 3 74 | elif angle_todo > +0.05: 75 | a = 1 76 | return np.array([a]) 77 | 78 | -------------------------------------------------------------------------------- /src/viper/algorithms/envs/Wrapper.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | class GymWrapper(gym.Wrapper): 5 | """ 6 | Basic wrapper that makes everything a numpy array 7 | """ 8 | def __init__(self, env, reward_mean=0, reward_std=1): 9 | gym.Wrapper.__init__(self, gym.make(env)) 10 | self.reward_mean = reward_mean 11 | self.reward_std = reward_std 12 | 13 | def reset(self): 14 | state = self.env.reset() 15 | self.state = np.array(state) 16 | return self.state 17 | 18 | def step(self, a): 19 | state, reward, done, info = self.env.step(a) 20 | reward = reward*self.reward_std + self.reward_mean 21 | self.state = np.array(state) 22 | return self.state, np.array(reward), np.array(done), None 23 | 24 | def rescaleReward(self, sum_reward, ep_len): 25 | return (sum_reward-ep_len*self.reward_mean)/self.reward_std -------------------------------------------------------------------------------- /src/viper/algorithms/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantumiracle/Cascading-Decision-Tree/8c8be18031511547d35e540a01262f1c006b9687/src/viper/algorithms/envs/__init__.py -------------------------------------------------------------------------------- /src/viper/algorithms/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import ipdb as pdb 3 | import itertools 4 | from gym.spaces import Box, Discrete 5 | import random 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.distributions.normal import Normal 11 | from torch.distributions.categorical import Categorical 12 | from torch.optim import Adam 13 | 14 | 15 | def MLP(sizes, activation, output_activation=nn.Identity, **kwargs): 16 | layers = [] 17 | for j in range(len(sizes)-1): 18 | act = activation if j < len(sizes)-2 else output_activation 19 | layers += [nn.Linear(sizes[j], sizes[j+1]), act()] 20 | return nn.Sequential(*layers) 21 | 22 | def CNN(sizes, kernels, strides, paddings, activation, output_activation=nn.Identity, **kwargs): 23 | layers = [] 24 | for j in range(len(sizes)-1): 25 | act = activation if j < len(sizes)-2 else output_activation 26 | layers += [nn.Conv2d(sizes[j], sizes[j+1], kernels[j], strides[j], paddings[j]), act()] 27 | return nn.Sequential(*layers) 28 | 29 | class ParameterizedModel(nn.Module): 30 | """ 31 | assumes parameterized state representation 32 | we may use a gaussian prediciton, 33 | but it degenrates without a kl hyperparam 34 | unlike the critic and the actor class, 35 | the sizes argument does not include the dim of the state 36 | n_embedding is the number of embedding modules needed, = the number of discrete action spaces used as input 37 | """ 38 | def __init__(self, env, logger, n_embedding=1, to_predict="srd", **net_args): 39 | super().__init__() 40 | self.logger = logger.child("p") 41 | self.action_space=env.action_space 42 | self.observation_space = env.observation_space 43 | input_dim = net_args['sizes'][0] 44 | output_dim = net_args['sizes'][-1] 45 | self.n_embedding = n_embedding 46 | if isinstance(self.action_space, Discrete): 47 | self.action_embedding = nn.Embedding(self.action_space.n, input_dim//n_embedding) 48 | self.net = MLP(**net_args) 49 | self.state_head = nn.Linear(output_dim, self.observation_space.shape[0]) 50 | self.reward_head = nn.Linear(output_dim, 1) 51 | self.done_head = nn.Linear(output_dim, 1) 52 | self.MSE = nn.MSELoss(reduction='none') 53 | self.BCE = nn.BCEWithLogitsLoss(reduction='none') 54 | self.to_predict = to_predict 55 | 56 | def forward(self, s, a, r=None, s1=None, d=None): 57 | embedding = s 58 | if isinstance(self.action_space, Discrete): 59 | batch_size, _ = a.shape 60 | action_embedding = self.action_embedding(a).view(batch_size, -1) 61 | embedding = embedding + action_embedding 62 | embedding = self.net(embedding) 63 | state = self.state_head(embedding) 64 | reward = self.reward_head(embedding).squeeze(1) 65 | 66 | 67 | 68 | if r is None: #inference 69 | with torch.no_grad(): 70 | done = torch.sigmoid(self.done_head(embedding)) 71 | done = torch.cat([1-done, done], dim = 1) 72 | done = Categorical(done).sample() # [b] 73 | return reward, state, done 74 | 75 | else: # training 76 | done = self.done_head(embedding).squeeze(1) 77 | state_loss = self.MSE(state, s1) 78 | state_loss = state_loss.mean(dim=1) 79 | state_var = self.MSE(s1, s1.mean(dim = 0, keepdim=True).expand(*s1.shape)) 80 | # we assume the components of state are of similar magnitude 81 | rel_state_loss = state_loss.mean()/state_var.mean() 82 | self.logger.log(rel_state_loss=rel_state_loss) 83 | 84 | loss = state_loss 85 | if 'r' in self.to_predict: 86 | reward_loss = self.MSE(reward, r) 87 | loss = loss + reward_loss 88 | reward_var = self.MSE(reward, reward.mean(dim=0, keepdim=True).expand(*reward.shape)).mean() 89 | 90 | self.logger.log(reward_loss=reward_loss, 91 | reward_var=reward_var, 92 | reward = r) 93 | 94 | if 'd' in self.to_predict: 95 | done_loss = self.BCE(done, d) 96 | loss = loss + 10* done_loss 97 | done = done > 0 98 | done_true_positive = (done*d).mean() 99 | d = d.mean() 100 | self.logger.log(done_loss=done_loss,done_true_positive=done_true_positive, done=d, rolling=100) 101 | 102 | return (loss, state.detach()) 103 | 104 | class QCritic(nn.Module): 105 | """ 106 | Dueling Q, currently only implemented for discrete action space 107 | if n_embedding > 0, assumes the action space needs embedding 108 | Notice that the output shape should be 1+action_space.n for discrete dueling Q 109 | 110 | n_embedding is the number of embedding modules needed, = the number of discrete action spaces used as input 111 | only used for decentralized multiagent, assumes the first action is local (consistent with gather() in utils) 112 | """ 113 | def __init__(self, env, n_embedding=0, **q_args): 114 | super().__init__() 115 | q_net = q_args['network'] 116 | self.action_space=env.action_space 117 | self.q = q_net(**q_args) 118 | self.n_embedding = n_embedding 119 | input_dim = q_args['sizes'][0] 120 | self.state_per_agent = input_dim//(n_embedding+1) 121 | if n_embedding != 0: 122 | self.action_embedding = nn.Embedding(self.action_space.n, self.state_per_agent) 123 | 124 | def forward(self, state, output_distribution, action=None): 125 | """ 126 | action is only used for decentralized multiagent 127 | """ 128 | if isinstance(self.action_space, Box): 129 | q = self.q(torch.cat([obs, action], dim=-1)) 130 | else: 131 | if self.n_embedding > 0: 132 | # multiagent 133 | batch_size, _ = action.shape 134 | action_embedding = self.action_embedding(action).view(batch_size, -1) 135 | action_embedding[:, :self.state_per_agent] = 0 136 | state = state + action_embedding 137 | action = action[:, 0] 138 | q = self.q(state) 139 | while len(q.shape) > 2: 140 | q = q.squeeze(-1) # HW of size 1 if CNN 141 | # [b, a+1] 142 | v = q[:, -1:] 143 | q = q[:, :-1] 144 | q = q - q.mean(dim=1, keepdim=True) + v 145 | if output_distribution: 146 | # q for all actions 147 | return q 148 | else: 149 | # q for a particular action 150 | q = torch.gather(input=q,dim=1,index=action.unsqueeze(-1)) 151 | return q.squeeze(dim=1) 152 | 153 | class CategoricalActor(nn.Module): 154 | """ 155 | always returns a distribution 156 | """ 157 | def __init__(self, **net_args): 158 | super().__init__() 159 | self.softmax = nn.Softmax(dim=1) 160 | net_fn = net_args['network'] 161 | self.network = net_fn(**net_args) 162 | self.eps = 1e-5 163 | # if pi becomes truely deterministic (e.g. SAC alpha = 0) 164 | # q will become NaN, use eps to increase stability 165 | # and make SAC compatible with "Hard"ActorCritic 166 | 167 | def forward(self, obs): 168 | logit = self.network(obs) 169 | while len(logit.shape) > 2: 170 | logit = logit.squeeze(-1) # HW of size 1 if CNN 171 | probs = self.softmax(logit) 172 | probs = (probs + self.eps) 173 | probs = probs/probs.sum(dim=-1, keepdim=True) 174 | return probs 175 | 176 | class RegressionActor(nn.Module): 177 | """ 178 | determinsitc actor, used in DDPG and TD3 179 | """ 180 | def __init__(self, **net_args): 181 | super().__init__() 182 | net_fn = net_args['network'] 183 | self.network = net_fn(**net_args) 184 | 185 | def forward(self, obs): 186 | out = self.network(obs) 187 | while len(out.shape) > 2: 188 | out = out.squeeze(-1) # HW of size 1 if CNN 189 | return out 190 | 191 | -------------------------------------------------------------------------------- /src/viper/dagger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import time 4 | import ray 5 | from algorithms.utils import Config, LogClient, LogServer 6 | from algorithms.algorithm import Dagger 7 | import pdb 8 | 9 | os.environ['RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE']='1' 10 | 11 | 12 | """ 13 | This section contains run args, separated from args for the RL algorithm and agents 14 | """ 15 | args = Config() 16 | #### computation 17 | os.environ['CUDA_VISIBLE_DEVICES']='1' 18 | args.n_thread = 1 19 | args.parallel = False 20 | args.device = 'cpu' 21 | args.n_cpu = 1 # per agent, used only if parallel = True 22 | args.n_gpu = 0 23 | args.n_run = 3 24 | 25 | #### general 26 | args.debug = False 27 | args.test = True # if no training, only test 28 | args.profiling = False 29 | backend = 'tensorboard' 30 | 31 | import gym 32 | from algorithms.envs.Wrapper import GymWrapper 33 | env_name = 'CartPole-v1' 34 | #env_name = 'LunarLander-v2' 35 | 36 | #args.name='Imitation-QDagger-LunarLander-depth9' 37 | args.name='Imitation-Dagger-CartPole-depth2' 38 | 39 | #### misc 40 | args.save_period=99999 # in seconds 41 | args.log_period=int(20) 42 | args.seed = None 43 | args.test_interval = int(3e4) 44 | args.n_test = 50 45 | 46 | def env_fn(): 47 | return GymWrapper(env_name, 0, 1) 48 | 49 | algo_args = Config() 50 | algo_args.replay_size=int(1e6) 51 | algo_args.max_ep_len=600 52 | algo_args.n_step=int(4) 53 | 54 | #### checkpoint 55 | #algo_args.expert_init_checkpoint = 'checkpoints/lunar-1_CartPole-v1_SAC_50644/808241_500.0.pt' 56 | #algo_args.expert_init_checkpoint='checkpoints/SAC-Lunar_LunarLander-v2_SAC_16686/21538671_259.63662453068207.pt' 57 | 58 | algo_args.start_step = 0 59 | 60 | #from algorithms.config.SAC import getArgs 61 | #agent_args = getArgs(4, 2) 62 | #agent_args = getArgs(8, 4) 63 | from algorithms.config.heuristic import getArgs 64 | agent_args = getArgs(env_name) 65 | algo_args.expert_args = agent_args 66 | p_args, q_args, pi_args = agent_args.p_args, agent_args.q_args, agent_args.pi_args 67 | agent_args.parallel = args.parallel 68 | 69 | from algorithms.config.Dagger import getArgs 70 | algo_args.agent_args = getArgs(update_interval=int(1e4), max_depth=2, use_Q=False) 71 | 72 | env = env_fn() 73 | print(f"observation: {env.env.observation_space}, action: {env.env.action_space}") 74 | del env 75 | algo_args.env_fn = env_fn 76 | args.env_fn = env_fn 77 | 78 | if args.seed is None: 79 | args.seed = int(time.time()*1000)%65536 80 | 81 | if not p_args is None: 82 | print(f"rollout reuse:{(p_args.refresh_interval/q_args.update_interval*algo_args.batch_size)/p_args.model_buffer_size}") 83 | # each generated data will be used so many times 84 | 85 | import torch 86 | torch.set_num_threads(args.n_thread) 87 | print(f"n_threads {torch.get_num_threads()}") 88 | print(f"n_gpus {torch.cuda.device_count()}") 89 | 90 | ray.init(ignore_reinit_error = True, num_gpus=len(os.environ['CUDA_VISIBLE_DEVICES'].split(','))) 91 | 92 | 93 | reward = [] 94 | acc = [] 95 | short_name = args.name 96 | while args.n_run > 0: 97 | args.name = f'{short_name}_{env_name}_{agent_args.agent.__name__}_{args.seed}' 98 | logger = LogServer.remote({'run_args':args, 'algo_args':algo_args}, backend = '') 99 | logger = LogClient(logger) 100 | Dagger(logger = logger, run_args=args, **algo_args._toDict()).run() 101 | reward += [logger.buffer['test_episode_reward']] 102 | acc += [logger.buffer['imitation_learning_acc']] 103 | args.n_run -= 1 104 | args.seed += 1 105 | 106 | reward = np.stack(reward) 107 | acc = np.stack(acc) 108 | print(np.mean(reward), np.std(reward)) 109 | print(np.mean(acc), np.std(acc)) -------------------------------------------------------------------------------- /src/viper/lunar_lander.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import time 4 | import ray 5 | from algorithms.utils import Config, LogClient, LogServer 6 | from algorithms.algorithm import RL 7 | 8 | os.environ['RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE']='1' 9 | 10 | 11 | """ 12 | This section contains run args, separated from args for the RL algorithm and agents 13 | """ 14 | args = Config() 15 | #### computation 16 | os.environ['CUDA_VISIBLE_DEVICES']='1' 17 | args.n_thread = 1 18 | args.parallel = False 19 | args.device = 'cpu' 20 | args.n_cpu = 1 # per agent, used only if parallel = True 21 | args.n_gpu = 0 22 | 23 | #### general 24 | args.debug = False 25 | args.test = True # if no training, only test 26 | args.profiling = False 27 | backend = 'tensorboard' 28 | 29 | #### algorithm and environment 30 | from algorithms.config.SAC import getArgs 31 | 32 | import gym 33 | from algorithms.envs.Wrapper import GymWrapper 34 | env_name = 'LunarLander-v2' 35 | 36 | args.name='SAC-Lunar-reproduce-high-gamma' 37 | 38 | #### misc 39 | args.save_period=900 # in seconds 40 | args.log_period=int(20) 41 | args.seed = None 42 | args.test_interval = int(3e4) 43 | args.n_test = 10 44 | 45 | def env_fn(): 46 | return GymWrapper(env_name, 0, 0.2) 47 | 48 | agent_args = getArgs(8, 4) 49 | 50 | algo_args = Config() 51 | algo_args.replay_size=int(1e6) 52 | algo_args.max_ep_len=1000 53 | algo_args.n_step=int(1e8) 54 | #### checkpoint 55 | algo_args.init_checkpoint = 'checkpoints/SAC-Lunar_LunarLander-v2_SAC_16686/21538671_259.63662453068207.pt' 56 | algo_args.start_step = 0 57 | 58 | agent_args.gamma=0.999 59 | 60 | ########################## 61 | 62 | env = env_fn() 63 | print(f"observation: {env.env.observation_space}, action: {env.env.action_space}") 64 | del env 65 | algo_args.env_fn = env_fn 66 | args.env_fn = env_fn 67 | 68 | algo_args.agent_args = agent_args 69 | p_args, q_args, pi_args = agent_args.p_args, agent_args.q_args, agent_args.pi_args 70 | if args.debug: 71 | pi_args.update_interval = 1 72 | q_args.update_interval = 1 73 | algo_args.batch_size = 4 74 | algo_args.max_ep_len=2 75 | algo_args.replay_size=1 76 | if not p_args is None: 77 | p_args.model_buffer_size = 4 78 | algo_args.n_warmup=1 79 | args.n_test=1 80 | if args.test: 81 | algo_args.n_warmup = 0 82 | args.n_test = 50 83 | algo_args.n_step = 1 84 | if args.profiling: 85 | algo_args.batch_size=128 86 | if algo_args.agent_args.p_args is None: 87 | algo_args.n_step = 50 88 | else: 89 | algo_args.n_step = algo_args.batch_size + 10 90 | algo_args.replay_size = 1000 91 | algo_args.n_warmup = algo_args.batch_size 92 | args.n_test = 1 93 | algo_args.max_ep_len = 20 94 | if args.seed is None: 95 | args.seed = int(time.time()*1000)%65536 96 | 97 | agent_args.parallel = args.parallel 98 | args.name = f'{args.name}_{env_name}_{agent_args.agent.__name__}_{args.seed}' 99 | 100 | 101 | if not p_args is None: 102 | print(f"rollout reuse:{(p_args.refresh_interval/q_args.update_interval*algo_args.batch_size)/p_args.model_buffer_size}") 103 | # each generated data will be used so many times 104 | 105 | import torch 106 | torch.set_num_threads(args.n_thread) 107 | print(f"n_threads {torch.get_num_threads()}") 108 | print(f"n_gpus {torch.cuda.device_count()}") 109 | 110 | ray.init(ignore_reinit_error = True, num_gpus=len(os.environ['CUDA_VISIBLE_DEVICES'].split(','))) 111 | if args.test or args.debug or args.profiling: 112 | backend = '' 113 | logger = LogServer.remote({'run_args':args, 'algo_args':algo_args}, backend = backend) 114 | logger = LogClient(logger) 115 | if args.profiling: 116 | import cProfile 117 | cProfile.run("RL(logger = logger, run_args=args, **algo_args._toDict()).run()", 118 | filename=f'device{args.device}_parallel{args.parallel}.profile') 119 | else: 120 | RL(logger = logger, run_args=args, **algo_args._toDict()).run() 121 | -------------------------------------------------------------------------------- /src/viper/rl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import time 4 | import ray 5 | from algorithms.utils import Config, LogClient, LogServer 6 | from algorithms.algorithm import RL 7 | 8 | os.environ['RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE']='1' 9 | 10 | 11 | """ 12 | This section contains run args, separated from args for the RL algorithm and agents 13 | """ 14 | args = Config() 15 | #### computation 16 | os.environ['CUDA_VISIBLE_DEVICES']='1' 17 | args.n_thread = 1 18 | args.parallel = False 19 | args.device = 'cpu' 20 | args.n_cpu = 1 # per agent, used only if parallel = True 21 | args.n_gpu = 0 22 | 23 | #### general 24 | args.debug = False 25 | args.test = True # if no training, only test 26 | args.profiling = False 27 | backend = 'tensorboard' 28 | 29 | #### algorithm and environment 30 | from algorithms.config.SAC import getArgs 31 | 32 | import gym 33 | from algorithms.envs.Wrapper import GymWrapper 34 | env_name = 'CartPole-v1' 35 | 36 | args.name='SAC-CartPole' 37 | 38 | #### misc 39 | args.save_period=900 # in seconds 40 | args.log_period=int(20) 41 | args.seed = None 42 | args.test_interval = int(3e4) 43 | args.n_test = 10 44 | 45 | def env_fn(): 46 | return GymWrapper(env_name, 0, 1) 47 | 48 | agent_args = getArgs(4, 2) 49 | 50 | algo_args = Config() 51 | algo_args.replay_size=int(1e6) 52 | algo_args.max_ep_len=500 53 | algo_args.n_step=int(1e8) 54 | 55 | #### checkpoint 56 | algo_args.init_checkpoint = 'checkpoints/lunar-1_CartPole-v1_SAC_50644/808241_500.0.pt' 57 | algo_args.start_step = 0 58 | 59 | env = env_fn() 60 | print(f"observation: {env.env.observation_space}, action: {env.env.action_space}") 61 | del env 62 | algo_args.env_fn = env_fn 63 | args.env_fn = env_fn 64 | 65 | algo_args.agent_args = agent_args 66 | p_args, q_args, pi_args = agent_args.p_args, agent_args.q_args, agent_args.pi_args 67 | if args.debug: 68 | pi_args.update_interval = 1 69 | q_args.update_interval = 1 70 | algo_args.batch_size = 4 71 | algo_args.max_ep_len=2 72 | algo_args.replay_size=1 73 | if not p_args is None: 74 | p_args.model_buffer_size = 4 75 | algo_args.n_warmup=1 76 | args.n_test=1 77 | if args.test: 78 | algo_args.n_warmup = 0 79 | args.n_test = 50 80 | algo_args.n_step = 1 81 | if args.profiling: 82 | algo_args.batch_size=128 83 | if algo_args.agent_args.p_args is None: 84 | algo_args.n_step = 50 85 | else: 86 | algo_args.n_step = algo_args.batch_size + 10 87 | algo_args.replay_size = 1000 88 | algo_args.n_warmup = algo_args.batch_size 89 | args.n_test = 1 90 | algo_args.max_ep_len = 20 91 | if args.seed is None: 92 | args.seed = int(time.time()*1000)%65536 93 | 94 | agent_args.parallel = args.parallel 95 | args.name = f'{args.name}_{env_name}_{agent_args.agent.__name__}_{args.seed}' 96 | 97 | if not p_args is None: 98 | print(f"rollout reuse:{(p_args.refresh_interval/q_args.update_interval*algo_args.batch_size)/p_args.model_buffer_size}") 99 | # each generated data will be used so many times 100 | 101 | import torch 102 | torch.set_num_threads(args.n_thread) 103 | print(f"n_threads {torch.get_num_threads()}") 104 | print(f"n_gpus {torch.cuda.device_count()}") 105 | 106 | ray.init(ignore_reinit_error = True, num_gpus=len(os.environ['CUDA_VISIBLE_DEVICES'].split(','))) 107 | if args.test or args.debug or args.profiling: 108 | backend = '' 109 | logger = LogServer.remote({'run_args':args, 'algo_args':algo_args}, backend = backend) 110 | logger = LogClient(logger) 111 | if args.profiling: 112 | import cProfile 113 | cProfile.run("RL(logger = logger, run_args=args, **algo_args._toDict()).run()", 114 | filename=f'device{args.device}_parallel{args.parallel}.profile') 115 | else: 116 | RL(logger = logger, run_args=args, **algo_args._toDict()).run() 117 | -------------------------------------------------------------------------------- /visual/.ipynb_checkpoints/load_model_weights-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Load Tree weights" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 2, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "ename": "ModuleNotFoundError", 17 | "evalue": "No module named 'cdt'", 18 | "output_type": "error", 19 | "traceback": [ 20 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 21 | "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", 22 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"..\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0msrc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrl\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mPPO\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 23 | "\u001b[0;32m~/research/Explainability/XRL_BorealisAI/src/rl/__init__.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mPPO\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mPPO\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0menv_wrapper\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mStateNormWrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 24 | "\u001b[0;32m~/research/Explainability/XRL_BorealisAI/src/rl/PPO.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"..\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mcdt\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCDT\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msdt\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSDT\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 25 | "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'cdt'" 26 | ] 27 | } 28 | ], 29 | "source": [ 30 | "import numpy as np\n", 31 | "import sys\n", 32 | "import json\n", 33 | "import torch\n", 34 | "import gym\n", 35 | "sys.path.append(\"..\")\n", 36 | "from src.rl import PPO" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 13, 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "name": "stdout", 46 | "output_type": "stream", 47 | "text": [ 48 | "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n" 49 | ] 50 | }, 51 | { 52 | "name": "stderr", 53 | "output_type": "stream", 54 | "text": [ 55 | "/home/quantumiracle/.conda/envs/robo/lib/python3.6/site-packages/gym/envs/registration.py:14: PkgResourcesDeprecationWarning: Parameters to load are deprecated. Call .resolve and .require separately.\n", 56 | " result = entry_point.load(False)\n" 57 | ] 58 | }, 59 | { 60 | "ename": "NameError", 61 | "evalue": "name 'PPO' is not defined", 62 | "output_type": "error", 63 | "traceback": [ 64 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 65 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 66 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mrl_confs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjson\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mread_file\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# hyperparameters for il training\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m model = PPO(state_dim, action_dim, rl_confs[\"General\"][\"policy_approx\"], rl_confs[EnvName][\"learner_args\"],\\\n\u001b[0m\u001b[1;32m 13\u001b[0m **rl_confs[EnvName][\"alg_confs\"]).to(torch.device(rl_confs[EnvName][\"learner_args\"][\"device\"]))\n\u001b[1;32m 14\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 67 | "\u001b[0;31mNameError\u001b[0m: name 'PPO' is not defined" 68 | ] 69 | } 70 | ], 71 | "source": [ 72 | "EnvName = 'MountainCar-v0'\n", 73 | "m = 'cdt'\n", 74 | "\n", 75 | "env = gym.make(EnvName).unwrapped\n", 76 | "state_dim = env.observation_space.shape[0]\n", 77 | "action_dim = env.action_space.n # discrete\n", 78 | "\n", 79 | "conf_path = '../src/'+m+'/'+m+'_rl_train.json'\n", 80 | "with open(conf_path, \"r\") as read_file:\n", 81 | " rl_confs = json.load(read_file) # hyperparameters for il training\n", 82 | "\n", 83 | "model = PPO(state_dim, action_dim, rl_confs[\"General\"][\"policy_approx\"], rl_confs[EnvName][\"learner_args\"],\\\n", 84 | " **rl_confs[EnvName][\"alg_confs\"]).to(torch.device(rl_confs[EnvName][\"learner_args\"][\"device\"]))\n", 85 | "i=0\n", 86 | "model_path = rl_confs[EnvName][\"train_confs\"][\"model_path\"]+str(i)\n", 87 | "model.load_model(model_path)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [] 103 | } 104 | ], 105 | "metadata": { 106 | "kernelspec": { 107 | "display_name": "Python 3", 108 | "language": "python", 109 | "name": "python3" 110 | }, 111 | "language_info": { 112 | "codemirror_mode": { 113 | "name": "ipython", 114 | "version": 3 115 | }, 116 | "file_extension": ".py", 117 | "mimetype": "text/x-python", 118 | "name": "python", 119 | "nbconvert_exporter": "python", 120 | "pygments_lexer": "ipython3", 121 | "version": "3.6.8" 122 | } 123 | }, 124 | "nbformat": 4, 125 | "nbformat_minor": 2 126 | } 127 | --------------------------------------------------------------------------------