├── .gitignore ├── DRNet.py ├── LICENSE ├── README.md ├── datasets ├── FeatureBinarizer.py ├── adult │ ├── adult.data │ ├── adult.names │ └── adult.test ├── dataset.py ├── heloc │ ├── heloc_data_dictionary-2.xlsx │ └── heloc_dataset_v1.csv ├── house │ └── house_16H.csv └── magic │ ├── magic04.data │ └── magic04.names ├── sparse_linear.py └── tutorial.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .ipynb_checkpoints -------------------------------------------------------------------------------- /DRNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import itertools 4 | import pickle 5 | import numpy as np 6 | import pandas as pd 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | import torchvision 13 | from tqdm import tqdm 14 | 15 | from sparse_linear import sparse_linear 16 | 17 | class RuleFunction(torch.autograd.Function): 18 | ''' 19 | The autograd function used in the Rules Layer. 20 | The forward function implements the equation (1) in the paper. 21 | The backward function implements the gradient of the foward function. 22 | ''' 23 | @staticmethod 24 | def forward(ctx, input, weight, bias): 25 | ctx.save_for_backward(input, weight, bias) 26 | 27 | output = input.mm(weight.t()) 28 | output = output + bias.unsqueeze(0).expand_as(output) 29 | output = output - (weight * (weight > 0)).sum(-1).unsqueeze(0).expand_as(output) 30 | return output 31 | 32 | @staticmethod 33 | def backward(ctx, grad_output): 34 | input, weight, bias = ctx.saved_tensors 35 | 36 | grad_input = grad_output.mm(weight) 37 | grad_weight = grad_output.t().mm(input) - grad_output.sum(0).unsqueeze(1).expand_as(weight) * (weight > 0) 38 | grad_bias = grad_output.sum(0) 39 | grad_bias[(bias >= 1) * (grad_bias < 0)] = 0 40 | 41 | return grad_input, grad_weight, grad_bias 42 | 43 | class LabelFunction(torch.autograd.Function): 44 | ''' 45 | The autograd function used in the OR Layer. 46 | The forward function implements the equations (4) and (5) in the paper. 47 | The backward function implements the standard STE estimator. 48 | ''' 49 | 50 | @staticmethod 51 | def forward(ctx, input, weight, bias): 52 | ctx.save_for_backward(input, weight, bias) 53 | 54 | output = input.mm((weight.t() > 0).float()) 55 | output += bias.unsqueeze(0).expand_as(output) 56 | 57 | return output 58 | 59 | @staticmethod 60 | def backward(ctx, grad_output): 61 | input, weight, bias = ctx.saved_tensors 62 | 63 | grad_input = grad_output.mm(weight) 64 | grad_weight = grad_output.t().mm(input) 65 | grad_bias = grad_output.sum(0) 66 | 67 | return grad_input, grad_weight, grad_bias 68 | 69 | class Binarization(torch.autograd.Function): 70 | ''' 71 | The autograd function for the binarization activation in the Rules Layer. 72 | The forward function implements the equations (2) in the paper. Note here 0.999999 is used to cancel the rounding error. 73 | The backward function implements the STE estimator with equation (3) in the paper. 74 | ''' 75 | 76 | @staticmethod 77 | def forward(ctx, input): 78 | ctx.save_for_backward(input) 79 | output = (input > 0.999999).float() 80 | 81 | return output 82 | 83 | @staticmethod 84 | def backward(ctx, grad_output): 85 | input, = ctx.saved_tensors 86 | grad_input = grad_output.clone() 87 | grad_input[(input < 0)] = 0 88 | grad_input[(input >= 1) * (grad_output < 0)] = 0 89 | 90 | return grad_input 91 | 92 | class DRNet(nn.Module): 93 | def __init__(self, in_features, num_rules, out_features): 94 | """ 95 | DR-Net: https://arxiv.org/pdf/2103.02826.pdf 96 | 97 | Args 98 | in_features (int): the input dimension. 99 | num_rules (int): number of hidden neurons, which is also the maximum number of rules. 100 | out_features (int): the output dimension; should always be 1. 101 | """ 102 | super(DRNet, self).__init__() 103 | 104 | self.linear = sparse_linear('l0') 105 | self.and_layer = self.linear(in_features, num_rules, linear=RuleFunction.apply) 106 | self.or_layer = self.linear(num_rules, out_features, linear=LabelFunction.apply) 107 | 108 | self.and_layer.bias.requires_grad = False 109 | self.and_layer.bias.data.fill_(1) 110 | self.or_layer.weight.requires_grad = False 111 | self.or_layer.weight.data.fill_(1) 112 | self.or_layer.bias.requires_grad = False 113 | self.or_layer.bias.data.fill_(-0.5) 114 | 115 | def forward(self, out): 116 | out = self.and_layer(out) 117 | out = Binarization.apply(out) 118 | out = self.or_layer(out) 119 | 120 | return out 121 | 122 | def regularization(self): 123 | """ 124 | Implements the Sparsity-Based Regularization (equation 7). 125 | 126 | Returns 127 | regularization (float): the regularization term. 128 | """ 129 | 130 | regularization = ((self.and_layer.regularization(axis=1) + 1) * self.or_layer.regularization(mean=False)).mean() 131 | 132 | return regularization 133 | 134 | def statistics(self): 135 | """ 136 | Return the statistics of the network. 137 | 138 | Returns 139 | sparsity (float): sparsity of the rule set. 140 | num_rules (int): number of unpruned rules. 141 | """ 142 | 143 | rule_indices = (self.or_layer.masked_weight() != 0).nonzero()[:, 1] 144 | sparsity = (self.and_layer.masked_weight()[rule_indices] == 0).float().mean().item() 145 | num_rules = rule_indices.size(0) 146 | return sparsity, num_rules 147 | 148 | def get_rules(self, header=None): 149 | """ 150 | Translate network into rules. 151 | 152 | Args 153 | header (list OR None): the description of each input feature. 154 | Returns 155 | rules (np.array OR list): contains a list of rules. 156 | If header is None (2-d np.array), each rule is represented by a list of numbers (1: positive feature, 0: negative feature, 0.5: dont' care). 157 | If header is not None (list of lists): each rule is represented by a list of strings. 158 | """ 159 | 160 | self.eval() 161 | self.to('cpu') 162 | 163 | prune_weights = self.and_layer.masked_weight() 164 | valid_indices = self.or_layer.masked_weight().nonzero(as_tuple=True)[1] 165 | rules = np.sign(prune_weights[valid_indices].detach().numpy()) * 0.5 + 0.5 166 | 167 | if header != None: 168 | rules_exp = [] 169 | for weight in prune_weights[valid_indices]: 170 | rule = [] 171 | for w, h in zip(weight, header): 172 | if w < 0: 173 | rule.append('NOT ' + h) 174 | elif w > 0: 175 | rule.append(h) 176 | rules_exp.append(rule) 177 | rules = rules_exp 178 | 179 | return rules 180 | 181 | def predict(self, X): 182 | """ 183 | Classifiy the labels of X using rules encoded by the network. 184 | 185 | Args 186 | X (np.array) 2-d np.array of instances with binary features. 187 | Returns 188 | results (np.array): 1-d array of labels. 189 | """ 190 | 191 | rules = self.get_rules() 192 | 193 | results = [] 194 | for x in X: 195 | indices = np.where(np.absolute(x - rules).max(axis=1) < 1)[0] 196 | result = int(len(indices) != 0) 197 | results.append(result) 198 | return np.array(results) 199 | 200 | def save(self, path): 201 | state = { 202 | 'state_dict': self.state_dict(), 203 | 'parameters': { 204 | 'in_features': self.and_layer.weight.size(1), 205 | 'num_rules': self.and_layer.bias.size(0), 206 | 'out_features': self.or_layer.bias.size(0), 207 | 'and_lam': self.and_lam, 208 | 'or_lam': self.or_lam, 209 | } 210 | } 211 | 212 | dir_path = os.path.dirname(path) 213 | os.makedirs(dir_path, exist_ok=True) 214 | torch.save(state, path) 215 | 216 | @staticmethod 217 | def load(path): 218 | state = torch.load(path) 219 | model = DRNet(**state['parameters']) 220 | model.load_state_dict(state['state_dict']) 221 | 222 | return model 223 | 224 | def train(net, train_set, test_set, device="cuda", epochs=2000, batch_size=2000, lr=1e-2, 225 | and_lam=1e-2, or_lam=1e-5, num_alter=500): 226 | def score(out, y): 227 | y_labels = (out >= 0).float() 228 | y_corrs = (y_labels == y.reshape(y_labels.size())).float() 229 | 230 | return y_corrs 231 | 232 | reg_lams = [and_lam, or_lam] 233 | optimizers = [optim.Adam(net.and_layer.parameters(), lr=lr), optim.Adam(net.or_layer.parameters(), lr=lr)] 234 | 235 | criterion = nn.BCEWithLogitsLoss().to(device) 236 | 237 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, drop_last=True, shuffle=True) 238 | 239 | with tqdm(total=epochs, desc="Epoch", bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}") as t: 240 | for epoch in range(epochs): 241 | net.to(device) 242 | net.train() 243 | 244 | batch_losses = [] 245 | batch_corres = [] 246 | for index, (x_batch, y_batch) in enumerate(train_loader): 247 | x_batch = x_batch.to(device) 248 | y_batch = y_batch.to(device) 249 | 250 | out = net(x_batch) 251 | 252 | phase = int((epoch / num_alter) % 2) 253 | loss = criterion(out, y_batch.reshape(out.size())) + reg_lams[phase] * net.regularization() 254 | optimizers[phase].zero_grad() 255 | loss.backward() 256 | optimizers[phase].step() 257 | 258 | corr = score(out, y_batch).sum() 259 | 260 | batch_losses.append(loss.item()) 261 | batch_corres.append(corr.item()) 262 | epoch_loss = torch.Tensor(batch_losses).mean().item() 263 | epoch_accu = torch.Tensor(batch_corres).sum().item() / len(train_set) 264 | 265 | net.to('cpu') 266 | net.eval() 267 | with torch.no_grad(): 268 | test_accu = score(net(test_set[:][0]), test_set[:][1]).mean().item() 269 | sparsity, num_rules = net.statistics() 270 | 271 | t.update(1) 272 | t.set_postfix({ 273 | 'loss': epoch_loss, 274 | 'epoch accu': epoch_accu, 275 | 'test accu': test_accu, 276 | 'num rules': num_rules, 277 | 'sparsity': sparsity, 278 | }) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Litao Qiao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Description 2 | 3 | Decision rules network is a special type of MLP that, once trained, can be directly mapped to a set of decision rules. 4 | 5 | This repository contains the source code of the paper ["Learning Accurate and Interpretable Decision Rule Sets from Neural Networks"](https://arxiv.org/pdf/2103.02826.pdf). 6 | 7 | Please see tutorial.ipynb to get started. 8 | 9 | # Dependencies 10 | 11 | 1. numpy 12 | 2. pandas 13 | 3. scitkit-learn 14 | 4. pytorch -------------------------------------------------------------------------------- /datasets/FeatureBinarizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from numpy import ndarray 4 | from pandas import DataFrame, Series 5 | from sklearn.base import TransformerMixin 6 | from sklearn.preprocessing import OneHotEncoder, StandardScaler 7 | from sklearn.tree import DecisionTreeClassifier 8 | 9 | 10 | # noinspection PyPep8Naming 11 | class FeatureBinarizer(TransformerMixin): 12 | '''Transformer for binarizing categorical and ordinal features. 13 | 14 | For use with BooleanRuleCG, LogisticRuleRegression and LinearRuleRegression 15 | ''' 16 | def __init__(self, colCateg=[], numThresh=9, negations=False, threshStr=False, returnOrd=False, **kwargs): 17 | """ 18 | Args: 19 | colCateg (list): Categorical features ('object' dtype automatically treated as categorical) 20 | numThresh (int): Number of quantile thresholds used to binarize ordinal variables 21 | negations (bool): Append negations 22 | threshStr (bool): Convert thresholds on ordinal features to strings 23 | returnOrd (bool): Also return standardized ordinal features 24 | """ 25 | # List of categorical columns 26 | if type(colCateg) is pd.Series: 27 | self.colCateg = colCateg.tolist() 28 | elif type(colCateg) is not list: 29 | self.colCateg = [colCateg] 30 | else: 31 | self.colCateg = colCateg 32 | # Number of quantile thresholds used to binarize ordinal features 33 | self.numThresh = numThresh 34 | self.thresh = {} 35 | # whether to append negations 36 | self.negations = negations 37 | # whether to convert thresholds on ordinal features to strings 38 | self.threshStr = threshStr 39 | # Also return standardized ordinal features 40 | self.returnOrd = returnOrd 41 | 42 | def fit(self, X): 43 | '''Fit FeatureBinarizer to data 44 | 45 | Args: 46 | X (DataFrame): Original features 47 | Returns: 48 | FeatureBinarizer: Self 49 | self.maps (dict): Mappings for unary/binary columns 50 | self.enc (dict): OneHotEncoders for categorical columns 51 | self.thresh (dict(array)): Thresholds for ordinal columns 52 | self.NaN (list): Ordinal columns containing NaN values 53 | self.ordinal (list): Ordinal columns 54 | self.scaler (StandardScaler): StandardScaler for ordinal columns 55 | ''' 56 | data = X 57 | # Quantile probabilities 58 | quantProb = np.linspace(1. / (self.numThresh + 1.), self.numThresh / (self.numThresh + 1.), self.numThresh) 59 | # Initialize 60 | maps = {} 61 | enc = {} 62 | thresh = {} 63 | NaN = [] 64 | if self.returnOrd: 65 | ordinal = [] 66 | 67 | # Iterate over columns 68 | for c in data: 69 | # number of unique values 70 | valUniq = data[c].nunique() 71 | 72 | # Constant or binary column 73 | if valUniq <= 2: 74 | # Mapping to 0, 1 75 | maps[c] = pd.Series(range(valUniq), index=np.sort(data[c].dropna().unique())) 76 | 77 | # Categorical column 78 | elif (c in self.colCateg) or (data[c].dtype == 'object'): 79 | # OneHotEncoder object 80 | enc[c] = OneHotEncoder(sparse=False, dtype=int, handle_unknown='ignore') 81 | # Fit to observed categories 82 | enc[c].fit(data[[c]]) 83 | 84 | # Ordinal column 85 | elif np.issubdtype(data[c].dtype, np.integer) | np.issubdtype(data[c].dtype, np.floating): 86 | # Few unique values 87 | if valUniq <= self.numThresh + 1: 88 | # Thresholds are sorted unique values excluding maximum 89 | thresh[c] = np.sort(data[c].unique())[:-1] 90 | # Many unique values 91 | else: 92 | # Thresholds are quantiles excluding repetitions 93 | thresh[c] = data[c].quantile(q=quantProb).unique() 94 | if data[c].isnull().any(): 95 | # Contains NaN values 96 | NaN.append(c) 97 | if self.returnOrd: 98 | ordinal.append(c) 99 | 100 | else: 101 | print(("Skipping column '" + str(c) + "': data type cannot be handled")) 102 | continue 103 | 104 | self.maps = maps 105 | self.enc = enc 106 | self.thresh = thresh 107 | self.NaN = NaN 108 | if self.returnOrd: 109 | self.ordinal = ordinal 110 | # Fit StandardScaler to ordinal features 111 | self.scaler = StandardScaler().fit(data[ordinal]) 112 | return self 113 | 114 | def transform(self, X): 115 | '''Binarize features 116 | 117 | Args: 118 | X (DataFrame): Original features 119 | Returns: 120 | A (DataFrame): Binarized features with MultiIndex column labels 121 | Xstd (DataFrame, optional): Standardized ordinal features 122 | ''' 123 | data = X 124 | maps = self.maps 125 | enc = self.enc 126 | thresh = self.thresh 127 | NaN = self.NaN 128 | 129 | # Initialize dataframe 130 | A = pd.DataFrame(index=data.index, 131 | columns=pd.MultiIndex.from_arrays([[], [], []], names=['feature', 'operation', 'value'])) 132 | 133 | # Iterate over columns 134 | for c in data: 135 | # Constant or binary column 136 | if c in maps: 137 | # Rename values to 0, 1 138 | A[(str(c), '', '')] = data[c].map(maps[c]).astype(int) 139 | if self.negations: 140 | A[(str(c), 'not', '')] = 1 - A[(str(c), '', '')] 141 | 142 | # Categorical column 143 | elif c in enc: 144 | # Apply OneHotEncoder 145 | Anew = enc[c].transform(data[[c]]) 146 | Anew = pd.DataFrame(Anew, index=data.index, columns=enc[c].categories_[0].astype(str)) 147 | if self.negations: 148 | # Append negations 149 | Anew = pd.concat([Anew, 1 - Anew], axis=1, keys=[(str(c), '=='), (str(c), '!=')]) 150 | else: 151 | Anew.columns = pd.MultiIndex.from_product([[str(c)], ['=='], Anew.columns]) 152 | # Concatenate 153 | A = pd.concat([A, Anew], axis=1) 154 | 155 | # Ordinal column 156 | elif c in thresh: 157 | # Threshold values to produce binary arrays 158 | Anew = (data[c].values[:, np.newaxis] <= thresh[c]).astype(int) 159 | if self.negations: 160 | # Append negations 161 | Anew = np.concatenate((Anew, 1 - Anew), axis=1) 162 | ops = ['<=', '>'] 163 | else: 164 | ops = ['<='] 165 | # Convert to dataframe with column labels 166 | if self.threshStr: 167 | Anew = pd.DataFrame(Anew, index=data.index, 168 | columns=pd.MultiIndex.from_product([[str(c)], ops, thresh[c].astype(str)])) 169 | else: 170 | Anew = pd.DataFrame(Anew, index=data.index, 171 | columns=pd.MultiIndex.from_product([[str(c)], ops, thresh[c]])) 172 | if c in NaN: 173 | # Ensure that rows corresponding to NaN values are zeroed out 174 | indNull = data[c].isnull() 175 | Anew.loc[indNull] = 0 176 | # Add NaN indicator column 177 | Anew[(str(c), '==', 'NaN')] = indNull.astype(int) 178 | if self.negations: 179 | Anew[(str(c), '!=', 'NaN')] = (~indNull).astype(int) 180 | # Concatenate 181 | A = pd.concat([A, Anew], axis=1) 182 | 183 | else: 184 | print(("Skipping column '" + str(c) + "': data type cannot be handled")) 185 | continue 186 | 187 | if self.returnOrd: 188 | # Standardize ordinal features 189 | Xstd = self.scaler.transform(data[self.ordinal]) 190 | Xstd = pd.DataFrame(Xstd, index=data.index, columns=self.ordinal) 191 | # Fill NaN with mean (which is now zero) 192 | Xstd.fillna(0, inplace=True) 193 | return A, Xstd 194 | else: 195 | return A -------------------------------------------------------------------------------- /datasets/adult/adult.names: -------------------------------------------------------------------------------- 1 | | This data was extracted from the census bureau database found at 2 | | http://www.census.gov/ftp/pub/DES/www/welcome.html 3 | | Donor: Ronny Kohavi and Barry Becker, 4 | | Data Mining and Visualization 5 | | Silicon Graphics. 6 | | e-mail: ronnyk@sgi.com for questions. 7 | | Split into train-test using MLC++ GenCVFiles (2/3, 1/3 random). 8 | | 48842 instances, mix of continuous and discrete (train=32561, test=16281) 9 | | 45222 if instances with unknown values are removed (train=30162, test=15060) 10 | | Duplicate or conflicting instances : 6 11 | | Class probabilities for adult.all file 12 | | Probability for the label '>50K' : 23.93% / 24.78% (without unknowns) 13 | | Probability for the label '<=50K' : 76.07% / 75.22% (without unknowns) 14 | | 15 | | Extraction was done by Barry Becker from the 1994 Census database. A set of 16 | | reasonably clean records was extracted using the following conditions: 17 | | ((AAGE>16) && (AGI>100) && (AFNLWGT>1)&& (HRSWK>0)) 18 | | 19 | | Prediction task is to determine whether a person makes over 50K 20 | | a year. 21 | | 22 | | First cited in: 23 | | @inproceedings{kohavi-nbtree, 24 | | author={Ron Kohavi}, 25 | | title={Scaling Up the Accuracy of Naive-Bayes Classifiers: a 26 | | Decision-Tree Hybrid}, 27 | | booktitle={Proceedings of the Second International Conference on 28 | | Knowledge Discovery and Data Mining}, 29 | | year = 1996, 30 | | pages={to appear}} 31 | | 32 | | Error Accuracy reported as follows, after removal of unknowns from 33 | | train/test sets): 34 | | C4.5 : 84.46+-0.30 35 | | Naive-Bayes: 83.88+-0.30 36 | | NBTree : 85.90+-0.28 37 | | 38 | | 39 | | Following algorithms were later run with the following error rates, 40 | | all after removal of unknowns and using the original train/test split. 41 | | All these numbers are straight runs using MLC++ with default values. 42 | | 43 | | Algorithm Error 44 | | -- ---------------- ----- 45 | | 1 C4.5 15.54 46 | | 2 C4.5-auto 14.46 47 | | 3 C4.5 rules 14.94 48 | | 4 Voted ID3 (0.6) 15.64 49 | | 5 Voted ID3 (0.8) 16.47 50 | | 6 T2 16.84 51 | | 7 1R 19.54 52 | | 8 NBTree 14.10 53 | | 9 CN2 16.00 54 | | 10 HOODG 14.82 55 | | 11 FSS Naive Bayes 14.05 56 | | 12 IDTM (Decision table) 14.46 57 | | 13 Naive-Bayes 16.12 58 | | 14 Nearest-neighbor (1) 21.42 59 | | 15 Nearest-neighbor (3) 20.35 60 | | 16 OC1 15.04 61 | | 17 Pebls Crashed. Unknown why (bounds WERE increased) 62 | | 63 | | Conversion of original data as follows: 64 | | 1. Discretized agrossincome into two ranges with threshold 50,000. 65 | | 2. Convert U.S. to US to avoid periods. 66 | | 3. Convert Unknown to "?" 67 | | 4. Run MLC++ GenCVFiles to generate data,test. 68 | | 69 | | Description of fnlwgt (final weight) 70 | | 71 | | The weights on the CPS files are controlled to independent estimates of the 72 | | civilian noninstitutional population of the US. These are prepared monthly 73 | | for us by Population Division here at the Census Bureau. We use 3 sets of 74 | | controls. 75 | | These are: 76 | | 1. A single cell estimate of the population 16+ for each state. 77 | | 2. Controls for Hispanic Origin by age and sex. 78 | | 3. Controls by Race, age and sex. 79 | | 80 | | We use all three sets of controls in our weighting program and "rake" through 81 | | them 6 times so that by the end we come back to all the controls we used. 82 | | 83 | | The term estimate refers to population totals derived from CPS by creating 84 | | "weighted tallies" of any specified socio-economic characteristics of the 85 | | population. 86 | | 87 | | People with similar demographic characteristics should have 88 | | similar weights. There is one important caveat to remember 89 | | about this statement. That is that since the CPS sample is 90 | | actually a collection of 51 state samples, each with its own 91 | | probability of selection, the statement only applies within 92 | | state. 93 | 94 | 95 | >50K, <=50K. 96 | 97 | age: continuous. 98 | workclass: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked. 99 | fnlwgt: continuous. 100 | education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool. 101 | education-num: continuous. 102 | marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse. 103 | occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces. 104 | relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried. 105 | race: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black. 106 | sex: Female, Male. 107 | capital-gain: continuous. 108 | capital-loss: continuous. 109 | hours-per-week: continuous. 110 | native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands. 111 | -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from sklearn.impute import SimpleImputer 7 | from sklearn.preprocessing import OrdinalEncoder, KBinsDiscretizer, LabelEncoder, LabelBinarizer 8 | from sklearn.model_selection import train_test_split, KFold, StratifiedKFold 9 | from .FeatureBinarizer import FeatureBinarizer 10 | 11 | def predefined_dataset(name, binary_y=False): 12 | """ 13 | Define how to read specific datasets and return structured X and Y data. 14 | 15 | Args 16 | name (str): the name of the dataset to read. 17 | binary_y (bool): if True, force the dataset to only have two classes. 18 | 19 | Returns 20 | table_X (DataFrame): instances, values can be strings or numbers. 21 | table_Y (DataFrame): labels, values can be strings or numbers. 22 | categorical_cols (list): A list of column names that are categorical data. 23 | numerical_cols (list): A list of column names that are numerical data. 24 | """ 25 | 26 | dir_path = os.path.dirname(os.path.realpath(__file__)) # .py 27 | 28 | ### UCI datasets 29 | if name == 'adult': 30 | # https://archive.ics.uci.edu/ml/datasets/adult 31 | # X dim: (30162, 14) 32 | # Y counts: {'<=50K': 22654, '>50K': 7508} 33 | table = pd.read_csv(dir_path + '/adult/adult.data', header=0, na_values='?', skipinitialspace=True).dropna() 34 | table_X = table.iloc[:, :-1].copy() 35 | table_Y = table.iloc[:, -1].copy() 36 | categorical_cols = None 37 | numerical_cols = None 38 | 39 | elif name == 'magic': 40 | # http://archive.ics.uci.edu/ml/datasets/MAGIC+GAMMA+Telescope 41 | # X dim: (19020, 10/90) 42 | # Y counts: {'g': 12332, 'h': 6688} 43 | table = pd.read_csv(dir_path + '/magic/magic04.data', header=0, na_values='?', skipinitialspace=True).dropna() 44 | table_X = table.iloc[:, :-1].copy() 45 | table_Y = table.iloc[:, -1].copy() 46 | categorical_cols = None 47 | numerical_cols = None 48 | 49 | ### OpenML datasets 50 | elif name == 'house': 51 | # https://www.openml.org/d/821 52 | # X dim: (22784, 16/132) 53 | # Y counts: {'N': 6744, 'P': 16040} 54 | table = pd.read_csv(dir_path + '/house/house_16H.csv', header=0, skipinitialspace=True) 55 | table_X = table.iloc[:, :-1].copy() 56 | table_Y = table.iloc[:, -1].copy() 57 | categorical_cols = None 58 | numerical_cols = None 59 | 60 | ### Others 61 | elif name == 'heloc': 62 | # https://community.fico.com/s/explainable-machine-learning-challenge?tabset-3158a=2&tabset-158d9=3 63 | # X dim: (2502, 23) 64 | # Y counts: {'Bad': 1560, 'Good': 942} 65 | table = pd.read_csv(dir_path + '/heloc/heloc_dataset_v1.csv', header=0, na_values=['-7', '-8', '-9'], skipinitialspace=True)#.dropna() 66 | table_X = table.iloc[:, 1:].copy() 67 | table_Y = table.iloc[:, 0].copy() 68 | categorical_cols = None 69 | numerical_cols = None 70 | 71 | else: 72 | raise NameError(f'The input dataset is not found: {name}.') 73 | 74 | return table_X, table_Y, categorical_cols, numerical_cols 75 | 76 | def transform_dataset(name, method='ordinal', negations=False, labels='ordinal'): 77 | """ 78 | Transform values in datasets (from predefined_dataset) into real numbers or binary numbers. 79 | 80 | Args 81 | name (str): the name of the dataset. 82 | method (str): specify how the instances are encoded: 83 | 'origin': encode categorical features as integers and leave the numerical features as they are (float). 84 | 'ordinal': encode all features as integers; numerical features are discretized into intervals. 85 | 'onehot': one-hot encode the integer features transformed using 'ordinal' method. 86 | 'onehot-compare': one-hot encode the categorical features just like how they are done in 'onehot' method; 87 | one-hot encode numerical features by comparing them with different threhsolds and encode 1 if they are smaller than threholds. 88 | negations (bool): whether append negated binary features; only valid when method is 'onehot' or 'onehot-compare'. 89 | labels (str): specify how the labels are transformed. 90 | 'ordinal': output Y is a 1d array of integer values ([0, 1, 2, ...]); each label is an integer value. 91 | 'binary': output Y is a 1d array of binary values ([0, 1, 0, ...]); each label is forced to be a binary value (see predefined_dataset). 92 | 'onehot': output Y is a 2d array of one-hot encoded values ([[0, 1, 0], [1, 0, 0], [0, 0, 1]]); each label is a one-hot encoded 1d array. 93 | 94 | Return 95 | X (DataFrame): 2d float array; transformed instances. 96 | Y (np.array): 1d or 2d (labels='onehot') integer array; transformed labels;. 97 | X_headers (list|dict): if method='ordinal', a dict where keys are features and values and their categories; otherwise, a list of binarized features. 98 | Y_headers (list): the names of the labels, indexed by the values in Y. 99 | """ 100 | 101 | METHOD = ['origin', 'ordinal', 'onehot', 'onehot-compare'] 102 | LABELS = ['ordinal', 'binary', 'onehot'] 103 | if method not in METHOD: 104 | raise ValueError(f'method={method} is not a valid option. The options are {METHOD}') 105 | if labels not in LABELS: 106 | raise ValueError(f'labels={labels} is not a valid option. The options are {LABELS}') 107 | 108 | table_X, table_Y, categorical_cols, numerical_cols = predefined_dataset(name, binary_y=labels == 'binary') 109 | 110 | # By default, columns with object type are treated as categorical features and rest are numerical features 111 | # All numerical features that have fewer than 5 unique values should be considered as categorical features 112 | if categorical_cols is None: 113 | categorical_cols = list(table_X.columns[(table_X.dtypes == np.dtype('O')).to_numpy().nonzero()[0]]) 114 | if numerical_cols is None: 115 | numerical_cols = [col for col in table_X.columns if col not in categorical_cols and np.unique(table_X[col].to_numpy()).shape[0] > 5] 116 | categorical_cols = [col for col in table_X.columns if col not in numerical_cols] 117 | 118 | # Fill categorical nan values with most frequent value and numerical nan values with the mean value 119 | if len(categorical_cols) != 0: 120 | imp_cat = SimpleImputer(missing_values=np.nan, strategy='most_frequent') 121 | table_X[categorical_cols] = imp_cat.fit_transform(table_X[categorical_cols]) 122 | if len(numerical_cols) != 0: 123 | imp_num = SimpleImputer(missing_values=np.nan, strategy='mean') 124 | table_X[numerical_cols] = imp_num.fit_transform(table_X[numerical_cols]) 125 | 126 | if np.nan in table_X or np.nan in table_Y: 127 | raise ValueError('Dataset should not have nan value!') 128 | 129 | # Encode instances 130 | X = table_X.copy() 131 | 132 | col_categories = [] 133 | if method in ['origin', 'ordinal'] and len(categorical_cols) != 0: 134 | # Convert categorical strings to integers that represent different categories 135 | ord_enc = OrdinalEncoder() 136 | X[categorical_cols] = ord_enc.fit_transform(X[categorical_cols]) 137 | col_categories = {col: list(categories) for col, categories in zip(categorical_cols, ord_enc.categories_)} 138 | 139 | col_intervals = [] 140 | if method in ['ordinal', 'onehot'] and len(numerical_cols) != 0: 141 | # Discretize numerical values to integers that represent different intervals 142 | kbin_dis = KBinsDiscretizer(encode='ordinal', strategy='kmeans') 143 | X[numerical_cols] = kbin_dis.fit_transform(X[numerical_cols]) 144 | col_intervals = {col: [f'({intervals[i]:.2f}, {intervals[i+1]:.2f})' for i in range(len(intervals) - 1)] for col, intervals in zip(numerical_cols, kbin_dis.bin_edges_)} 145 | 146 | if method in ['onehot']: 147 | # Make numerical values to interval strings so that FeatureBinarizer can process them as categorical values 148 | for col in numerical_cols: 149 | X[col] = np.array(col_intervals[col]).astype('object')[X[col].astype(int)] 150 | 151 | if method in ['onehot', 'onehot-compare']: 152 | # One-hot encode categorical values and encode numerical values by comparing with thresholds 153 | fb = FeatureBinarizer(colCateg=categorical_cols, negations=negations) 154 | X = fb.fit_transform(X) 155 | 156 | if method in ['origin']: 157 | # X_headers is a list of features 158 | X_headers = [column for column in X.columns] 159 | if method in ['ordinal']: 160 | # X_headers is a dict where keys are features and values and their categories 161 | X_headers = {col: col_categories[col] if col in col_categories else col_intervals[col] for col in table_X.columns} 162 | else: 163 | # X_headers is a list of binarized features 164 | X_headers = ["".join(map(str, column)) for column in X.columns] 165 | 166 | if method not in ['origin']: 167 | X = X.astype(int) 168 | 169 | # Encode labels 170 | le = LabelEncoder() 171 | Y = le.fit_transform(table_Y).astype(int) 172 | Y_headers = le.classes_ 173 | if labels == 'onehot': 174 | lb = LabelBinarizer() 175 | Y = lb.fit_transform(Y) 176 | 177 | return X, Y, X_headers, Y_headers 178 | 179 | def split_dataset(X, Y, test=0.2, shuffle=None): 180 | X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=test, random_state=shuffle) 181 | 182 | return X_train, X_test, Y_train, Y_test 183 | 184 | def kfold_dataset(X, Y, k=5, shuffle=None): 185 | kf = StratifiedKFold(n_splits=k, shuffle=bool(shuffle), random_state=shuffle) 186 | datasets = [(X.iloc[train_index], X.iloc[test_index], Y[train_index], Y[test_index]) 187 | for train_index, test_index in kf.split(X, Y if len(Y.shape) == 1 else Y.argmax(1))] 188 | 189 | return datasets 190 | 191 | def nested_kfold_dataset(X, Y, outer_k=5, inner_k=5, shuffle=None): 192 | inner_kf = StratifiedKFold(n_splits=inner_k, shuffle=bool(shuffle), random_state=shuffle) 193 | 194 | datasets = [] 195 | for dataset in kfold_dataset(X, Y, k=outer_k, shuffle=shuffle): 196 | X_train_valid, X_test, Y_train_valid, Y_test = dataset 197 | 198 | nested_datasets = [] 199 | for train_index, valid_index in inner_kf.split( 200 | X_train_valid, Y_train_valid if len(Y.shape) == 1 else Y_train_valid.argmax(1)): 201 | X_train = X.iloc[train_index] 202 | X_valid = X.iloc[valid_index] 203 | Y_train = Y[train_index] 204 | Y_valid = Y[valid_index] 205 | nested_datasets.append([X_train, X_valid, Y_train, Y_valid]) 206 | datasets.append([X_train_valid, X_test, Y_train_valid, Y_test, nested_datasets]) 207 | 208 | return datasets -------------------------------------------------------------------------------- /datasets/heloc/heloc_data_dictionary-2.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Joeyonng/decision-rules-network/ac246983407fbb78b18369f6936ff0c1657accc3/datasets/heloc/heloc_data_dictionary-2.xlsx -------------------------------------------------------------------------------- /datasets/magic/magic04.names: -------------------------------------------------------------------------------- 1 | 1. Title of Database: MAGIC gamma telescope data 2004 2 | 3 | 2. Sources: 4 | 5 | (a) Original owner of the database: 6 | 7 | R. K. Bock 8 | Major Atmospheric Gamma Imaging Cherenkov Telescope project (MAGIC) 9 | http://wwwmagic.mppmu.mpg.de 10 | rkb@mail.cern.ch 11 | 12 | (b) Donor: 13 | 14 | P. Savicky 15 | Institute of Computer Science, AS of CR 16 | Czech Republic 17 | savicky@cs.cas.cz 18 | 19 | (c) Date received: May 2007 20 | 21 | 3. Past Usage: 22 | 23 | (a) Bock, R.K., Chilingarian, A., Gaug, M., Hakl, F., Hengstebeck, T., 24 | Jirina, M., Klaschka, J., Kotrc, E., Savicky, P., Towers, S., 25 | Vaicilius, A., Wittek W. (2004). 26 | Methods for multidimensional event classification: a case study 27 | using images from a Cherenkov gamma-ray telescope. 28 | Nucl.Instr.Meth. A, 516, pp. 511-528. 29 | 30 | (b) P. Savicky, E. Kotrc. 31 | Experimental Study of Leaf Confidences for Random Forest. 32 | Proceedings of COMPSTAT 2004, In: Computational Statistics. 33 | (Ed.: Antoch J.) - Heidelberg, Physica Verlag 2004, pp. 1767-1774. 34 | 35 | (c) J. Dvorak, P. Savicky. 36 | Softening Splits in Decision Trees Using Simulated Annealing. 37 | Proceedings of ICANNGA 2007, Warsaw, (Ed.: Beliczynski et. al), 38 | Part I, LNCS 4431, pp. 721-729. 39 | 40 | 4. Relevant Information: 41 | 42 | The data are MC generated (see below) to simulate registration of high energy 43 | gamma particles in a ground-based atmospheric Cherenkov gamma telescope using the 44 | imaging technique. Cherenkov gamma telescope observes high energy gamma rays, 45 | taking advantage of the radiation emitted by charged particles produced 46 | inside the electromagnetic showers initiated by the gammas, and developing in the 47 | atmosphere. This Cherenkov radiation (of visible to UV wavelengths) leaks 48 | through the atmosphere and gets recorded in the detector, allowing reconstruction 49 | of the shower parameters. The available information consists of pulses left by 50 | the incoming Cherenkov photons on the photomultiplier tubes, arranged in a 51 | plane, the camera. Depending on the energy of the primary gamma, a total of 52 | few hundreds to some 10000 Cherenkov photons get collected, in patterns 53 | (called the shower image), allowing to discriminate statistically those 54 | caused by primary gammas (signal) from the images of hadronic showers 55 | initiated by cosmic rays in the upper atmosphere (background). 56 | 57 | Typically, the image of a shower after some pre-processing is an elongated 58 | cluster. Its long axis is oriented towards the camera center if the shower axis 59 | is parallel to the telescope's optical axis, i.e. if the telescope axis is 60 | directed towards a point source. A principal component analysis is performed 61 | in the camera plane, which results in a correlation axis and defines an ellipse. 62 | If the depositions were distributed as a bivariate Gaussian, this would be 63 | an equidensity ellipse. The characteristic parameters of this ellipse 64 | (often called Hillas parameters) are among the image parameters that can be 65 | used for discrimination. The energy depositions are typically asymmetric 66 | along the major axis, and this asymmetry can also be used in discrimination. 67 | There are, in addition, further discriminating characteristics, like the 68 | extent of the cluster in the image plane, or the total sum of depositions. 69 | 70 | The data set was generated by a Monte Carlo program, Corsika, described in 71 | D. Heck et al., CORSIKA, A Monte Carlo code to simulate extensive air showers, 72 | Forschungszentrum Karlsruhe FZKA 6019 (1998). 73 | The program was run with parameters allowing to observe events with energies down 74 | to below 50 GeV. 75 | 76 | 5. Number of Instances: 19020 77 | 78 | 6. Number of Attributes: 11 (including the class) 79 | 80 | 7. Attribute information: 81 | 82 | 1. fLength: continuous # major axis of ellipse [mm] 83 | 2. fWidth: continuous # minor axis of ellipse [mm] 84 | 3. fSize: continuous # 10-log of sum of content of all pixels [in #phot] 85 | 4. fConc: continuous # ratio of sum of two highest pixels over fSize [ratio] 86 | 5. fConc1: continuous # ratio of highest pixel over fSize [ratio] 87 | 6. fAsym: continuous # distance from highest pixel to center, projected onto major axis [mm] 88 | 7. fM3Long: continuous # 3rd root of third moment along major axis [mm] 89 | 8. fM3Trans: continuous # 3rd root of third moment along minor axis [mm] 90 | 9. fAlpha: continuous # angle of major axis with vector to origin [deg] 91 | 10. fDist: continuous # distance from origin to center of ellipse [mm] 92 | 11. class: g,h # gamma (signal), hadron (background) 93 | 94 | 8. Missing Attribute Values: None 95 | 96 | 9. Class Distribution: 97 | 98 | g = gamma (signal): 12332 99 | h = hadron (background): 6688 100 | 101 | For technical reasons, the number of h events is underestimated. 102 | In the real data, the h class represents the majority of the events. 103 | 104 | The simple classification accuracy is not meaningful for this data, since 105 | classifying a background event as signal is worse than classifying a signal 106 | event as background. For comparison of different classifiers an ROC curve 107 | has to be used. The relevant points on this curve are those, where the 108 | probability of accepting a background event as signal is below one of the 109 | following thresholds: 0.01, 0.02, 0.05, 0.1, 0.2 depending on the required 110 | quality of the sample of the accepted events for different experiments. 111 | 112 | -------------------------------------------------------------------------------- /sparse_linear.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.nn.utils.prune as prune 9 | import torch.optim as optim 10 | 11 | def sparse_linear(name): 12 | if name == 'linear': 13 | return Linear 14 | elif name == 'l0': 15 | return L0Linear 16 | elif name == 'reweight': 17 | return ReweightLinear 18 | else: 19 | raise ValueError(f'{name} linear type not supported.') 20 | 21 | class Linear(nn.Linear): 22 | def __init__(self, in_features, out_features, bias=True, linear=F.linear, **kwargs): 23 | super(Linear, self).__init__(in_features, out_features, bias=bias, **kwargs) 24 | 25 | self.linear = linear 26 | 27 | def forward(self, input): 28 | output = self.linear(input, self.weight, self.bias) 29 | 30 | return output 31 | 32 | def sparsity(self): 33 | sparsity = (self.weight == 0).float().mean().item() 34 | 35 | return sparsity 36 | 37 | def masked_weight(self): 38 | masked_weight = self.weight 39 | 40 | return masked_weight 41 | 42 | def regularization(self): 43 | regularization = 0 44 | 45 | return regularization 46 | 47 | class L0Linear(nn.Linear): 48 | def __init__(self, in_features, out_features, bias=True, linear=F.linear, loc_mean=0, loc_sdev=0.01, 49 | beta=2 / 3, gamma=-0.1, zeta=1.1, fix_temp=True, **kwargs): 50 | super(L0Linear, self).__init__(in_features, out_features, bias=bias, **kwargs) 51 | 52 | self._size = self.weight.size() 53 | self.loc = nn.Parameter(torch.zeros(self._size).normal_(loc_mean, loc_sdev)) 54 | self.temp = beta if fix_temp else nn.Parameter(torch.zeros(1).fill_(beta)) 55 | self.register_buffer("uniform", torch.zeros(self._size)) 56 | self.gamma = gamma 57 | self.zeta = zeta 58 | self.gamma_zeta_ratio = math.log(-gamma / zeta) 59 | self.linear = linear 60 | 61 | self.penalty = 0 62 | 63 | def forward(self, input): 64 | mask, self.penalty = self._get_mask() 65 | masked_weight = self.weight * mask 66 | output = self.linear(input, masked_weight, self.bias) 67 | 68 | return output 69 | 70 | def sparsity(self): 71 | sparsity = (self.masked_weight() == 0).float().mean().item() 72 | 73 | return sparsity 74 | 75 | def masked_weight(self): 76 | mask, _ = self._get_mask() 77 | masked_weight = self.weight * mask 78 | 79 | return masked_weight 80 | 81 | def regularization(self, mean=True, axis=None): 82 | regularization = self.penalty 83 | if mean: 84 | regularization = regularization.mean() if axis == None else regularization.mean(axis) 85 | 86 | return regularization 87 | 88 | def _get_mask(self): 89 | def hard_sigmoid(x): 90 | return torch.min(torch.max(x, torch.zeros_like(x)), torch.ones_like(x)) 91 | 92 | if self.training: 93 | self.uniform.uniform_() 94 | u = torch.autograd.Variable(self.uniform) 95 | s = torch.sigmoid((torch.log(u) - torch.log(1 - u) + self.loc) / self.temp) 96 | s = s * (self.zeta - self.gamma) + self.gamma 97 | penalty = torch.sigmoid(self.loc - self.temp * self.gamma_zeta_ratio) 98 | else: 99 | s = torch.sigmoid(self.loc) * (self.zeta - self.gamma) + self.gamma 100 | penalty = 0 101 | 102 | return hard_sigmoid(s), penalty 103 | 104 | class ReweightLinear(nn.Linear): 105 | def __init__(self, in_features, out_features, bias=True, linear=F.linear, 106 | prune_neuron=False, prune_always=True, factor=0.1): 107 | super(ReweightLinear, self).__init__(in_features, out_features, bias=bias) 108 | 109 | self.prune_neuron = prune_neuron 110 | self.prune_always = prune_always 111 | self.factor = factor 112 | self.linear = linear 113 | 114 | def forward(self, input): 115 | if self.eval(): 116 | weight = self.masked_weight() 117 | else: 118 | weight = self.masked_weight() if self.prune_always else self.weight 119 | out = self.linear(input, weight, self.bias) 120 | 121 | return out 122 | 123 | def sparsity(self): 124 | sparsity = (self.weight.abs() <= self._threshold()).float().mean().item() 125 | 126 | return sparsity 127 | 128 | def masked_weight(self): 129 | masked_weight = self.weight.clone() 130 | masked_weight[self.weight.abs() <= self._threshold()] = 0 131 | 132 | return masked_weight 133 | 134 | def regularization(self, mean=True, axis=None): 135 | regularization = self.weight.abs() 136 | if mean: 137 | regularization = regularization.mean() if axis == None else regularization.mean(axis) 138 | 139 | return regularization 140 | 141 | def _threshold(self): 142 | if self.prune_neuron: 143 | threshold = self.factor * self.weight.std(1).unsqueeze(1) 144 | else: 145 | threshold = self.factor * self.weight.std() 146 | 147 | return threshold -------------------------------------------------------------------------------- /tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "b44bab5d-a11d-4003-96ff-4f71c083fef0", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import numpy as np\n", 12 | "\n", 13 | "from datasets.dataset import transform_dataset, kfold_dataset\n", 14 | "from DRNet import train, DRNet" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "id": "f49dd5dd-aca4-46a7-8bbe-781f1a057ed3", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "# Read datasets\n", 25 | "name = 'adult'\n", 26 | "X, Y, X_headers, Y_headers = transform_dataset(name, method='onehot-compare', negations=False, labels='binary')\n", 27 | "datasets = kfold_dataset(X, Y, shuffle=1)\n", 28 | "X_train, X_test, Y_train, Y_test = datasets[0]\n", 29 | "\n", 30 | "train_set = torch.utils.data.TensorDataset(torch.Tensor(X_train.to_numpy()), torch.Tensor(Y_train))\n", 31 | "test_set = torch.utils.data.TensorDataset(torch.Tensor(X_test.to_numpy()), torch.Tensor(Y_test))" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 4, 37 | "id": "b60c1c84-8267-4a3a-9e86-89363f8ec696", 38 | "metadata": { 39 | "tags": [] 40 | }, 41 | "outputs": [ 42 | { 43 | "name": "stderr", 44 | "output_type": "stream", 45 | "text": [ 46 | "Epoch: 100%|██████████| 2000/2000 [07:36<00:00, 4.38it/s, loss=0.544, epoch accu=0.831, test accu=0.836, num rules=11, sparsity=0.907] \n" 47 | ] 48 | } 49 | ], 50 | "source": [ 51 | "# Train DR-Net\n", 52 | "# Default learning rate (1e-2), and_lam (1e-2), and and_lam (1e-5) usually work the best. A large epochs number is necessary for a sparse rule set i.e 10000 epochs.\n", 53 | "net = DRNet(train_set[:][0].size(1), 50, 1)\n", 54 | "train(net, train_set, test_set=test_set, device='cuda', lr=1e-2, epochs=2000, batch_size=400,\n", 55 | " and_lam=1e-2, or_lam=1e-5, num_alter=500)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 5, 61 | "id": "f208a470-759a-4499-a3a8-e4cf004c8663", 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "name": "stdout", 66 | "output_type": "stream", 67 | "text": [ 68 | "Accuracy: 0.836399801093983, num rules: 11, num conditions: 131\n" 69 | ] 70 | } 71 | ], 72 | "source": [ 73 | "# Get accuracy and the rule net\n", 74 | "accu = (net.predict(np.array(X_test)) == Y_test).mean()\n", 75 | "rules = net.get_rules(X_headers)\n", 76 | "print(f'Accuracy: {accu}, num rules: {len(rules)}, num conditions: {sum(map(len, rules))}')" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "id": "e1260593-6f2b-44e2-9a11-27ffdea0d14b", 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | " " 87 | ] 88 | } 89 | ], 90 | "metadata": { 91 | "kernelspec": { 92 | "display_name": "Python 3 (ipykernel)", 93 | "language": "python", 94 | "name": "python3" 95 | }, 96 | "language_info": { 97 | "codemirror_mode": { 98 | "name": "ipython", 99 | "version": 3 100 | }, 101 | "file_extension": ".py", 102 | "mimetype": "text/x-python", 103 | "name": "python", 104 | "nbconvert_exporter": "python", 105 | "pygments_lexer": "ipython3", 106 | "version": "3.8.10" 107 | } 108 | }, 109 | "nbformat": 4, 110 | "nbformat_minor": 5 111 | } 112 | --------------------------------------------------------------------------------