├── LICENSE ├── README.md ├── classification ├── dataset.py ├── main.py ├── model.py ├── ndf.py ├── ndf_vis.py ├── optimizer.py ├── parse.py ├── resnet.py ├── trainer.py └── utils.py ├── data ├── CACD_split │ └── place meta-data here.txt └── Nexperia │ ├── __init__.py │ └── dataset.py ├── docs └── supplementary material.pdf ├── pre-trained └── place pre-trained models here.txt ├── regression ├── cacd_process.py ├── data_prepare.py ├── main.py ├── model.py ├── ndf.py ├── ndf_vis.py ├── optimizer.py ├── resnet.py ├── trainer.py ├── utils.py └── vis_utils.py └── teasers ├── cacd_final1.png ├── cifar10_results.pdf ├── cifar10_results.png ├── mnist_results.pdf └── mnist_results.png /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Shichao Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VisualizingNDF 2 | Background: Neural decision forest (NDF) [3] combines the representation power of deep neural networks (DNNs) and the divide-and-conquer idea of traditional decision trees. It conducts inference by making decisions based on image features extracted by DNNs. This decision-making process can be traced and visualized with Decision Saliency Maps (DSMs) [1], which highlight important regions of the input that influence the decision process more. 3 | 4 | Contents: This repository contains official Pytorch code for training and visualizing NDF. Pre-processed data and pre-trained models are also released. Both classification and regression problems are considered and specific tasks include: 5 | 1. Image classification for MNIST, CIFAR-10 and Nexperia semiconductor. The last is a Kaggle competition organized by [MATH 6380O (Advanced Topics in Deep Learning)](https://deeplearning-math.github.io/) in HKUST. 6 | 2. Facial age estimation on the large-scale Cross-Age Celebrity Dataset (CACD). Pre-processed data and a new model (RNDF) [2] is released. RNDF achieves state-of-the-art accuracy while comsumes less memory. 7 | 8 | ## Example: decision-making for image classification 9 | The left-most column shows the input images. Each row visualizes one path from the root node towards the leaf node in a soft decision tree. Each image in the row represents the DSM [1] for one splitting node, where (Na, Pb) means the input arrives at node a with probability b. For example, the input arrives at the root node with probability 1 is indicated by (N1, P1.0). Each DSM highlights the spatial region that has larger influence on the corresponding splitting node. For example, the foreground object is more important for NDF when making decisions. 10 |
11 | 12 |
13 |
14 | 15 |
16 | 17 | ## Example: decision-making for facial age estimation 18 | Note how the irrelevant texture (e.g. hair) is ignored by NDF during its decision making process. 19 |
20 | 21 |
22 | 23 | ## Performance on Cross-Age Celebrity Dataset (CACD) 24 | | Model | Error | Memory Usage | FLOPs 25 | | ----------------- | ----------- | ----------- | ----------- | 26 | | [DRFs (CVPR 2018)](https://github.com/shenwei1231/caffe-DeepRegressionForests) | 4.637 | 539.4MB | 16G 27 | | [RNDF (Ours)](https://arxiv.org/abs/1908.10737) | 4.595 | 112.4MB | 4G 28 | 29 | ## Dependency 30 | * Python 3.6 (not tested for other versions) 31 | * PyTorch >= 1.0 32 | * Matplotlib 33 | * Numpy 34 | * CUDA (CPU mode is not implemented) 35 | 36 | ## Pre-trained models 37 | You can download the pre-trained models [here](https://drive.google.com/drive/folders/1DM6wVSknkYBqGf1UwHQgJNUp40sYDMrv?usp=sharing) and place them in the "pre-trained" folder. 38 | 39 | ## Usage: visualizing pre-trained NDF for image classification 40 | After downloading the pre-trained models, go to /classification and 41 | run 42 | ```bash 43 | python ndf_vis.py 44 | ``` 45 | for CIFAR-10. 46 | 47 | For MNIST, run 48 | ```bash 49 | python ndf_vis.py -dataset 'mnist' 50 | ``` 51 | ## Usage: visualizing pre-trained RNDF for facial age estimation 52 | To visualize NDF for CACD dataset: 53 | 1. Download the pre-processed images [here](https://drive.google.com/file/d/1_xb5E_f_vmfZN_9ymmrBhZQVKdaAsubj/view?usp=sharing) and decompress it into the "/data" folder. 54 | 2. Download the metadata folder [here](https://drive.google.com/drive/folders/1s_Ml82O4FVkC34PCE4ttrYhta3EKeYdo?usp=sharing) and place it under "/data". 55 | 3. Go to /regression and run 56 | ```bash 57 | python ndf_vis.py 58 | ``` 59 | Please refer to the classification counterpart for detailed comments. Future updates will introduce more comments for regression. 60 | 61 | ## Usage: training NDF for image classification 62 | To train a deep neural decision forest for CIFAR-10, go to /classification and run 63 | ```bash 64 | python main.py 65 | ``` 66 | For MNIST, run 67 | ```bash 68 | python main.py -dataset 'mnist' -epochs 50 69 | ``` 70 | 71 | ## Usage: training NDF for facial age estimation 72 | To train a RNDF (R stands for residual) for CACD dataset, follow the same step 1 and 2 as in visualization. Finally, go to /regression and run 73 | ```bash 74 | python main.py -train True 75 | ``` 76 | To test the pre-trained model on CACD, go to /regression and run 77 | ```bash 78 | python main.py -evaluate True -test_model_path "YourPATH/CACD_MAE_4.59.pth" 79 | ``` 80 | The released model should give a MAE of 4.59 81 | 82 | ## License 83 | MIT 84 | 85 | ## Citation 86 | Please consider citing the related papers in your publications if they help your research: 87 | 88 | @inproceedings{li2019visualizing, 89 | title={Visualizing the Decision-making Process in Deep Neural Decision Forest}, 90 | author={Li, Shichao and Cheng, Kwang-Ting}, 91 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 92 | pages={114--117}, 93 | year={2019} 94 | } 95 | 96 | @article{li2019facial, 97 | title={Facial age estimation by deep residual decision making}, 98 | author={Li, Shichao and Cheng, Kwang-Ting}, 99 | journal={arXiv preprint arXiv:1908.10737}, 100 | year={2019} 101 | } 102 | 103 | @inproceedings{kontschieder2015deep, 104 | title={Deep neural decision forests}, 105 | author={Kontschieder, Peter and Fiterau, Madalina and Criminisi, Antonio and Rota Bulo, Samuel}, 106 | booktitle={Proceedings of the IEEE international conference on computer vision}, 107 | pages={1467--1475}, 108 | year={2015} 109 | } 110 | 111 | Links to the papers: 112 | 113 | 1. [Visualizing the decision-making process in deep neural decision forest](http://openaccess.thecvf.com/content_CVPRW_2019/papers/Explainable%20AI/Li_Visualizing_the_Decision-making_Process_in_Deep_Neural_Decision_Forest_CVPRW_2019_paper.pdf) 114 | 2. [Facial age estimation by deep residual decision making](https://arxiv.org/abs/1908.10737) 115 | 3. [Deep neural decision forests](http://openaccess.thecvf.com/content_iccv_2015/papers/Kontschieder_Deep_Neural_Decision_ICCV_2015_paper.pdf) 116 | -------------------------------------------------------------------------------- /classification/dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | 6 | def prepare_db(opt): 7 | """ 8 | prepare the Pytorch dataset object for classification. 9 | args: 10 | opt: the experiment configuration object. 11 | return: 12 | a dictionary contraining the training and evaluation dataset 13 | """ 14 | logging.info("Use %s dataset"%(opt.dataset)) 15 | 16 | # prepare MNIST dataset 17 | if opt.dataset == 'mnist': 18 | opt.n_class = 10 19 | # training set 20 | train_dataset = torchvision.datasets.MNIST('../data/mnist', train=True, 21 | download=True, 22 | transform=transforms.Compose([ 23 | transforms.ToTensor(), 24 | transforms.Normalize((0.1307,), 25 | (0.3081,)) 26 | ])) 27 | 28 | # evaluation set 29 | eval_dataset = torchvision.datasets.MNIST('../data/mnist', train=False, 30 | download=True, 31 | transform=transforms.Compose([ 32 | transforms.ToTensor(), 33 | transforms.Normalize((0.1307,), 34 | (0.3081,)) 35 | ])) 36 | # prepare CIFAR-10 dataset 37 | elif opt.dataset == 'cifar10': 38 | opt.n_class = 10 39 | # define the image transformation operators 40 | transform_train = transforms.Compose([ 41 | transforms.RandomCrop(32, padding=4), 42 | transforms.RandomHorizontalFlip(), 43 | transforms.ToTensor(), 44 | transforms.Normalize((0.4914, 0.4822, 0.4465), 45 | (0.2023, 0.1994, 0.2010)), 46 | ]) 47 | transform_test = transforms.Compose([ 48 | transforms.ToTensor(), 49 | transforms.Normalize((0.4914, 0.4822, 0.4465), 50 | (0.2023, 0.1994, 0.2010)), 51 | ]) 52 | train_dataset = torchvision.datasets.CIFAR10(root='../data/cifar10', 53 | train=True, 54 | download=True, 55 | transform=transform_train) 56 | eval_dataset = torchvision.datasets.CIFAR10(root='../data/cifar10', 57 | train=False, 58 | download=True, 59 | transform=transform_test) 60 | elif opt.dataset == 'Nexperia': 61 | # Nexperia image classification 62 | # Updated 2019/12/01 63 | # Reference: https://www.kaggle.com/c/semi-conductor-image-classification-first 64 | # Task: classify whether a semiconductor image is abnormal 65 | opt.n_class = 2 66 | if not os.path.exists("../data/Nexperia/semi-conductor-image-classification-first"): 67 | logging.info("Data not found. Please download first.") 68 | raise ValueError 69 | import sys 70 | sys.path.insert(0, "../data/") 71 | from Nexperia.dataset import get_datasets 72 | return get_datasets(opt) 73 | else: 74 | raise NotImplementedError 75 | return {'train':train_dataset, 'eval':eval_dataset} 76 | -------------------------------------------------------------------------------- /classification/main.py: -------------------------------------------------------------------------------- 1 | import parse 2 | import model 3 | import dataset 4 | import trainer 5 | import optimizer 6 | 7 | import logging 8 | import torch 9 | 10 | def main(): 11 | # logging configuration 12 | logging.basicConfig( 13 | level=logging.INFO, 14 | format="[%(asctime)s]: %(levelname)s: %(message)s" 15 | ) 16 | 17 | # command line paser 18 | opt = parse.parse_arg() 19 | 20 | # GPU 21 | opt.cuda = opt.gpuid >= 0 22 | if opt.gpuid >= 0: 23 | torch.cuda.set_device(opt.gpuid) 24 | else: 25 | logging.info("WARNING: RUN WITHOUT GPU") 26 | 27 | # prepare dataset 28 | db = dataset.prepare_db(opt) 29 | 30 | # initalize neural decision forest 31 | NDF = model.prepare_model(opt) 32 | 33 | # prepare optimizer 34 | optim, sche = optimizer.prepare_optim(NDF, opt) 35 | 36 | # train the neural decision forest 37 | best_metric = trainer.train(NDF, optim, sche, db, opt) 38 | logging.info('The best evaluation metric is %f'%best_metric) 39 | 40 | if __name__ == '__main__': 41 | main() 42 | 43 | -------------------------------------------------------------------------------- /classification/model.py: -------------------------------------------------------------------------------- 1 | import ndf 2 | 3 | def prepare_model(opt): 4 | """ 5 | prepare the neural decison forest model. The model is composed of the 6 | feature extractor (a CNN) and a decision forest. The feature extractor 7 | extracts features from input, which is sent to the decison forest for 8 | inference. 9 | args: 10 | opt: experiment configuration object 11 | """ 12 | # initialize feature extractor 13 | if opt.dataset == 'mnist': 14 | feat_layer = ndf.MNISTFeatureLayer(opt.feat_dropout, 15 | opt.feature_length) 16 | elif opt.dataset == 'cifar10': 17 | feat_layer = ndf.CIFAR10FeatureLayer(opt.feat_dropout, 18 | feat_length=opt.feature_length, 19 | archi_type=opt.model_type) 20 | elif opt.dataset == 'Nexperia': 21 | feat_layer = ndf.NexperiaFeatureLayer(opt.feat_dropout, 22 | feat_length=opt.feature_length, 23 | archi_type=opt.model_type) 24 | else: 25 | raise NotImplementedError 26 | # initialize the decison forest 27 | forest = ndf.Forest(n_tree = opt.n_tree, tree_depth = opt.tree_depth, 28 | feature_length = opt.feature_length, 29 | vector_length = opt.n_class, use_cuda = opt.cuda) 30 | model = ndf.NeuralDecisionForest(feat_layer, forest) 31 | if opt.cuda: 32 | model = model.cuda() 33 | else: 34 | model = model.cpu() 35 | 36 | return model -------------------------------------------------------------------------------- /classification/ndf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | from collections import OrderedDict 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import resnet 8 | # smallest positive float number 9 | FLT_MIN = float(np.finfo(np.float32).eps) 10 | FLT_MAX = float(np.finfo(np.float32).max) 11 | 12 | class MNISTFeatureLayer(nn.Module): 13 | def __init__(self, dropout_rate, feat_length = 512, shallow=False): 14 | super(MNISTFeatureLayer, self).__init__() 15 | self.shallow = shallow 16 | self.feat_length = feat_length 17 | if shallow: 18 | self.add_module('conv1', nn.Conv2d(1, 32, kernel_size=15,padding=1,stride=5)) 19 | else: 20 | self.conv_layers = nn.Sequential() 21 | self.conv_layers.add_module('conv1', nn.Conv2d(1, 32, kernel_size=3, padding=1)) 22 | self.conv_layers.add_module('bn1', nn.BatchNorm2d(32)) 23 | self.conv_layers.add_module('relu1', nn.ReLU()) 24 | self.conv_layers.add_module('pool1', nn.MaxPool2d(kernel_size=2)) 25 | #self.add_module('drop1', nn.Dropout(dropout_rate)) 26 | self.conv_layers.add_module('conv2', nn.Conv2d(32, 64, kernel_size=3, padding=1)) 27 | self.conv_layers.add_module('bn2', nn.BatchNorm2d(64)) 28 | self.conv_layers.add_module('relu2', nn.ReLU()) 29 | self.conv_layers.add_module('pool2', nn.MaxPool2d(kernel_size=2)) 30 | #self.add_module('drop2', nn.Dropout(dropout_rate)) 31 | self.conv_layers.add_module('conv3', nn.Conv2d(64, 128, kernel_size=3, padding=1)) 32 | self.conv_layers.add_module('bn3', nn.BatchNorm2d(128)) 33 | self.conv_layers.add_module('relu3', nn.ReLU()) 34 | self.conv_layers.add_module('pool3', nn.MaxPool2d(kernel_size=2)) 35 | #self.add_module('drop3', nn.Dropout(dropout_rate)) 36 | self.linear_layer = nn.Linear(self.get_conv_size(), 37 | feat_length, 38 | bias=True) 39 | def get_out_feature_size(self): 40 | return self.feat_length 41 | 42 | def get_conv_size(self): 43 | if self.shallow: 44 | return 64*4*4 45 | else: 46 | return 128*3*3 47 | 48 | def forward(self, x): 49 | feats = self.conv_layers(x) 50 | feats = feats.view(x.size()[0], -1) 51 | return self.linear_layer(feats) 52 | 53 | class CIFAR10FeatureLayer(nn.Sequential): 54 | def __init__(self, dropout_rate, feat_length = 512, archi_type='resnet18'): 55 | super(CIFAR10FeatureLayer, self).__init__() 56 | self.archi_type = archi_type 57 | self.feat_length = feat_length 58 | if self.archi_type == 'default': 59 | self.add_module('conv1', nn.Conv2d(3, 32, kernel_size=3, padding=1)) 60 | self.add_module('bn1', nn.BatchNorm2d(32)) 61 | self.add_module('relu1', nn.ReLU()) 62 | self.add_module('pool1', nn.MaxPool2d(kernel_size=2)) 63 | #self.add_module('drop1', nn.Dropout(dropout_rate)) 64 | self.add_module('conv2', nn.Conv2d(32, 32, kernel_size=3, padding=1)) 65 | self.add_module('bn2', nn.BatchNorm2d(32)) 66 | self.add_module('relu2', nn.ReLU()) 67 | self.add_module('pool2', nn.MaxPool2d(kernel_size=2)) 68 | #self.add_module('drop2', nn.Dropout(dropout_rate)) 69 | self.add_module('conv3', nn.Conv2d(32, 64, kernel_size=3, padding=1)) 70 | self.add_module('bn3', nn.BatchNorm2d(64)) 71 | self.add_module('relu3', nn.ReLU()) 72 | self.add_module('pool3', nn.MaxPool2d(kernel_size=2)) 73 | #self.add_module('drop3', nn.Dropout(dropout_rate)) 74 | elif self.archi_type == 'resnet18': 75 | self.add_module('resnet18', resnet.ResNet18(feat_length)) 76 | elif self.archi_type == 'resnet50': 77 | self.add_module('resnet50', resnet.ResNet50(feat_length)) 78 | elif self.archi_type == 'resnet152': 79 | self.add_module('resnet152', resnet.ResNet152(feat_length)) 80 | else: 81 | raise NotImplementedError 82 | 83 | def get_out_feature_size(self): 84 | if self.archi_type == 'default': 85 | return 64*4*4 86 | else: 87 | return self.feat_length 88 | 89 | class NexperiaFeatureLayer(nn.Sequential): 90 | def __init__(self, dropout_rate, feat_length = 512, archi_type='resnet18'): 91 | super(NexperiaFeatureLayer, self).__init__() 92 | self.archi_type = archi_type 93 | self.feat_length = feat_length 94 | if self.archi_type in ['resnet18', 'resnet50', 'resnet152']: 95 | self.add_module(archi_type, self.get_resnet_pretrained(self.archi_type, self.feat_length)) 96 | elif self.archi_type in ['densenet121', 'densenet161', 'densenet169', 'densenet201']: 97 | self.add_module(archi_type, self.get_densenet_pretrained(self.archi_type, self.feat_length)) 98 | else: 99 | raise NotImplementedError 100 | 101 | def get_resnet_pretrained(self, archi_type, feat_length, grayscale=True): 102 | from torchvision import models 103 | model = getattr(models, archi_type)(pretrained=True) 104 | in_features = model.fc.in_features 105 | if grayscale: 106 | # replace the first convolution layer 107 | stride = model.conv1.kernel_size 108 | padding = model.conv1.padding 109 | kernel_size = model.conv1.kernel_size 110 | out_channels = model.conv1.out_channels 111 | del model.conv1 112 | model.conv1 = nn.Conv2d(1, out_channels, kernel_size, stride, padding) 113 | # replace the FC layer 114 | del model.fc 115 | model.fc = nn.Linear(in_features, feat_length, bias=True) 116 | return model 117 | 118 | def get_densenet_pretrained(self, archi_type, feat_length, grayscale=True): 119 | from torchvision import models 120 | model = getattr(models, archi_type)(pretrained=True) 121 | in_features = model.classifier.in_features 122 | if grayscale: 123 | # replace the first convolution layer 124 | stride = model.features.conv0.kernel_size 125 | padding = model.features.conv0.padding 126 | kernel_size = model.features.conv0.kernel_size 127 | out_channels = model.features.conv0.out_channels 128 | model.features[0] = nn.Conv2d(1, out_channels, kernel_size, stride, padding) 129 | model.classifier = nn.Linear(in_features, feat_length, bias=True) 130 | return model 131 | 132 | def get_out_feature_size(self): 133 | return self.feat_length 134 | 135 | class Tree(nn.Module): 136 | def __init__(self, depth, feature_length, vector_length, use_cuda = False): 137 | """ 138 | Args: 139 | depth (int): depth of the neural decision tree. 140 | feature_length (int): number of neurons in the last feature layer 141 | vector_length (int): length of the mean vector stored at each tree leaf node 142 | """ 143 | super(Tree, self).__init__() 144 | self.depth = depth 145 | self.n_leaf = 2 ** depth 146 | self.feature_length = feature_length 147 | self.vector_length = vector_length 148 | self.is_cuda = use_cuda 149 | # used in leaf node update 150 | self.mu_cache = [] 151 | 152 | onehot = np.eye(feature_length) 153 | # randomly use some neurons in the feature layer to compute decision function 154 | self.using_idx = np.random.choice(feature_length, self.n_leaf, replace=False) 155 | self.feature_mask = onehot[self.using_idx].T 156 | self.feature_mask = Parameter(torch.from_numpy(self.feature_mask).type(torch.FloatTensor), requires_grad=False) 157 | # a leaf node contains a mean vector and a covariance matrix 158 | self.pi = np.zeros((self.n_leaf, self.vector_length)) 159 | if not use_cuda: 160 | self.pi = Parameter(torch.from_numpy(self.pi).type(torch.FloatTensor), requires_grad=False) 161 | else: 162 | self.pi = Parameter(torch.from_numpy(self.pi).type(torch.FloatTensor).cuda(), requires_grad=False) 163 | # use sigmoid function as the decision function 164 | self.decision = nn.Sequential(OrderedDict([ 165 | ('sigmoid', nn.Sigmoid()), 166 | ])) 167 | 168 | def forward(self, x, save_flag = False): 169 | """ 170 | Args: 171 | param x (Tensor): input feature batch of size [batch_size,n_features] 172 | Return: 173 | (Tensor): routing probability of size [batch_size,n_leaf] 174 | """ 175 | # def debug_hook(grad): 176 | # print('This is a debug hook') 177 | # print(grad.shape) 178 | # print(grad) 179 | cache = {} # save some intermediate results for analysis 180 | if x.is_cuda and not self.feature_mask.is_cuda: 181 | self.feature_mask = self.feature_mask.cuda() 182 | 183 | feats = torch.mm(x, self.feature_mask) # ->[batch_size,n_leaf] 184 | decision = self.decision(feats) # passed sigmoid->[batch_size,n_leaf] 185 | 186 | decision = torch.unsqueeze(decision,dim=2) # ->[batch_size,n_leaf,1] 187 | decision_comp = 1-decision 188 | decision = torch.cat((decision,decision_comp),dim=2) # -> [batch_size,n_leaf,2] 189 | # for debug 190 | #decision.register_hook(debug_hook) 191 | # compute route probability 192 | # note: we do not use decision[:,0] 193 | # save some intermediate results for analysis 194 | if save_flag: 195 | cache['decision'] = decision[:,:,0] 196 | batch_size = x.size()[0] 197 | 198 | mu = x.data.new(batch_size,1,1).fill_(1.) 199 | begin_idx = 1 200 | end_idx = 2 201 | for n_layer in range(0, self.depth): 202 | # mu stores the probability a sample is routed at certain node 203 | # repeat it to be multiplied for left and right routing 204 | mu = mu.repeat(1, 1, 2) 205 | # the routing probability at n_layer 206 | _decision = decision[:, begin_idx:end_idx, :] # -> [batch_size,2**n_layer,2] 207 | mu = mu*_decision # -> [batch_size,2**n_layer,2] 208 | begin_idx = end_idx 209 | end_idx = begin_idx + 2 ** (n_layer+1) 210 | # merge left and right nodes to the same layer 211 | mu = mu.view(batch_size, -1, 1) 212 | 213 | mu = mu.view(batch_size, -1) 214 | if save_flag: 215 | return mu, cache 216 | else: 217 | return mu 218 | 219 | def pred(self, x): 220 | """ 221 | Predict a vector based on stored vectors and routing probability 222 | Args: 223 | param x (Tensor): input feature batch of size [batch_size, feature_length] 224 | Return: 225 | (Tensor): prediction [batch_size,vector_length] 226 | """ 227 | p = torch.mm(self(x), self.pi) 228 | return p 229 | 230 | def get_pi(self): 231 | return self.pi 232 | 233 | def cal_prob(self, mu, pi): 234 | """ 235 | 236 | :param mu [batch_size,n_leaf] 237 | :param pi [n_leaf,n_class] 238 | :return: label probability [batch_size,n_class] 239 | """ 240 | p = torch.mm(mu,pi) 241 | return p 242 | 243 | def update_label_distribution(self, target_batches): 244 | """ 245 | compute new mean vector based on a simple update rule inspired from traditional regression tree 246 | Args: 247 | param feat_batch (Tensor): feature batch of size [batch_size, feature_length] 248 | param target_batch (Tensor): target batch of size [batch_size, vector_length] 249 | """ 250 | with torch.no_grad(): 251 | new_pi = self.pi.data.new(self.n_leaf, self.vector_length).fill_(0.) # Tensor [n_leaf,n_class] 252 | 253 | for mu, target in zip(self.mu_cache, target_batches): 254 | prob = torch.mm(mu, self.pi) # [batch_size,n_class] 255 | 256 | _target = target.unsqueeze(1) # [batch_size,1,n_class] 257 | _pi = self.pi.unsqueeze(0) # [1,n_leaf,n_class] 258 | _mu = mu.unsqueeze(2) # [batch_size,n_leaf,1] 259 | _prob = torch.clamp(prob.unsqueeze(1),min=1e-6,max=1.) # [batch_size,1,n_class] 260 | 261 | _new_pi = torch.mul(torch.mul(_target,_pi),_mu)/_prob # [batch_size,n_leaf,n_class] 262 | new_pi += torch.sum(_new_pi,dim=0) 263 | new_pi = F.softmax(new_pi, dim=1).data 264 | self.pi = Parameter(new_pi, requires_grad = False) 265 | return 266 | 267 | class Forest(nn.Module): 268 | def __init__(self, n_tree, tree_depth, feature_length, vector_length, use_cuda = False): 269 | super(Forest, self).__init__() 270 | self.trees = nn.ModuleList() 271 | self.n_tree = n_tree 272 | self.tree_depth = tree_depth 273 | self.feature_length = feature_length 274 | self.vector_length = vector_length 275 | for _ in range(n_tree): 276 | tree = Tree(tree_depth, feature_length, vector_length, use_cuda) 277 | self.trees.append(tree) 278 | 279 | def forward(self, x, save_flag = False): 280 | predictions = [] 281 | cache = [] 282 | for tree in self.trees: 283 | if save_flag: 284 | mu, cache_tree = tree(x, save_flag = True) 285 | p = tree.cal_prob(mu, tree.get_pi()) 286 | cache.append(cache_tree) 287 | else: 288 | p = tree.pred(x) 289 | predictions.append(p.unsqueeze(2)) 290 | prediction = torch.cat(predictions, dim=2) 291 | prediction = torch.sum(prediction, dim=2)/self.n_tree 292 | if save_flag: 293 | return prediction, cache 294 | else: 295 | return prediction 296 | 297 | class NeuralDecisionForest(nn.Module): 298 | def __init__(self, feature_layer, forest): 299 | super(NeuralDecisionForest, self).__init__() 300 | self.feature_layer = feature_layer 301 | self.forest = forest 302 | 303 | def forward(self, x, save_flag = False): 304 | feats = self.feature_layer(x) 305 | 306 | if save_flag: 307 | pred, cache = self.forest(feats, save_flag = True) 308 | return pred, cache, 0 309 | else: 310 | pred = self.forest(feats) 311 | return pred -------------------------------------------------------------------------------- /classification/ndf_vis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualizing the decision saliency maps for pre-trained deep neural decison forest models. 3 | """ 4 | import utils 5 | from dataset import prepare_db 6 | 7 | import torch 8 | import argparse 9 | 10 | def parse_arg(): 11 | """ 12 | argument parser. 13 | """ 14 | parser = argparse.ArgumentParser(description='ndf_vis.py') 15 | # which dataset to use, mnist or cifar10 16 | parser.add_argument('-dataset', choices=['mnist', 'cifar10', 'Nexperia'], default='cifar10') 17 | # which GPU to use 18 | parser.add_argument('-gpuid', type=int, default=0) 19 | # root path for Nexperia dataset 20 | parser.add_argument('-nexperia_root', type=str, default='../data/Nexperia/semi-conductor-image-classification-first') 21 | # use all images from Nexperia dataset 22 | parser.add_argument('-train_all', type=bool, default=False) 23 | # if not, specify the used fraction for Nexperia dataset 24 | parser.add_argument('-train_ratio', type=float, default=0.9) 25 | return parser.parse_args() 26 | 27 | # parse arguments 28 | opt = parse_arg() 29 | 30 | # For now only GPU version is supported 31 | torch.cuda.set_device(opt.gpuid) 32 | 33 | # please place the downloaded pre-trained models in the following directory 34 | if opt.dataset == 'mnist': 35 | model_path = "../pre-trained/mnist_depth_9_tree_1_acc_0.993.pth" 36 | elif opt.dataset == 'cifar10': 37 | model_path = "../pre-trained/cifar10_depth_9_tree_1_ResNet50_acc_0.9341.pth" 38 | elif opt.dataset == 'Nexperia': 39 | model_path = "../pre-trained/nexperia_depth_9_tree_1_resnet50_acc_0.955.pth" 40 | else: 41 | raise NotImplementedError 42 | 43 | # load model 44 | model = torch.load(model_path).cuda() 45 | 46 | # prepare dataset 47 | db = prepare_db(opt) 48 | # use only the evaluation subset. use db['train'] for fetching the training subset 49 | dataset = db['eval'] 50 | 51 | # ================================================================================== 52 | # compute saliency maps for different inputs for one splitting node 53 | # pick a tree index and splitting node index 54 | # tree_idx = 0 55 | # node_idx = 0 # 0 - 510 for the 511 splitting nodes in a tree of depth 9 56 | # get saliency maps for a specified node for different input tensors 57 | # utils.get_node_saliency_map(dataset, model, tree_idx, node_idx, name=opt.dataset) 58 | # ================================================================================== 59 | 60 | # visualize for the first tree (modify to others if needed) 61 | tree_idx = 0 62 | 63 | # get the computational paths for the some random inputs 64 | sample, paths, class_pred = utils.get_paths(dataset, model, tree_idx, name=opt.dataset) 65 | 66 | # for each sample, compute and plot the decision saliency map, which reflects how the input will influence the 67 | # decision-making process 68 | utils.get_path_saliency(sample, paths, class_pred, model, tree_idx, name=opt.dataset) 69 | 70 | -------------------------------------------------------------------------------- /classification/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def prepare_optim(model, opt): 4 | """ 5 | prepare the optimizer from the trainable parameters from the model. 6 | args: 7 | model: the neural decision forest to be trained 8 | opt: experiment configuration object 9 | """ 10 | params = [ p for p in model.parameters() if p.requires_grad] 11 | optimizer = torch.optim.Adam(params, lr=opt.lr, weight_decay=1e-5) 12 | # For CIFAR-10, use a scheduler to shrink learning rate 13 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 14 | milestones=[150, 250], 15 | gamma=0.3) 16 | 17 | return optimizer, scheduler -------------------------------------------------------------------------------- /classification/parse.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_arg(): 4 | parser = argparse.ArgumentParser(description='parse.py') 5 | # choice of dataset 6 | parser.add_argument('-dataset', choices=['mnist', 'cifar10', 'Nexperia'], 7 | default='cifar10') 8 | # batch size used when training the feature extractor 9 | parser.add_argument('-batch_size', type=int, default = 256) 10 | parser.add_argument('-feat_dropout', type=float, default = 0) 11 | # how many tree to use 12 | parser.add_argument('-n_tree', type=int, default=1) 13 | # tree depth 14 | parser.add_argument('-tree_depth', type=int, default=9) 15 | # # number of classes for the dataset 16 | # parser.add_argument('-n_class', type=int, default=10) 17 | parser.add_argument('-tree_feature_rate', type=float, default = 1) 18 | # learning rate 19 | parser.add_argument('-lr', type=float, default=0.001, help="sgd: 10, adam: 0.001") 20 | # choice of GPU 21 | parser.add_argument('-gpuid', type=int, default=0) 22 | # total number of training epochs 23 | parser.add_argument('-epochs', type=int, default=350) 24 | # log every how many batches 25 | parser.add_argument('-report_every', type=int, default=20) 26 | parser.add_argument('-eval_metric', type=str, default='accuracy') 27 | # whether to save the trained model 28 | parser.add_argument('-save', type=bool, default=True) 29 | # path to save the trained model 30 | parser.add_argument('-save_dir', type=str, default='../pre-trained') 31 | # root path for Nexperia dataset 32 | parser.add_argument('-nexperia_root', type=str, default='../data/Nexperia/semi-conductor-image-classification-first') 33 | # use all images from Nexperia dataset for training 34 | parser.add_argument('-train_all', type=bool, default=False) 35 | # if not, specify the used training fraction for Nexperia dataset 36 | parser.add_argument('-train_ratio', type=float, default=0.9) 37 | # network architecture to use 38 | parser.add_argument('-model_type', type=str, default='resnet18') 39 | parser.add_argument('-init_feat_map_num', type=int, default=64) 40 | # batch size used when update the leaf node prediction vectors 41 | parser.add_argument('-label_batch_size', type=int, default= 2240) 42 | # representation length after the last FC layer 43 | parser.add_argument('-feature_length', type=int, default=1024) 44 | # number of threads for data loading 45 | parser.add_argument('-num_worker', type=int, default=4) 46 | opt = parser.parse_args() 47 | return opt -------------------------------------------------------------------------------- /classification/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | Reference: 3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 4 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class BasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, in_planes, planes, stride=1): 15 | super(BasicBlock, self).__init__() 16 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 19 | self.bn2 = nn.BatchNorm2d(planes) 20 | 21 | self.shortcut = nn.Sequential() 22 | if stride != 1 or in_planes != self.expansion*planes: 23 | self.shortcut = nn.Sequential( 24 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 25 | nn.BatchNorm2d(self.expansion*planes) 26 | ) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(self.conv1(x))) 30 | out = self.bn2(self.conv2(out)) 31 | out += self.shortcut(x) 32 | out = F.relu(out) 33 | return out 34 | 35 | 36 | class Bottleneck(nn.Module): 37 | expansion = 4 38 | 39 | def __init__(self, in_planes, planes, stride=1): 40 | super(Bottleneck, self).__init__() 41 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 42 | self.bn1 = nn.BatchNorm2d(planes) 43 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 44 | self.bn2 = nn.BatchNorm2d(planes) 45 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 46 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 47 | 48 | self.shortcut = nn.Sequential() 49 | if stride != 1 or in_planes != self.expansion*planes: 50 | self.shortcut = nn.Sequential( 51 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 52 | nn.BatchNorm2d(self.expansion*planes) 53 | ) 54 | 55 | def forward(self, x): 56 | out = F.relu(self.bn1(self.conv1(x))) 57 | out = F.relu(self.bn2(self.conv2(out))) 58 | out = self.bn3(self.conv3(out)) 59 | out += self.shortcut(x) 60 | out = F.relu(out) 61 | return out 62 | 63 | 64 | class ResNet(nn.Module): 65 | def __init__(self, block, num_blocks, num_classes=10, rgb=True): 66 | super(ResNet, self).__init__() 67 | self.in_planes = 64 68 | self.input_channels = 3 if rgb else 1 69 | self.conv1 = nn.Conv2d(self.input_channels, 64, kernel_size=3, stride=1, padding=1, bias=False) 70 | self.bn1 = nn.BatchNorm2d(64) 71 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 72 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 73 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 74 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 75 | self.linear = nn.Linear(512*block.expansion, num_classes) 76 | 77 | def _make_layer(self, block, planes, num_blocks, stride): 78 | strides = [stride] + [1]*(num_blocks-1) 79 | layers = [] 80 | for stride in strides: 81 | layers.append(block(self.in_planes, planes, stride)) 82 | self.in_planes = planes * block.expansion 83 | return nn.Sequential(*layers) 84 | 85 | def forward(self, x): 86 | out = F.relu(self.bn1(self.conv1(x))) 87 | out = self.layer1(out) 88 | out = self.layer2(out) 89 | out = self.layer3(out) 90 | out = self.layer4(out) 91 | out = F.avg_pool2d(out, 4) 92 | out = out.view(out.size(0), -1) 93 | out = self.linear(out) 94 | return out 95 | 96 | 97 | def ResNet18(num_output, rgb=True): 98 | return ResNet(BasicBlock, [2,2,2,2], num_classes=num_output, rgb=rgb) 99 | 100 | def ResNet34(num_output, rgb=True): 101 | return ResNet(BasicBlock, [3,4,6,3], num_classes=num_output, rgb=rgb) 102 | 103 | def ResNet50(num_output, rgb=True): 104 | return ResNet(Bottleneck, [3,4,6,3], num_classes=num_output, rgb=rgb) 105 | 106 | def ResNet101(num_output, rgb=True): 107 | return ResNet(Bottleneck, [3,4,23,3], num_classes=num_output, rgb=rgb) 108 | 109 | def ResNet152(num_output, rgb=True): 110 | return ResNet(Bottleneck, [3,8,36,3], num_classes=num_output, rgb=rgb) 111 | 112 | 113 | def test(): 114 | net = ResNet18() 115 | y = net(torch.randn(1,3,32,32)) 116 | print(y.size()) -------------------------------------------------------------------------------- /classification/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import os 5 | import logging 6 | from sklearn.metrics import roc_auc_score 7 | # minimum float number 8 | FLT_MIN = float(np.finfo(np.float32).eps) 9 | 10 | 11 | def prepare_batches(model, dataset, num_of_batches, opt): 12 | """ 13 | prepare some feature vectors for leaf node update. 14 | args: 15 | model: the neural decison forest to be trained 16 | dataset: the used dataset 17 | num_of_batches: total number of batches to prepare 18 | opt: experiment configuration object 19 | return: target vectors used for leaf node update 20 | """ 21 | cls_onehot = torch.eye(opt.n_class) 22 | target_batches = [] 23 | with torch.no_grad(): 24 | # the features are prepared from the feature layer 25 | train_loader = torch.utils.data.DataLoader(dataset, 26 | batch_size = opt.batch_size, 27 | shuffle = True) 28 | 29 | for batch_idx, (data, target) in enumerate(train_loader): 30 | if batch_idx == num_of_batches: 31 | # enough batches 32 | break 33 | if opt.cuda: 34 | # move tensors to GPU if needed 35 | data, target, cls_onehot = data.cuda(), target.cuda(), \ 36 | cls_onehot.cuda() 37 | # get the feature vectors 38 | feats = model.feature_layer(data) 39 | # release some memory 40 | del data 41 | feats = feats.view(feats.size()[0],-1) 42 | for tree in model.forest.trees: 43 | # compute routing probability for each tree and cache them 44 | mu = tree(feats) 45 | mu += FLT_MIN 46 | tree.mu_cache.append(mu) 47 | del feats 48 | target_batches.append(cls_onehot[target]) 49 | return target_batches 50 | 51 | def evaluate(model, dataset, opt): 52 | """ 53 | evaluate the neural decison forest. 54 | args: 55 | dataset: the evaluation dataset 56 | opt: experiment configuration object 57 | return: 58 | record: evaluation statistics 59 | """ 60 | # set the model in evaluation mode 61 | model.eval() 62 | # average evaluation loss 63 | test_loss = 0.0 64 | # total correct predictions 65 | correct = 0 66 | # used for calculating AUC of ROC 67 | y_true = [] 68 | y_score = [] 69 | test_loader = torch.utils.data.DataLoader(dataset, 70 | batch_size = opt.batch_size, 71 | shuffle = False) 72 | for data, target in test_loader: 73 | with torch.no_grad(): 74 | if opt.cuda: 75 | data, target = data.cuda(), target.cuda() 76 | # get the output vector 77 | output = model(data) 78 | # loss function 79 | test_loss += F.nll_loss(torch.log(output), target, reduction='sum').data.item() # sum up batch loss 80 | # get class prediction 81 | pred = output.data.max(1, keepdim = True)[1] # get the index of the max log-probability 82 | # count correct prediction 83 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 84 | if opt.eval_metric == "AUC": 85 | y_true.append(target.data.cpu().numpy()) 86 | y_score.append(output.data.cpu().numpy()[:,1]) 87 | test_loss /= len(test_loader.dataset) 88 | test_acc = int(correct) / len(dataset) 89 | # get AUC of ROC curve 90 | if opt.eval_metric == "AUC": 91 | y_true = np.concatenate(y_true, axis=0) 92 | y_score = np.concatenate(y_score, axis=0) 93 | auc = roc_auc_score(y_true, y_score) 94 | else: 95 | auc = None 96 | record = {'loss':test_loss, 'accuracy':test_acc, 97 | 'correct number':correct, 'AUC':auc} 98 | return record 99 | 100 | def inference(model, dataset, opt, save=True): 101 | if dataset.name not in ['Nexperia']: 102 | raise NotImplementedError 103 | model.eval() 104 | all_preds = [] 105 | test_loader = torch.utils.data.DataLoader(dataset, 106 | batch_size = opt.batch_size, 107 | shuffle = False) 108 | for data, target in test_loader: 109 | with torch.no_grad(): 110 | if opt.cuda: 111 | data, target = data.cuda(), target.cuda() 112 | output = model(data) 113 | pred = output.data.cpu().numpy() 114 | if dataset.name == 'Nexperia': 115 | pred = list(pred[:,1]) 116 | #pred = list(np.argmax(pred, axis=1)) 117 | all_preds += pred 118 | dataset.write_preds(all_preds) 119 | return 120 | 121 | def get_loss(output, target, dataset_name, loss_type='nll', reweight=True): 122 | if loss_type == 'nll' and reweight and dataset_name == 'Nexperia': 123 | # re-weight due to class imbalance 124 | weight = torch.Tensor([1, 9]) 125 | weight = weight.to(target.device) 126 | loss = F.nll_loss(torch.log(output), target, weight=weight) 127 | elif loss_type == 'nll': 128 | loss = F.nll_loss(torch.log(output), target) 129 | else: 130 | raise NotImplementedError 131 | return loss 132 | 133 | def report(eval_record): 134 | logging.info('Evaluation summary:') 135 | for key in eval_record.keys(): 136 | if eval_record[key] is not None: 137 | logging.info("{:s}: {:.3f}".format(key, eval_record[key])) 138 | return 139 | 140 | def metric_init(metric): 141 | if metric in ['accuracy', 'AUC']: 142 | value = 0.0 143 | else: 144 | raise NotImplementedError 145 | return value 146 | 147 | def metric_comparison(current, best, metric): 148 | if metric in ['accuracy', 'AUC']: 149 | flag = current > best 150 | else: 151 | raise NotImplementedError 152 | return flag 153 | 154 | def train(model, optim, sche, db, opt): 155 | """ 156 | model training function. 157 | args: 158 | model: the neural decison forest to be trained 159 | optim: the optimizer 160 | sche: learning rate scheduler 161 | db: dataset object 162 | opt: experiment configuration object 163 | return: 164 | best_eval_acc: best evaluation accuracy 165 | """ 166 | # some initialization 167 | iteration_num = 0 168 | # number of batches to use for leaf node update 169 | num_of_batches = int(opt.label_batch_size/opt.batch_size) 170 | # number of images 171 | num_train = len(db['train']) 172 | # best evaluation metric 173 | best_eval_metric = metric_init(opt.eval_metric) 174 | # start training 175 | for epoch in range(1, opt.epochs + 1): 176 | # update learning rate by the scheduler 177 | sche.step() 178 | 179 | # Update leaf node prediction vector 180 | logging.info("Epoch %d : update leaf node distribution"%(epoch)) 181 | 182 | # prepare feature vectors for leaf node update 183 | target_batches = prepare_batches(model, db['train'], 184 | num_of_batches, opt) 185 | 186 | # update leaf node prediction vectors for every tree 187 | for tree in model.forest.trees: 188 | for _ in range(20): 189 | tree.update_label_distribution(target_batches) 190 | # clear the cache for routing probabilities 191 | del tree.mu_cache 192 | tree.mu_cache = [] 193 | 194 | # optimize decision functions 195 | model.train() 196 | train_loader = torch.utils.data.DataLoader(db['train'], 197 | batch_size=opt.batch_size, 198 | shuffle=True) 199 | for batch_idx, (data, target) in enumerate(train_loader): 200 | if opt.cuda: 201 | # move tensors to GPU 202 | with torch.no_grad(): 203 | data, target = data.cuda(), target.cuda() 204 | iteration_num += 1 205 | optim.zero_grad() 206 | output = model(data) 207 | output = output.clamp(min=1e-6, max=1) # resolve some numerical issue 208 | # loss function 209 | loss = get_loss(output, target, opt.dataset) 210 | # compute gradients 211 | loss.backward() 212 | # update network parameters 213 | optim.step() 214 | # logging 215 | if batch_idx % opt.report_every == 0: 216 | logging.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(\ 217 | epoch, batch_idx * len(data), num_train,\ 218 | 100. * batch_idx / len(train_loader), loss.data.item())) 219 | 220 | # Evaluate after every epoch 221 | eval_record = evaluate(model, db['eval'], opt) 222 | if metric_comparison(eval_record[opt.eval_metric], best_eval_metric, opt.eval_metric): 223 | best_eval_metric = eval_record[opt.eval_metric] 224 | best_eval_acc = eval_record["accuracy"] 225 | # save prediction results for Nexperia testing set 226 | if opt.save and opt.dataset == "Nexperia": 227 | inference(model, db['test'], opt) 228 | # save a snapshot of model when hitting a higher score 229 | if opt.save: 230 | save_path = os.path.join(opt.save_dir, 231 | 'depth_' + str(opt.tree_depth) + 232 | 'n_tree' + str(opt.n_tree) + \ 233 | 'archi_type_' + opt.model_type + '_' + str(best_eval_acc) + \ 234 | '.pth') 235 | if not os.path.exists(opt.save_dir): 236 | os.makedirs(opt.save_dir) 237 | torch.save(model, save_path) 238 | # logging 239 | report(eval_record) 240 | return best_eval_metric -------------------------------------------------------------------------------- /classification/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | some utiliy functions for data processing and visualization. 3 | """ 4 | import matplotlib.pyplot as plt 5 | from matplotlib.patches import ConnectionPatch 6 | import numpy as np 7 | import torch 8 | 9 | # class name for CIFAR-10 dataset 10 | cifar10_class_name = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 11 | 'frog', 'horse', 'ship', 'truck'] 12 | mnist_class_name = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 13 | nexperia_class_name = ['good', 'bad'] 14 | class_names = {'mnist': mnist_class_name, 15 | 'cifar10': cifar10_class_name, 16 | 'Nexperia': nexperia_class_name} 17 | def show_data(dataset, name): 18 | """ 19 | show some image from the dataset. 20 | args: 21 | dataset: dataset to show 22 | name: name of the dataset 23 | """ 24 | if name == 'mnist': 25 | num_test = len(dataset) 26 | num_shown = 100 27 | cols = 10 28 | rows = int(num_shown/cols) 29 | indices = np.random.choice(list(range(num_test)), num_test) 30 | plt.figure() 31 | for i in range(num_shown): 32 | plt.subplot(rows, cols, i+1) 33 | plt.imshow(dataset[indices[i]][0].squeeze().numpy()) 34 | plt.axis('off') 35 | plt.title(str(dataset[indices[i]][1].data.item())) 36 | plt.gcf().tight_layout() 37 | plt.show() 38 | else: 39 | raise NotImplementedError 40 | return 41 | 42 | def get_sample(dataset, sample_num, name): 43 | # random seed 44 | #np.random.seed(2019) 45 | """ 46 | get a batch of random images from the dataset 47 | args: 48 | dataset: Pytorch dataset object to use 49 | sample_num: number of samples to draw 50 | name: name of the dataset 51 | return: 52 | selected sample tensor 53 | """ 54 | # get random indices 55 | indices = np.random.choice(list(range(len(dataset))), sample_num) 56 | if name in ['mnist', 'cifar10', 'Nexperia']: 57 | # for MNIST and CIFAR-10 dataset 58 | sample = [dataset[indices[i]][0].unsqueeze(0) for i in range(len(indices))] 59 | # concatenate the samples as one tensor 60 | sample = torch.cat(sample, dim = 0) 61 | else: 62 | raise ValueError 63 | return sample 64 | 65 | def revert_preprocessing(data_tensor, name): 66 | """ 67 | unnormalize the data tensor by multiplying the standard deviation and adding the mean. 68 | args: 69 | data_tensor: input data tensor 70 | name: name of the dataset 71 | return: 72 | data_tensor: unnormalized data tensor 73 | """ 74 | if name == 'mnist': 75 | data_tensor = data_tensor*0.3081 + 0.1307 76 | elif name == 'cifar10': 77 | data_tensor[:,0,:,:] = data_tensor[:,0,:,:]*0.2023 + 0.4914 78 | data_tensor[:,1,:,:] = data_tensor[:,1,:,:]*0.1994 + 0.4822 79 | data_tensor[:,2,:,:] = data_tensor[:,2,:,:]*0.2010 + 0.4465 80 | elif name == 'Nexperia': 81 | data_tensor = data_tensor*0.1657 + 0.2484 82 | else: 83 | raise NotImplementedError 84 | return data_tensor 85 | 86 | def normalize(gradient, name): 87 | """ 88 | normalize the gradient to a 0 to 1 range for display 89 | args: 90 | gradient: input gradent tensor 91 | name: name of the dataset 92 | return: 93 | gradient: normalized gradient tensor 94 | """ 95 | if name in ['mnist', 'Nexperia']: 96 | pass 97 | elif name == 'cifar10': 98 | # take the maximum gradient from the 3 channels 99 | gradient = (gradient.max(dim=1)[0]).unsqueeze(dim=1) 100 | # get the maximum gradient 101 | max_gradient = torch.max(gradient.view(len(gradient), -1), dim=1)[0] 102 | max_gradient = max_gradient.view(len(gradient), 1, 1, 1) 103 | min_gradient = torch.min(gradient.view(len(gradient), -1), dim=1)[0] 104 | min_gradient = min_gradient.view(len(gradient), 1, 1, 1) 105 | # do normalization 106 | gradient = (gradient - min_gradient)/(max_gradient - min_gradient + 1e-3) 107 | return gradient 108 | 109 | def trace(record): 110 | """ 111 | get the the path that is very likely to be visited by the input images. For each splitting node along the 112 | path the probability of arriving at it is also computed. 113 | args: 114 | record: record of the routing probabilities of the splitting nodes 115 | return: 116 | path: the very likely computational path 117 | """ 118 | path = [] 119 | # probability of arriving at the root node is just 1 120 | prob = 1 121 | # the starting index 122 | node_idx = 1 123 | while node_idx < len(record): 124 | path.append((node_idx, prob)) 125 | # find the children node with larger visiting probability 126 | if record[node_idx] >= 0.5: 127 | prob *= record[node_idx] 128 | # go to left sub-tree 129 | node_idx = node_idx*2 130 | else: 131 | prob *= 1 - record[node_idx] 132 | # go to right sub-tree 133 | node_idx = node_idx*2 + 1 134 | return path 135 | 136 | def get_paths(dataset, model, tree_idx, name): 137 | """ 138 | compute the computational paths for the input tensors 139 | args: 140 | dataset: Pytorch dataset object 141 | model: pre-trained deep neural decision forest for visualizing 142 | tree_idx: which tree to use if there are multiple trees in the forest. 143 | name: name of the dataset 144 | return: 145 | sample: randomly drawn sample 146 | paths: computational paths for the samples 147 | class_pred: model predictions for the samples 148 | """ 149 | sample_num = 5 150 | # get some random input images 151 | sample = get_sample(dataset, sample_num, name) 152 | # forward pass to get the routing probability 153 | pred, cache, _ = model(sample.cuda(), save_flag = True) 154 | class_pred = pred.max(dim=1)[1] 155 | # for now use the first tree by cache[0] 156 | # please refer to ndf.py if you are interested in how the forward pass is implemented 157 | decision = cache[tree_idx]['decision'].data.cpu().numpy() 158 | paths = [] 159 | # trace the computational path for every input image 160 | for sample_idx in range(len(decision)): 161 | paths.append(trace(decision[sample_idx, :])) 162 | return sample, paths, class_pred 163 | 164 | def get_node_saliency_map(dataset, model, tree_idx, node_idx, name): 165 | """ 166 | get decision saliency maps for one specific splitting node 167 | args: 168 | dataset: Pytorch dataset object 169 | model: pre-trained neural decision forest to visualize 170 | tree_idx: index of the tree 171 | node_idx: index of the splitting node 172 | name: name of the dataset 173 | return: 174 | gradient: computed decision saliency maps 175 | """ 176 | # pick some samples from the dataset 177 | sample_num = 5 178 | sample = get_sample(dataset, sample_num, name) 179 | # For now only GPU code is supported 180 | sample = sample.cuda() 181 | # enable the gradient computation (the input tensor will requires gradient computation in the backward computational graph) 182 | sample.requires_grad = True 183 | # get the feature vectors for the drawn samples 184 | feats = model.feature_layer(sample) 185 | # using_idx gives the indices of the neurons in the last FC layer that are used to compute routing probabilities 186 | using_idx = model.forest.trees[tree_idx].using_idx[node_idx + 1] 187 | # for sample_idx in range(len(feats)): 188 | # feats[sample_idx, using_idx].backward(retain_graph=True) 189 | # equivalent to the above commented one 190 | feats[:, using_idx].sum(dim = 0).backward() 191 | # get the gradient data 192 | gradient = sample.grad.data 193 | # get the magnitude 194 | gradient = torch.abs(gradient) 195 | # normalize the gradient for visualizing 196 | gradient = normalize(gradient, name) 197 | # plot the input data and their corresponding decison saliency maps 198 | plt.figure() 199 | # unnormalize the images for display 200 | sample = revert_preprocessing(sample, name) 201 | # plot for every input image 202 | for sample_idx in range(sample_num): 203 | plt.subplot(2, sample_num, sample_idx + 1) 204 | sample_to_show = sample[sample_idx].squeeze().data.cpu().numpy() 205 | if name == 'cifar10': 206 | # re-order the channels 207 | sample_to_show = sample_to_show.transpose((1,2,0)) 208 | plt.imshow(sample_to_show) 209 | elif name == 'mnist': 210 | plt.imshow(sample_to_show, cmap='gray') 211 | else: 212 | raise NotImplementedError 213 | plt.subplot(2, sample_num, sample_idx + 1 + sample_num) 214 | plt.imshow(gradient[sample_idx].squeeze().cpu().numpy()) 215 | plt.axis('off') 216 | plt.show() 217 | return gradient 218 | 219 | def get_map(model, sample, node_idx, tree_idx, name): 220 | """ 221 | helper function for computing the saliency map for a specified sample and splitting node 222 | args: 223 | model: pre-trained neural decison forest to visualize 224 | sample: input image tensors 225 | node_idx: index of the splitting node 226 | tree_idx: index of the decison tree 227 | name:name of the dataset 228 | return: 229 | saliency_map: computed decision saliency map 230 | """ 231 | # move to GPU 232 | sample = sample.unsqueeze(dim=0).cuda() 233 | # enable gradient computation for the input tensor 234 | sample.requires_grad = True 235 | # get feature vectors of the input samples 236 | feat = model.feature_layer(sample) 237 | # using_idx gives the indices of the neurons in the last FC layer that are used to compute routing probabilities 238 | using_idx = model.forest.trees[tree_idx].using_idx[node_idx] 239 | # compute gradient by a backward pass 240 | feat[:, using_idx].backward() 241 | # get the gradient data 242 | gradient = sample.grad.data 243 | # normalize the gradient 244 | gradient = normalize(torch.abs(gradient), name) 245 | saliency_map = gradient.squeeze().cpu().numpy() 246 | return saliency_map 247 | 248 | def get_path_saliency(samples, paths, class_pred, model, tree_idx, name, orientation = 'horizontal'): 249 | """ 250 | show the saliency maps for the input samples with their pre-computed computational paths 251 | args: 252 | samples: input image tensor 253 | paths: pre-computed computational paths for the inputs 254 | class_pred: model predictons for the inputs 255 | model: pre-trained neural decison forest 256 | tree_idx: index of the decision tree 257 | name: name of the dataset 258 | orientation: layout of the figure 259 | """ 260 | #plt.ioff() 261 | # plotting parameters 262 | plt.figure(figsize=(20,5)) 263 | plt.rcParams.update({'font.size': 12}) 264 | # number of input samples 265 | num_samples = len(samples) 266 | # length of the computational path 267 | path_length = len(paths[0]) 268 | # iterate for every input sample 269 | for sample_idx in range(num_samples): 270 | sample = samples[sample_idx] 271 | # plot the sample 272 | plt.subplot(num_samples, path_length + 1, sample_idx*(path_length + 1) + 1) 273 | # unnormalize the input 274 | sample_to_plot = revert_preprocessing(sample.unsqueeze(dim=0), name) 275 | if name in ['mnist', 'Nexperia']: 276 | plt.imshow(sample_to_plot.squeeze().cpu().numpy(), cmap='gray') 277 | else: 278 | plt.imshow(sample_to_plot.squeeze().cpu().numpy().transpose((1,2,0))) 279 | pred_class_name = class_names[name][int(class_pred[sample_idx])] 280 | plt.axis('off') 281 | plt.title('Pred:{:s}'.format(pred_class_name)) 282 | # computational path for this sample 283 | path = paths[sample_idx] 284 | for node_idx in range(path_length): 285 | # compute and plot decison saliency map for each splitting node along the path 286 | node = path[node_idx][0] 287 | # probability of arriving at this node 288 | prob = path[node_idx][1] 289 | # compute the saliency map 290 | saliency_map = get_map(model, sample, node, tree_idx, name) 291 | if orientation == 'horizontal': 292 | sub_plot_idx = sample_idx*(path_length + 1) + node_idx + 2 293 | plt.subplot(num_samples, path_length + 1, sub_plot_idx) 294 | elif orientation == 'vertical': 295 | raise NotImplementedError 296 | else: 297 | raise NotImplementedError 298 | plt.imshow(saliency_map) 299 | plt.title('(N{:d}, P{:.2f})'.format(node, prob)) 300 | plt.axis('off') 301 | # draw some arrows 302 | for arrow_idx in range(num_samples*(path_length + 1) - 1): 303 | if (arrow_idx+1) % (path_length+1) == 0 and arrow_idx != 0: 304 | continue 305 | ax1 = plt.subplot(num_samples, path_length + 1, arrow_idx + 1) 306 | ax2 = plt.subplot(num_samples, path_length + 1, arrow_idx + 2) 307 | arrow = ConnectionPatch(xyA=[1.1,0.5], xyB=[-0.1, 0.5], coordsA='axes fraction', coordsB='axes fraction', 308 | axesA=ax1, axesB=ax2, arrowstyle="fancy") 309 | ax1.add_artist(arrow) 310 | left = 0 # the left side of the subplots of the figure 311 | right = 1 # the right side of the subplots of the figure 312 | bottom = 0.01 # the bottom of the subplots of the figure 313 | top = 0.95 # the top of the subplots of the figure 314 | wspace = 0.0 # the amount of width reserved for space between subplots, 315 | # expressed as a fraction of the average axis width 316 | hspace = 0.4 # the amount of height reserved for space between subplots, 317 | # expressed as a fraction of the average axis height 318 | plt.subplots_adjust(left, bottom, right, top, wspace, hspace) 319 | plt.show() 320 | # save figure if you need 321 | #plt.savefig('saved_fig.png',dpi=1200) 322 | return 323 | -------------------------------------------------------------------------------- /data/CACD_split/place meta-data here.txt: -------------------------------------------------------------------------------- 1 | place the pre-trained models in this directory. 2 | -------------------------------------------------------------------------------- /data/Nexperia/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nicholasli1995/VisualizingNDF/8209a75ad55201c1ec712580201b440669fcca73/data/Nexperia/__init__.py -------------------------------------------------------------------------------- /data/Nexperia/dataset.py: -------------------------------------------------------------------------------- 1 | # Nexperia Pytorch dataloader 2 | import numpy as np 3 | import os 4 | import torch 5 | import torch.utils.data 6 | import imageio 7 | import logging 8 | import csv 9 | 10 | image_extension = ".jpg" 11 | 12 | class NexperiaDataset(torch.utils.data.Dataset): 13 | def __init__(self, root, paths, imgs, labels=None, split=None, mean=None, 14 | std=None): 15 | self.root = root 16 | self.paths = paths 17 | self.names = [path.split(os.sep)[-1][:-len(image_extension)] for path in paths] 18 | self.imgs = imgs 19 | if len(self.imgs.shape) == 3: 20 | self.imgs = np.expand_dims(self.imgs, axis=1) 21 | self.labels = labels 22 | self.split = split 23 | self.name = 'Nexperia' 24 | logging.info('{:s} {:s} set contains {:d} images'.format(self.name, 25 | self.split, len(self.paths))) 26 | self.mean, self.std = self.get_stats(mean, std) 27 | self.normalize(self.mean, self.std) 28 | def __len__(self): 29 | return len(self.paths) 30 | 31 | def __getitem__(self, idx): 32 | return torch.from_numpy(self.imgs[idx]), self.labels[idx] 33 | 34 | def get_stats(self, mean=None, std=None, verbose=True): 35 | if mean is not None and std is not None: 36 | return mean, std 37 | # get normalization statistics 38 | if verbose: 39 | logging.info("Calculating normalizing statistics...") 40 | self.mean = np.mean(self.imgs) 41 | self.std = np.std(self.imgs) 42 | if verbose: 43 | logging.info("Calculation done for {:s} {:s} set.".format(self.name, 44 | self.split)) 45 | return self.mean, self.std 46 | 47 | def normalize(self, mean, std, verbose=True): 48 | if verbose: 49 | logging.info("Normalizing images...") 50 | self.imgs = (self.imgs - mean)/self.std 51 | if verbose: 52 | logging.info("Normalization done for {:s} {:s} set.".format(self.name, 53 | self.split)) 54 | return 55 | 56 | def visualize(self, count=3): 57 | for idx in range(1, count+1): 58 | visualize_grid(imgs = self.imgs, labels=self.labels, title=self.split + str(idx)) 59 | return 60 | 61 | def write_preds(self, preds): 62 | input_file = os.path.join(self.root, "template.csv") 63 | assert os.path.exists(input_file), "Please download the submission template." 64 | output_file = os.path.join(self.root, "submission.csv") 65 | save_csv(input_file, output_file, self.names, preds) 66 | np.save(os.path.join(self.root, 'submission.npy'), {'path':self.names, 'pred':preds}) 67 | return 68 | 69 | def save_csv(input_file, output_file, test_list, test_labels): 70 | """ 71 | save a csv file for testing prediction which can be submitted to Kaggle competition 72 | """ 73 | assert len(test_list) == len(test_labels) 74 | with open(input_file) as csv_file: 75 | with open(output_file, mode='w') as out_csv: 76 | csv_reader = csv.reader(csv_file, delimiter=',') 77 | csv_writer = csv.writer(out_csv) 78 | line_count = 0 79 | for row in csv_reader: 80 | if line_count == 0: 81 | # print(f'Column names are {", ".join(row)}') 82 | csv_writer.writerow(row) 83 | line_count += 1 84 | else: 85 | # print(f'\t{row[0]} works in the {row[1]} department, and was born in {row[2]}.') 86 | image_name = row[0] 87 | assert image_name in test_list, 'Missing prediction!' 88 | index = test_list.index(image_name) 89 | label = test_labels[index] 90 | csv_writer.writerow([image_name, str(label)]) 91 | line_count += 1 92 | logging.info('Saved prediction. Processed {:d} lines.'.format(line_count)) 93 | return 94 | 95 | def visualize_grid(imgs, nrows=5, ncols=5, labels = None, title=""): 96 | """ 97 | imgs: collection of images that supports indexing 98 | """ 99 | import matplotlib.pyplot as plt 100 | assert nrows*ncols <= len(imgs), 'Not enough images' 101 | # chosen indices 102 | cis = np.random.choice(len(imgs), nrows*ncols, replace=False) 103 | fig, axes = plt.subplots(nrows=nrows, ncols=ncols) 104 | fig.suptitle(title) 105 | for row_idx in range(nrows): 106 | for col_idx in range(ncols): 107 | idx = row_idx*ncols + col_idx 108 | axes[row_idx][col_idx].imshow(imgs[cis[idx]]) 109 | axes[row_idx][col_idx].set_axis_off() 110 | plt.show() 111 | if labels is not None: 112 | axes[row_idx][col_idx].set_title(str(labels[cis[idx]])) 113 | return 114 | 115 | def load_data(folders): 116 | lgood = 0 117 | lbad = 1 118 | ltest = -1 119 | paths = [] 120 | imgs = [] 121 | labels = [] 122 | for folder in folders: 123 | if 'good' in folder: 124 | label = lgood 125 | elif 'bad' in folder: 126 | label = lbad 127 | else: 128 | label = ltest 129 | for filename in os.listdir(folder): 130 | filepath = os.path.join(folder, filename) 131 | if filename.endswith(image_extension): 132 | paths.append(filepath) 133 | img = imageio.imread(filepath) 134 | img = img.astype('float32') / 255. 135 | imgs.append(img) 136 | labels.append(label) 137 | return np.array(paths), np.array(imgs), np.array(labels) 138 | 139 | def get_datasets(opt, visualize=False): 140 | root = opt.nexperia_root 141 | train_ratio = opt.train_ratio 142 | dirs = {} 143 | dirs['good'] = os.path.join(root, 'train/good_0') 144 | dirs['bad'] = os.path.join(root, 'train/bad_1') 145 | dirs['test'] = os.path.join(root, 'test/all_tests') 146 | train_paths, train_imgs, train_lbs = load_data([dirs['good'], dirs['bad']]) 147 | test_paths, test_imgs, test_lbs = load_data([dirs['test']]) 148 | # split the labeled data into training and evaluation set 149 | ntu = num_train_used = int(len(train_paths)*train_ratio) 150 | cis = chosen_indices = np.random.choice(len(train_paths), len(train_paths), replace=False) 151 | used_paths, used_imgs, used_lbs = train_paths[cis[:ntu]], train_imgs[cis[:ntu]], train_lbs[cis[:ntu]] 152 | eval_paths, eval_imgs, eval_lbs = train_paths[cis[ntu:]], train_imgs[cis[ntu:]], train_lbs[cis[ntu:]] 153 | if opt.train_all: 154 | train_set = NexperiaDataset(root, train_paths, train_imgs, train_lbs, 'train') 155 | else: 156 | train_set = NexperiaDataset(root, used_paths, used_imgs, used_lbs, 'train') 157 | eval_set = NexperiaDataset(root, eval_paths, eval_imgs, eval_lbs, 'eval', 158 | mean=train_set.mean, std=train_set.std) 159 | test_set = NexperiaDataset(root, test_paths, test_imgs, test_lbs, 'test', 160 | mean=train_set.mean, std=train_set.std) 161 | if visualize: 162 | # visualize the images with annotation 163 | train_set.visualize() 164 | eval_set.visualize() 165 | return {'train':train_set, 'eval':eval_set, 'test':test_set} -------------------------------------------------------------------------------- /docs/supplementary material.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nicholasli1995/VisualizingNDF/8209a75ad55201c1ec712580201b440669fcca73/docs/supplementary material.pdf -------------------------------------------------------------------------------- /pre-trained/place pre-trained models here.txt: -------------------------------------------------------------------------------- 1 | place the pre-trained models in this directory. 2 | -------------------------------------------------------------------------------- /regression/cacd_process.py: -------------------------------------------------------------------------------- 1 | # data-preprocessing for CACD 2 | # Reference: 3 | # https://github.com/shamangary/Keras-MORPH2-age-estimation/blob/master/TYY_MORPH_create_db.py 4 | import numpy as np 5 | import cv2 6 | import scipy.io 7 | import imageio as io 8 | import argparse 9 | from tqdm import tqdm 10 | import os 11 | from os import listdir 12 | from os.path import isfile, join 13 | import sys 14 | import dlib 15 | from moviepy.editor import * 16 | 17 | def get_dic(data_list, img_path): 18 | # split into different individuals 19 | data_dic = {} 20 | for idx in range(len(data_list)): 21 | file_name = data_list[idx][:-4] 22 | annotation = file_name.split('_') 23 | age = float(annotation[0]) 24 | identity = '' 25 | for i in range(1, len(annotation) - 1): 26 | identity += annotation[i] + ' ' 27 | file_path = os.path.join(img_path, data_list[idx]) 28 | assert os.path.exists(file_path), 'Image not found!' 29 | if identity not in data_dic: 30 | temp = {'path':[file_path], 31 | 'age_list':[age]} 32 | data_dic[identity] = temp 33 | else: 34 | data_dic[identity]['path'].append(file_path) 35 | data_dic[identity]['age_list'].append(age) 36 | return data_dic 37 | 38 | def get_counts(data_dic): 39 | SUM = 0 40 | for key in data_dic: 41 | SUM += len(data_dic[key]['path']) 42 | return SUM 43 | 44 | def get_data(img_path): 45 | # pre-process the data for CACD 46 | train_list = np.load('../data/CACD_split/train.npy', allow_pickle=True) 47 | valid_list = np.load('../data/CACD_split/valid.npy', allow_pickle=True) 48 | test_list = np.load('../data/CACD_split/test.npy', allow_pickle=True) 49 | train_dic = get_dic(train_list, img_path) 50 | print('Training images: %d'%get_counts(train_dic)) 51 | valid_dic = get_dic(valid_list, img_path) 52 | print('Validation images: %d'%get_counts(valid_dic)) 53 | test_dic = get_dic(test_list, img_path) 54 | print('Testing images: %d'%get_counts(test_dic)) 55 | return train_dic, valid_dic, test_dic 56 | 57 | def warp_im(im, M, dshape): 58 | output_im = np.zeros(dshape, dtype=im.dtype) 59 | cv2.warpAffine(im, 60 | M[:2], 61 | (dshape[1], dshape[0]), 62 | dst=output_im, 63 | borderMode=cv2.BORDER_TRANSPARENT, 64 | flags=cv2.WARP_INVERSE_MAP) 65 | return output_im 66 | 67 | def transformation_from_points(points1, points2): 68 | """ 69 | Return an affine transformation [s * R | T] such that: 70 | sum ||s*R*p1,i + T - p2,i||^2 71 | is minimized. 72 | """ 73 | # Solve the procrustes problem by subtracting centroids, scaling by the 74 | # standard deviation, and then using the SVD to calculate the rotation. See 75 | # the following for more details: 76 | # https://en.wikipedia.org/wiki/Orthogonal_Procrustes_problem 77 | 78 | points1 = np.matrix(points1).astype(np.float64) 79 | points2 = np.matrix(points2).astype(np.float64) 80 | 81 | c1 = np.mean(points1, axis=0) 82 | c2 = np.mean(points2, axis=0) 83 | points1 -= c1 84 | points2 -= c2 85 | 86 | s1 = np.std(points1) 87 | s2 = np.std(points2) 88 | points1 /= s1 89 | points2 /= s2 90 | 91 | U, S, Vt = np.linalg.svd(points1.T * points2) 92 | 93 | # The R we seek is in fact the transpose of the one given by U * Vt. This 94 | # is because the above formulation assumes the matrix goes on the right 95 | # (with row vectors) where as our solution requires the matrix to be on the 96 | # left (with column vectors). 97 | R = (U * Vt).T 98 | 99 | return np.vstack([np.hstack(((s2 / s1) * R, 100 | c2.T - (s2 / s1) * R * c1.T)), 101 | np.matrix([0., 0., 1.])]) 102 | 103 | def normalize(landmarks): 104 | center = landmarks.mean(axis = 0) 105 | deviation = landmarks - center 106 | norm_fac = np.abs(deviation).max() 107 | normalized_lm = deviation/norm_fac 108 | return normalized_lm 109 | 110 | def get_landmarks(img_name, args): 111 | file_path = args.annotation + img_name[:-4] + '.landmark' 112 | annotation = open(file_path, 'r').read().splitlines() 113 | num_lm = len(annotation) 114 | landmarks = np.matrix([[float(annotation[i].split(' ')[0]), 115 | float(annotation[i].split(' ')[1])] 116 | for i in range(num_lm)]) 117 | normalized_lm = normalize(landmarks) 118 | return (landmarks, normalized_lm) 119 | 120 | 121 | def get_args(): 122 | parser = argparse.ArgumentParser(description="This script cleans-up noisy labels " 123 | "and creates database for training.", 124 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 125 | parser.add_argument("--output", "-o", type=str, 126 | help="path to output database mat file", default='../data/CACD2000_processed') 127 | parser.add_argument("--img_size", type=int, default=256, 128 | help="output image size") 129 | parser.add_argument("-annotation", type=str, 130 | help="path to .landmark files", default='../data/CACD_landmark/landmark/') 131 | args = parser.parse_args() 132 | return args 133 | 134 | def get_mean_face(landmark_list, img_size): 135 | SUM = normalize(landmark_list[0][0]) 136 | for i in range(1, len(landmark_list)): 137 | SUM += normalize(landmark_list[i][0]) 138 | normalized_mean_face = SUM/len(landmark_list) 139 | face_size = img_size*0.3 140 | return normalized_mean_face*face_size + 0.5*img_size 141 | 142 | def process_split_file(): 143 | file_path = '/home/nicholas/Documents/Project/DRF_Age_Estimation/data/CACD_split/' 144 | train = open(file_path + 'train.txt', 'r').read().splitlines() 145 | valid = open(file_path + 'valid.txt', 'r').read().splitlines() 146 | test = open(file_path + 'test.txt', 'r').read().splitlines() 147 | return train, valid, test 148 | 149 | def main(): 150 | args = get_args() 151 | output_path = args.output 152 | img_size = args.img_size 153 | 154 | mypath = '../data/CACD2000' 155 | isPlot = False 156 | onlyfiles = [f for f in listdir(mypath) if isfile(join(mypath, f))] 157 | # landmark_list = [] 158 | # for i in tqdm(range(len(onlyfiles))): 159 | # landmark_list.append(get_landmarks(onlyfiles[i], args)) 160 | 161 | landmark_ref = np.matrix(np.load('../data/CACD_mean_face.npy', allow_pickle=True)) 162 | 163 | # Points used to line up the images. 164 | ALIGN_POINTS = list(range(16)) 165 | 166 | for i in tqdm(range(len(onlyfiles))): 167 | 168 | img_name = onlyfiles[i] 169 | input_img = cv2.imread(mypath+'/'+img_name) 170 | input_img = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB) 171 | img_h, img_w, _ = np.shape(input_img) 172 | 173 | landmark = get_landmarks(img_name, args)[0] 174 | M = transformation_from_points(landmark_ref[ALIGN_POINTS], 175 | landmark[ALIGN_POINTS]) 176 | input_img = warp_im(input_img, M, (256, 256, 3)) 177 | io.imsave(args.output +'/'+ img_name, input_img) 178 | 179 | if __name__ == '__main__': 180 | main() 181 | -------------------------------------------------------------------------------- /regression/data_prepare.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script prepares a pytorch dataset for facial age estimation. 3 | """ 4 | import torch 5 | import torch.utils.data 6 | import torchvision.transforms.functional as transform_f 7 | import imageio as io 8 | import numpy as np 9 | import logging 10 | import PIL 11 | 12 | class FacialAgeDataset(torch.utils.data.Dataset): 13 | def __init__(self, dictionary, opt, split): 14 | self.dict = dictionary 15 | self.name = opt.dataset_name 16 | self.split = split 17 | self.image_size = opt.image_size 18 | self.crop_size = opt.crop_size 19 | self.crop_limit = self.image_size - self.crop_size 20 | assert self.name in ['FGNET', 'CACD', 'Morph'], 'Dataset not supported yet.' 21 | self.cache = opt.cache 22 | self.image_channel = 1 if opt.gray_scale else 3 23 | self.transform = opt.transform 24 | self.img_path_list = [] 25 | self.label = [] 26 | self.scale_factor = 100 27 | if self.name == 'FGNET': 28 | self.mean = [0.425, 0.342, 0.314] 29 | self.std = [0.218, 0.191, 0.182] 30 | for key in dictionary: 31 | if len(key) == 3: 32 | self.img_path_list += dictionary[key]['path'] 33 | self.label += dictionary[key]['age_list'] 34 | elif self.name == 'CACD': 35 | self.mean = [0.432, 0.359, 0.320] 36 | self.std = [0.30, 0.264, 0.252] 37 | for key in dictionary: 38 | self.img_path_list += dictionary[key]['path'] 39 | self.label += dictionary[key]['age_list'] 40 | elif self.name == 'Morph': 41 | self.mean = [0.564, 0.521, 0.508] 42 | self.std = [0.281, 0.255, 0.246] 43 | self.img_path_list = dictionary['path'] 44 | self.label = dictionary['age_list'] 45 | logging.info('{:s} {:s} set contains {:d} images'.format(self.name, 46 | self.split, len(self.img_path_list))) 47 | self.label = torch.FloatTensor(self.label) 48 | self.label /= self.scale_factor 49 | 50 | def __len__(self): 51 | return len(self.img_path_list) 52 | 53 | def __getitem__(self, idx): 54 | # read image from the disk 55 | image_path = self.img_path_list[idx] 56 | image = PIL.Image.open(image_path) 57 | # transformation for data augmentation 58 | if self.transform: 59 | # Use PIL and transformation provided by Pytorch 60 | if np.random.rand() > 0.5 and self.split == 'train': 61 | image = transform_f.hflip(image) 62 | # only crop if input image size is large enough 63 | if self.crop_limit > 1: 64 | # random cropping 65 | if self.split == 'train': 66 | x_start = int(self.crop_limit*np.random.rand()) 67 | y_start = int(self.crop_limit*np.random.rand()) 68 | else: 69 | # only apply central-crop for evaluation set 70 | x_start = 15 71 | y_start = 15 72 | image = transform_f.crop(image, y_start, x_start, 73 | self.crop_size, 74 | self.crop_size) 75 | image = transform_f.to_tensor(image) 76 | image = transform_f.normalize(image, mean=self.mean, 77 | std=self.std) 78 | sample = {'image': image, 79 | 'age': self.label[idx], 80 | 'index': idx} 81 | return sample 82 | 83 | def convert(self, img): 84 | # convert grayscale to RGB image if needed 85 | if len(img.shape) == 2: 86 | img = np.expand_dims(img, axis=2) 87 | img = np.repeat(img, 3, axis=2) 88 | return img 89 | 90 | def get_label(self, idx): 91 | # this function only returns the label 92 | return self.label[idx] 93 | 94 | def get_image(self, idx): 95 | # returns the raw image with imageio 96 | image_path = self.img_path_list[idx] 97 | image = io.imread(image_path) 98 | return image 99 | 100 | def prepare_db(opt): 101 | # Prepare a list of datasets for training and evaluation 102 | train_list = [] 103 | eval_list = [] 104 | if opt.dataset_name == "FGNET": 105 | # not released for now 106 | raise NotImplementedError 107 | elif opt.dataset_name == "CACD": 108 | # testing set 109 | eval_dic = np.load('../data/CACD_split/test_cacd_processed.npy', allow_pickle=True).item() 110 | if opt.cacd_train: 111 | # use the official training set for training 112 | train_dic = np.load('../data/CACD_split/train_cacd_processed.npy', allow_pickle=True).item() 113 | logging.info('Preparing CACD dataset (training with the training set).') 114 | else: 115 | # use the official evaluation set for training 116 | train_dic = np.load('../data/CACD_split/valid_cacd_processed.npy', allow_pickle=True).item() 117 | logging.info('Preparing CACD dataset (training with the validation set).') 118 | train_list.append(FacialAgeDataset(train_dic, opt, 'train')) 119 | eval_list.append(FacialAgeDataset(eval_dic, opt, 'eval')) 120 | return {'train':train_list, 'eval':eval_list} 121 | elif opt.dataset_name == "Morph": 122 | raise ValueError 123 | else: 124 | raise NotImplementedError 125 | -------------------------------------------------------------------------------- /regression/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Summary: This script trains a residual neural decision forest (RNDF) for facial age 3 | estimation (on CACD dataset). It applies the simple idea of residual learning 4 | to neural decision forest (NDF). Residual learning is widely adopted in CNN, 5 | here let's use it for NDF. 6 | Author: Nicholas Li 7 | Contact: Nicholas.li@connect.ust.hk 8 | License: MIT 9 | """ 10 | 11 | # utility functions 12 | import data_prepare # dataset preparation 13 | import model # model implementation 14 | import trainer # training functions 15 | import optimizer # optimization functions 16 | import utils # other utilities 17 | 18 | # public libraries 19 | import torch 20 | import logging 21 | import numpy as np 22 | import time 23 | 24 | def main(): 25 | # logging configuration 26 | logging.basicConfig(level=logging.INFO, 27 | format="[%(asctime)s]: %(message)s" 28 | ) 29 | 30 | # parse command line input 31 | opt = utils.parse_arg() 32 | 33 | # Set GPU 34 | opt.cuda = opt.gpuid>=0 35 | if opt.cuda: 36 | torch.cuda.set_device(opt.gpuid) 37 | else: 38 | # please use GPU for training, CPU version is not supported for now. 39 | raise NotImplementedError 40 | #logging.info("GPU acceleration is disabled.") 41 | 42 | # prepare training and validation dataset 43 | db = data_prepare.prepare_db(opt) 44 | 45 | # sanity check for FG-NET dataset, not used for now 46 | # assertion: the total images in the eval set lists should be 1002 47 | total_eval_imgs = sum([len(db['eval'][i]) for i in range(len(db['eval']))]) 48 | print(total_eval_imgs) 49 | if db['train'][0].name == 'FGNET': 50 | assert total_eval_imgs == 1002, 'The preparation of the evalset is incorrect.' 51 | 52 | # training 53 | if opt.train: 54 | best_MAEs = [] 55 | last_MAEs = [] 56 | # record the current time 57 | opt.save_dir += time.asctime(time.localtime(time.time())) 58 | # for FG-NET, do training multiple times for leave-one-out validation 59 | # for CACD, do training just once 60 | for exp_id in range(len(db['train'])): 61 | # initialize the model 62 | model_train = model.prepare_model(opt) 63 | 64 | # configurate the optimizer and learning rate scheduler 65 | optim, sche = optimizer.prepare_optim(model_train, opt) 66 | 67 | # train the model and record mean average error (MAE) 68 | model_train, MAE, last_MAE = trainer.train(model_train, optim, sche, db, opt, exp_id) 69 | best_MAEs += MAE 70 | last_MAEs.append(last_MAE.data.item()) 71 | 72 | # remove the trained model for leave-one-out validation 73 | if exp_id != len(db['train']) - 1: 74 | del model_train 75 | 76 | #np.save('./MAE.npy', np.array(best_MAEs)) 77 | #np.save('./Last_MAE.npy', np.array(last_MAEs)) 78 | # save the final trained model 79 | #utils.save_model(model_train, opt) 80 | 81 | # testing a pre-trained model 82 | elif opt.evaluate: 83 | # path to the pre-trained model 84 | save_dir = opt.test_model_path 85 | #example: save_dir = '../model/CACD_MAE_4.59.pth' 86 | model_loaded = torch.load(save_dir) 87 | # test the model on the evaluation set 88 | # the last subject is the test set (compatible with FG-NET) 89 | trainer.evaluate(model_loaded, db['eval'][-1], opt) 90 | return 91 | 92 | if __name__ == '__main__': 93 | main() 94 | -------------------------------------------------------------------------------- /regression/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Initialize a RNDF. 3 | """ 4 | import ndf 5 | 6 | def prepare_model(opt): 7 | # RNDF consists of two parts: 8 | #1. a feature extraction model using residual learning 9 | #2. a neural decision forst 10 | feat_layer = ndf.FeatureLayer(model_type = opt.model_type, 11 | num_output = opt.num_output, 12 | gray_scale = opt.gray_scale, 13 | input_size = opt.image_size, 14 | pretrained = opt.pretrained) 15 | forest = ndf.Forest(opt.n_tree, opt.tree_depth, opt.num_output, 16 | 1, opt.cuda) 17 | model = ndf.NeuralDecisionForest(feat_layer, forest) 18 | if opt.cuda: 19 | model = model.cuda() 20 | else: 21 | raise NotImplementedError 22 | 23 | return model -------------------------------------------------------------------------------- /regression/ndf.py: -------------------------------------------------------------------------------- 1 | import resnet 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.parameter import Parameter 6 | from collections import OrderedDict 7 | import numpy as np 8 | 9 | # smallest positive float number 10 | FLT_MIN = float(np.finfo(np.float32).eps) 11 | FLT_MAX = float(np.finfo(np.float32).max) 12 | 13 | class FeatureLayer(nn.Sequential): 14 | def __init__(self, model_type = 'resnet34', num_output = 256, 15 | input_size = 224, pretrained = False, 16 | gray_scale = False): 17 | """ 18 | Args: 19 | model_type (string): type of model to be used. 20 | num_output (int): number of neurons in the last feature layer 21 | input_size (int): input image size 22 | pretrained (boolean): whether to use a pre-trained model from ImageNet 23 | gray_scale (boolean): whether the input is gray scale image 24 | """ 25 | super(FeatureLayer, self).__init__() 26 | self.model_type = model_type 27 | self.num_output = num_output 28 | if self.model_type == 'hybrid': 29 | # a model using a resnet-like backbone is used for feature extraction 30 | model = resnet.Hybridmodel(self.num_output) 31 | self.add_module('hybrid_model', model) 32 | else: 33 | raise NotImplementedError 34 | 35 | def get_out_feature_size(self): 36 | return self.num_output 37 | 38 | class Tree(nn.Module): 39 | def __init__(self, depth, feature_length, vector_length, use_cuda = True): 40 | """ 41 | Args: 42 | depth (int): depth of the neural decision tree. 43 | feature_length (int): number of neurons in the last feature layer 44 | vector_length (int): length of the mean vector stored at each tree leaf node 45 | use_cuda (boolean): whether to use GPU 46 | """ 47 | super(Tree, self).__init__() 48 | self.depth = depth 49 | self.n_leaf = 2 ** depth 50 | self.feature_length = feature_length 51 | self.vector_length = vector_length 52 | self.is_cuda = use_cuda 53 | 54 | onehot = np.eye(feature_length) 55 | # randomly use some neurons in the feature layer to compute decision function 56 | using_idx = np.random.choice(feature_length, self.n_leaf, replace=False) 57 | self.feature_mask = onehot[using_idx].T 58 | self.feature_mask = Parameter(torch.from_numpy(self.feature_mask).type(torch.FloatTensor),requires_grad=False) 59 | # a leaf node contains a mean vector and a covariance matrix 60 | self.mean = np.ones((self.n_leaf, self.vector_length)) 61 | # TODO: use k-means clusterring to perform leaf node initialization 62 | self.mu_cache = [] 63 | # use sigmoid function as the decision function 64 | self.decision = nn.Sequential(OrderedDict([ 65 | ('sigmoid', nn.Sigmoid()), 66 | ])) 67 | # used for leaf node update 68 | self.covmat = np.array([np.eye(self.vector_length) for i in range(self.n_leaf)]) 69 | # also stores the inverse of the covariant matrix for efficiency 70 | self.covmat_inv = np.array([np.eye(self.vector_length) for i in range(self.n_leaf)]) 71 | # also stores the determinant of the covariant matrix for efficiency 72 | self.factor = np.ones((self.n_leaf)) 73 | if not use_cuda: 74 | raise NotImplementedError 75 | else: 76 | self.mean = Parameter(torch.from_numpy(self.mean).type(torch.FloatTensor).cuda(), requires_grad=False) 77 | self.covmat = Parameter(torch.from_numpy(self.covmat).type(torch.FloatTensor).cuda(), requires_grad=False) 78 | self.covmat_inv = Parameter(torch.from_numpy(self.covmat_inv).type(torch.FloatTensor).cuda(), requires_grad=False) 79 | self.factor = Parameter(torch.from_numpy(self.factor).type(torch.FloatTensor).cuda(), requires_grad=False) 80 | 81 | 82 | def forward(self, x, save_flag = False): 83 | """ 84 | Args: 85 | param x (Tensor): input feature batch of size [batch_size, n_features] 86 | Return: 87 | (Tensor): routing probability of size [batch_size, n_leaf] 88 | """ 89 | cache = {} 90 | if x.is_cuda and not self.feature_mask.is_cuda: 91 | self.feature_mask = self.feature_mask.cuda() 92 | feats = torch.mm(x, self.feature_mask) 93 | decision = self.decision(feats) 94 | decision = torch.unsqueeze(decision,dim=2) 95 | decision_comp = 1-decision 96 | decision = torch.cat((decision,decision_comp),dim=2) 97 | 98 | # save some intermediate results for analysis if needed 99 | if save_flag: 100 | cache['decision'] = decision[:,:,0] 101 | batch_size = x.size()[0] 102 | 103 | mu = x.data.new(batch_size,1,1).fill_(1.) 104 | begin_idx = 1 105 | end_idx = 2 106 | for n_layer in range(0, self.depth): 107 | # mu stores the probability that a sample is routed to certain node 108 | # repeat it to be multiplied for left and right routing 109 | mu = mu.repeat(1, 1, 2) 110 | # the routing probability at n_layer 111 | _decision = decision[:, begin_idx:end_idx, :] # -> [batch_size,2**n_layer,2] 112 | mu = mu*_decision # -> [batch_size,2**n_layer,2] 113 | begin_idx = end_idx 114 | end_idx = begin_idx + 2 ** (n_layer+1) 115 | # merge left and right nodes to the same layer 116 | mu = mu.view(batch_size, -1, 1) 117 | mu = mu.view(batch_size, -1) 118 | 119 | if save_flag: 120 | cache['mu'] = mu 121 | return mu, cache 122 | else: 123 | return mu 124 | 125 | def pred(self, x): 126 | p = torch.mm(self(x), self.mean) 127 | return p 128 | 129 | def update_label_distribution(self, target_batch, check=False): 130 | """ 131 | fix the feature extractor of RNDF and update leaf node mean vectors and covariance matrices 132 | based on a multivariate gaussian distribution 133 | Args: 134 | param target_batch (Tensor): a batch of regression targets of size [batch_size, vector_length] 135 | """ 136 | target_batch = torch.cat(target_batch, dim = 0) 137 | mu = torch.cat(self.mu_cache, dim = 0) 138 | batch_size = len(mu) 139 | # no need for gradient computation 140 | with torch.no_grad(): 141 | leaf_prob_density = mu.data.new(batch_size, self.n_leaf) 142 | for leaf_idx in range(self.n_leaf): 143 | # vectorized code is used for efficiency 144 | temp = target_batch - self.mean[leaf_idx, :] 145 | leaf_prob_density[:, leaf_idx] = (self.factor[leaf_idx]*torch.exp(-0.5*(torch.mm(temp, self.covmat_inv[leaf_idx, :,:])*temp).sum(dim = 1))).clamp(FLT_MIN, FLT_MAX) # Tensor [batch_size, 1] 146 | nominator = (mu * leaf_prob_density).clamp(FLT_MIN, FLT_MAX) # [batch_size, n_leaf] 147 | denomenator = (nominator.sum(dim = 1).unsqueeze(1)).clamp(FLT_MIN, FLT_MAX) # add dimension for broadcasting 148 | zeta = nominator/denomenator # [batch_size, n_leaf] 149 | # new_mean if a weighted sum of all training samples 150 | new_mean = (torch.mm(target_batch.transpose(0, 1), zeta)/(zeta.sum(dim = 0).unsqueeze(0))).transpose(0, 1) # [n_leaf, vector_length] 151 | # allocate for new parameters 152 | new_covmat = new_mean.data.new(self.n_leaf, self.vector_length, self.vector_length) 153 | new_covmat_inv = new_mean.data.new(self.n_leaf, self.vector_length, self.vector_length) 154 | new_factor = new_mean.data.new(self.n_leaf) 155 | for leaf_idx in range(self.n_leaf): 156 | # new covariance matrix is a weighted sum of all covmats of each training sample 157 | weights = zeta[:, leaf_idx].unsqueeze(0) 158 | temp = target_batch - new_mean[leaf_idx, :] 159 | new_covmat[leaf_idx, :,:] = torch.mm(weights*(temp.transpose(0, 1)), temp)/(weights.sum()) 160 | # update cache (factor and inverse) for future use 161 | new_covmat_inv[leaf_idx, :,:] = new_covmat[leaf_idx, :,:].inverse() 162 | if check and new_covmat[leaf_idx, :,:].det() <= 0: 163 | print('Warning: singular matrix %d'%leaf_idx) 164 | new_factor[leaf_idx] = 1.0/max((torch.sqrt(new_covmat[leaf_idx, :,:].det())), FLT_MIN) 165 | # update parameters 166 | self.mean = Parameter(new_mean, requires_grad = False) 167 | self.covmat = Parameter(new_covmat, requires_grad = False) 168 | self.covmat_inv = Parameter(new_covmat_inv, requires_grad = False) 169 | self.factor = Parameter(new_factor, requires_grad = False) 170 | return 171 | 172 | class Forest(nn.Module): 173 | # a neural decision forest is an ensemble of neural decision trees 174 | def __init__(self, n_tree, tree_depth, feature_length, vector_length, use_cuda = False): 175 | super(Forest, self).__init__() 176 | self.trees = nn.ModuleList() 177 | self.n_tree = n_tree 178 | self.tree_depth = tree_depth 179 | self.feature_length = feature_length 180 | self.vector_length = vector_length 181 | for _ in range(n_tree): 182 | tree = Tree(tree_depth, feature_length, vector_length, use_cuda) 183 | self.trees.append(tree) 184 | 185 | def forward(self, x, save_flag = False): 186 | predictions = [] 187 | cache = [] 188 | for tree in self.trees: 189 | if save_flag: 190 | # record some intermediate results 191 | mu, cache_tree = tree(x, save_flag = True) 192 | p = torch.mm(mu, tree.mean) 193 | cache.append(cache_tree) 194 | else: 195 | p = tree.pred(x) 196 | predictions.append(p.unsqueeze(2)) 197 | prediction = torch.cat(predictions,dim=2) 198 | prediction = torch.sum(prediction, dim=2)/self.n_tree 199 | if save_flag: 200 | return prediction, cache 201 | else: 202 | return prediction 203 | 204 | class NeuralDecisionForest(nn.Module): 205 | def __init__(self, feature_layer, forest): 206 | super(NeuralDecisionForest, self).__init__() 207 | self.feature_layer = feature_layer 208 | self.forest = forest 209 | 210 | def forward(self, x, debug = False, save_flag = False): 211 | feats, reg_loss = self.feature_layer(x) 212 | if save_flag: 213 | # return some intermediate results 214 | pred, cache = self.forest(feats, save_flag = True) 215 | return pred, reg_loss, cache 216 | else: 217 | pred = self.forest(feats) 218 | return pred, reg_loss -------------------------------------------------------------------------------- /regression/ndf_vis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualize Deep Neural Deicsion Forest For Facial Age Estimaiton 3 | @author: Shichao (Nicholas) Li 4 | Contact: nicholas.li@connect.ust.hk 5 | License: MIT 6 | """ 7 | import torch 8 | import argparse 9 | import vis_utils 10 | from data_prepare import prepare_db 11 | # This script visualize and help understanding deep neural decision forest. 12 | 13 | def parse_arg(): 14 | parser = argparse.ArgumentParser(description='ndf_vis.py') 15 | parser.add_argument('-gpuid', type=int, default=0) 16 | parser.add_argument('-dataset_name', type=str, default='CACD') 17 | parser.add_argument('-image_size', type=int, default=256) 18 | parser.add_argument('-crop_size', type=int, default=224) 19 | # whether to create cache for the images to avoid reading disks 20 | parser.add_argument('-cache', type=bool, default=False) 21 | # whether to apply data augmentation by applying random transformation 22 | parser.add_argument('-transform', type=bool, default=True) 23 | # whether to apply data augmentation by multiple shape initialization 24 | parser.add_argument('-augment', type=bool, default=False) 25 | # how many times to create different initializations for every image 26 | # whether to use the training set of CACD dataset for training 27 | parser.add_argument('-cacd_train', type=bool, default=True) 28 | # whether to plot images after dataset initialization 29 | parser.add_argument('-visualize', type=bool, default=False) 30 | parser.add_argument('-gray_scale', type=bool, default=False) 31 | return parser.parse_args() 32 | 33 | # get configuration 34 | opt = parse_arg() 35 | 36 | # use GPU 37 | torch.cuda.set_device(opt.gpuid) 38 | 39 | if opt.dataset_name == 'Morph': 40 | # Sorry that the MORPH dataset is currently not freely available. 41 | # For now I can not release my pre-processed dataset without permission. 42 | raise ValueError 43 | elif opt.dataset_name == 'CACD': 44 | model_path = "../pre-trained/CACD_MAE_4.59.pth" 45 | else: 46 | raise NotImplementedError 47 | 48 | # load model 49 | model = torch.load(model_path) 50 | model.cuda() 51 | 52 | # prepare dataset 53 | db = prepare_db(opt) 54 | dataset = db['eval'][0] 55 | 56 | # compute saliency map 57 | # pick a tree within the forest 58 | tree_idx = 0 59 | depth = model.forest.trees[0].depth 60 | # pick a splitting node index (optional) 61 | #node_idx = 0 # 0 - 510 for 511 splitting nodes 62 | 63 | # get saliency maps for a specified node for different input tensors 64 | # vis_utils.get_node_saliency_map(dataset, model, tree_idx, node_idx, name=opt.dataset) 65 | 66 | # get the computational paths for the input 67 | sample, labels, paths, class_pred = vis_utils.get_paths(dataset, model, 68 | name=opt.dataset_name, 69 | depth=depth) 70 | 71 | # for each sample, plot the saliency and visualize how the input influence the 72 | # decision making process 73 | vis_utils.get_path_saliency(sample, labels, paths, class_pred, model, tree_idx, 74 | name=opt.dataset_name) 75 | -------------------------------------------------------------------------------- /regression/optimizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Optimizer preparation 3 | """ 4 | import torch 5 | 6 | def prepare_optim(model, opt): 7 | params = [ p for p in model.parameters() if p.requires_grad] 8 | if opt.optim_type == 'adam': 9 | optimizer = torch.optim.Adam(params, lr = opt.lr, 10 | weight_decay = opt.weight_decay) 11 | elif opt.optim_type == 'sgd': 12 | optimizer = torch.optim.SGD(params, lr = opt.lr, 13 | momentum = opt.momentum, 14 | weight_decay = opt.weight_decay) 15 | # scheduler with pre-defined learning rate decay 16 | # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 17 | # milestones = opt.milestones, 18 | # gamma = opt.gamma) 19 | # automatically decrease learning rate 20 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 21 | mode='min', 22 | factor=0.5, 23 | patience=10, 24 | verbose=True, 25 | min_lr=0.01) 26 | return optimizer, scheduler -------------------------------------------------------------------------------- /regression/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import matplotlib.pyplot as plt 5 | from torchvision import models 6 | 7 | class Hybrid(nn.Module): 8 | # this feature extracor has a resnet50-like backbone, where the building 9 | # blocks can be replaced if needed 10 | def __init__(self, block, num_blocks = [6,8,12,6], replace = [False, False, 11 | False, False], num_classes=10, attention=False): 12 | # a resnet with optional simple attention 13 | super(Hybrid, self).__init__() 14 | self.sub_model = models.resnet50(pretrained=True) 15 | # repalce some building blocks if needed 16 | # the default is no replacement 17 | self.in_planes = 64 18 | # first block 19 | if replace[0]: 20 | del self.sub_model.layer1 21 | self.sub_model.layer1 = self._make_layer(block, 64, num_blocks[0], 22 | stride=1) 23 | # second block 24 | if replace[1]: 25 | del self.sub_model.layer2 26 | self.sub_model.layer2 = self._make_layer(block, 128, num_blocks[1], 27 | stride=2) 28 | # third block 29 | if replace[2]: 30 | self.in_planes = 128 31 | del self.sub_model.layer3 32 | self.sub_model.layer3 = self._make_layer(block, 256, num_blocks[2], 33 | stride=2) 34 | # fourth block 35 | if replace[3]: 36 | self.in_planes = 256 37 | del self.sub_model.layer4 38 | self.sub_model.layer4 = self._make_layer(block, 512, num_blocks[3], 39 | stride=2) 40 | # re-initialize the FC layer 41 | del self.sub_model.fc 42 | # a two-layer fully-connected module 43 | self.sub_model.fc = nn.Sequential( 44 | nn.Linear(2048, 2048), 45 | nn.ReLU(True), 46 | nn.Dropout(0.5), 47 | nn.Linear(2048, num_classes)) 48 | 49 | # an optional spatial attention model 50 | self.attention = attention 51 | if self.attention: 52 | self.gamma1 = 0 53 | self.gamma2 = 0 54 | self.attention_model = nn.Conv2d(2048, 1, kernel_size=1, stride=1) 55 | 56 | def _make_layer(self, block, planes, num_blocks, stride): 57 | strides = [stride] + [1]*(num_blocks-1) 58 | layers = [] 59 | for i in range(len(strides)): 60 | stride = strides[i] 61 | if i == 0: 62 | block_ = BasicBlock 63 | else: 64 | block_ = block 65 | layers.append(block_(self.in_planes, planes, stride)) 66 | self.in_planes = planes * block.expansion 67 | return nn.Sequential(*layers) 68 | 69 | def forward(self, x): 70 | out = F.relu(self.sub_model.bn1(self.sub_model.conv1(x))) 71 | out = self.sub_model.maxpool(out) 72 | out = self.sub_model.layer1(out) 73 | out = self.sub_model.layer2(out) 74 | out = self.sub_model.layer3(out) 75 | out = self.sub_model.layer4(out) 76 | reg_loss = 0. 77 | if self.attention: 78 | # soft attention 79 | # mask = torch.sigmoid(self.attention_model(out)) 80 | # out = out*mask 81 | # hard attention 82 | mask = torch.tanh(self.attention_model(out)) 83 | out = F.relu(out*mask) 84 | reg_loss = self.gamma1*mask.mean() + self.gamma2*(1 - mask**2).mean() 85 | out = self.sub_model.avgpool(out) 86 | out = out.view(out.size(0), -1) 87 | out = self.sub_model.fc(out) 88 | return out, reg_loss 89 | 90 | def Hybridmodel(num_output): 91 | # the default feature extractor is un-modified resnet50 92 | return Hybrid(HierRes, [6,8,12,5], num_classes=num_output) 93 | #-----------------------------------------------------------------------------# 94 | #-----------------------------some deprecated functions-----------------------# 95 | class HierRes(nn.Module): 96 | # a hierarchical block designed for less model parameters 97 | expansion = 1 98 | def __init__(self, in_channels, out_channels, stride=1): 99 | super(HierRes, self).__init__() 100 | if out_channels % 16 != 0: 101 | raise NotImplementedError 102 | self.stride = stride 103 | self.conv1 = nn.Conv2d(in_channels, int(out_channels/2), kernel_size=1, padding=0, stride=stride) 104 | self.bn1 = nn.BatchNorm2d(int(out_channels/2)) 105 | self.relu1 = nn.ReLU(inplace=True) 106 | self.conv2 = nn.Conv2d(int(out_channels/2), int(out_channels/4), kernel_size=3, padding=1, stride=1) 107 | self.bn2 = nn.BatchNorm2d(int(out_channels/4)) 108 | self.relu2 = nn.ReLU(inplace=True) 109 | self.conv3 = nn.Conv2d(int(out_channels/4), int(out_channels/8), kernel_size=3, padding=1, stride=1) 110 | self.bn3 = nn.BatchNorm2d(int(out_channels/8)) 111 | self.relu3 = nn.ReLU(inplace=True) 112 | self.conv4 = nn.Conv2d(int(out_channels/8), int(out_channels/8), kernel_size=3, padding=1, stride=1) 113 | self.bn4 = nn.BatchNorm2d(int(out_channels/8)) 114 | self.relu4 = nn.ReLU(inplace=True) 115 | self.in_num = in_channels 116 | self.out_num = out_channels 117 | if in_channels != out_channels or stride != 1: 118 | self.map = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride) 119 | self.bn_map = nn.BatchNorm2d(out_channels) 120 | 121 | def forward(self, x): 122 | if self.in_num != self.out_num or self.stride != 1: 123 | origin = self.bn_map(self.map(x)) 124 | else: 125 | origin = x 126 | out1 = self.conv1(x) 127 | out1 = self.bn1(out1) 128 | out1 = self.relu1(out1) 129 | out2 = self.conv2(out1) 130 | out2 = self.bn2(out2) 131 | out2 = self.relu2(out2) 132 | out3 = self.conv3(out2) 133 | out3 = self.bn3(out3) 134 | out3 = self.relu3(out3) 135 | out4 = self.conv4(out3) 136 | out4 = self.bn4(out4) 137 | out4 = self.relu4(out4) 138 | out = torch.cat((out1, out2, out3, out4), dim=1) + origin 139 | return out 140 | 141 | class Inception(nn.Module): 142 | # inception module 143 | expansion = 1 144 | def __init__(self, in_channels, out_channels, stride=1): 145 | super(Inception, self).__init__() 146 | if out_channels % 16 != 0: 147 | raise NotImplementedError 148 | self.stride = stride 149 | self.conv1 = nn.Conv2d(in_channels, int(out_channels/2), kernel_size=1, padding=0, stride=stride) 150 | self.bn1 = nn.BatchNorm2d(int(out_channels/2)) 151 | self.relu1 = nn.ReLU(inplace=True) 152 | self.conv2 = nn.Conv2d(in_channels, int(out_channels/4), kernel_size=3, padding=1, stride=stride) 153 | self.bn2 = nn.BatchNorm2d(int(out_channels/4)) 154 | self.relu2 = nn.ReLU(inplace=True) 155 | self.conv3 = nn.Conv2d(in_channels, int(out_channels/4), kernel_size=5, padding=2, stride=stride) 156 | self.bn3 = nn.BatchNorm2d(int(out_channels/4)) 157 | self.relu3 = nn.ReLU(inplace=True) 158 | self.in_num = in_channels 159 | self.out_num = out_channels 160 | if in_channels != out_channels or stride != 1: 161 | self.map = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride) 162 | self.bn_map = nn.BatchNorm2d(out_channels) 163 | 164 | def forward(self, x): 165 | if self.in_num != self.out_num or self.stride != 1: 166 | origin = self.bn_map(self.map(x)) 167 | else: 168 | origin = x 169 | out1 = F.relu(self.bn1(self.conv1(x))) 170 | out2 = F.relu(self.bn2(self.conv2(x))) 171 | out3 = F.relu(self.bn3(self.conv3(x))) 172 | out = torch.cat((out1, out2, out3), dim=1) + origin 173 | return out 174 | 175 | class BasicBlock(nn.Module): 176 | expansion = 1 177 | def __init__(self, in_planes, planes, stride=1): 178 | super(BasicBlock, self).__init__() 179 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 180 | self.bn1 = nn.BatchNorm2d(planes) 181 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 182 | self.bn2 = nn.BatchNorm2d(planes) 183 | 184 | self.shortcut = nn.Sequential() 185 | if stride != 1 or in_planes != self.expansion*planes: 186 | self.shortcut = nn.Sequential( 187 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 188 | nn.BatchNorm2d(self.expansion*planes) 189 | ) 190 | 191 | def forward(self, x): 192 | out = F.relu(self.bn1(self.conv1(x))) 193 | out = self.bn2(self.conv2(out)) 194 | out += self.shortcut(x) 195 | out = F.relu(out) 196 | return out 197 | 198 | 199 | class Bottleneck(nn.Module): 200 | expansion = 4 201 | 202 | def __init__(self, in_planes, planes, stride=1): 203 | super(Bottleneck, self).__init__() 204 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 205 | self.bn1 = nn.BatchNorm2d(planes) 206 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 207 | self.bn2 = nn.BatchNorm2d(planes) 208 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 209 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 210 | 211 | self.shortcut = nn.Sequential() 212 | if stride != 1 or in_planes != self.expansion*planes: 213 | self.shortcut = nn.Sequential( 214 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 215 | nn.BatchNorm2d(self.expansion*planes) 216 | ) 217 | 218 | def forward(self, x): 219 | out = F.relu(self.bn1(self.conv1(x))) 220 | out = F.relu(self.bn2(self.conv2(out))) 221 | out = self.bn3(self.conv3(out)) 222 | out += self.shortcut(x) 223 | out = F.relu(out) 224 | return out 225 | 226 | class ResNet(nn.Module): 227 | def __init__(self, block, num_blocks, num_classes=10): 228 | super(ResNet, self).__init__() 229 | self.in_planes = 64 230 | 231 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 232 | self.bn1 = nn.BatchNorm2d(64) 233 | self.mp = nn.MaxPool2d(kernel_size=3, stride=2) 234 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 235 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 236 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 237 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 238 | self.linear = nn.Linear(512*block.expansion, num_classes) 239 | 240 | def _make_layer(self, block, planes, num_blocks, stride): 241 | strides = [stride] + [1]*(num_blocks-1) 242 | layers = [] 243 | for stride in strides: 244 | layers.append(block(self.in_planes, planes, stride)) 245 | self.in_planes = planes * block.expansion 246 | return nn.Sequential(*layers) 247 | 248 | def forward(self, x): 249 | out = F.relu(self.bn1(self.conv1(x))) 250 | out = self.mp(out) 251 | out = self.layer1(out) 252 | out = self.layer2(out) 253 | out = self.layer3(out) 254 | out = self.layer4(out) 255 | # 224 by 224 input, the output size is 7 by 7 256 | out = F.avg_pool2d(out, 7) 257 | out = out.view(out.size(0), -1) 258 | out = self.linear(out) 259 | return out 260 | 261 | 262 | def ResNet18(num_output): 263 | return ResNet(BasicBlock, [2,2,2,2], num_classes=num_output) 264 | 265 | def ResNet34(num_output, pretrained = False): 266 | model = ResNet(Inception, [6,8,12,6], num_classes=num_output) 267 | # copy the first convolution kernel from a model pre-trained on Imagenet 268 | if pretrained: 269 | model.conv1.weight.data = get_first_conv_layer() 270 | return model 271 | 272 | def ResNet50(num_output): 273 | return ResNet(Bottleneck, [3,4,6,3], num_classes=num_output) 274 | 275 | def ResNet101(num_output): 276 | return ResNet(Bottleneck, [3,4,23,3], num_classes=num_output) 277 | 278 | def ResNet152(num_output): 279 | return ResNet(Bottleneck, [3,8,36,3], num_classes=num_output) 280 | 281 | def get_first_conv_layer(): 282 | # get the first convlution layer from resnet18 pre-trained on ImageNet 283 | temp_model = models.resnet18(pretrained=True) 284 | weight = temp_model.conv1.weight.data.clone() 285 | del temp_model 286 | return weight 287 | 288 | def visualize_filters(weight): 289 | assert len(weight.shape) == 4 290 | assert weight.shape[1] == 3 291 | col_num = 10 292 | row_num = int(weight.shape[0]/col_num) + 1 293 | for filter_idx in range(weight.shape[0]): 294 | plt.subplot(row_num, col_num, filter_idx+1) 295 | plt.imshow(normalize(weight[filter_idx,:]).numpy().transpose((1,2,0))) 296 | 297 | def test(): 298 | net = ResNet34(128) 299 | y = net(torch.randn(1,3,224,224)) 300 | print(y.size()) 301 | 302 | def normalize(tensor): 303 | return (tensor/(tensor.max()-tensor.min()) + 1)/2 304 | -------------------------------------------------------------------------------- /regression/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | training utilities 3 | """ 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import logging 8 | import os 9 | 10 | import utils 11 | # smallest positive float number 12 | FLT_MIN = float(np.finfo(np.float32).eps) 13 | 14 | def prepare_batches(model, dataset, opt): 15 | # prepare some feature batches for leaf node distribution update 16 | with torch.no_grad(): 17 | target_batches = [] 18 | train_loader = torch.utils.data.DataLoader(dataset, 19 | batch_size = opt.batch_size, 20 | shuffle = True, 21 | num_workers = opt.num_threads) 22 | num_batch = int(np.ceil(opt.label_batch_size/opt.batch_size)) 23 | for batch_idx, batch in enumerate(train_loader): 24 | if batch_idx == num_batch: 25 | break 26 | data = batch['image'] 27 | targets = batch['age'] 28 | targets = targets.view(len(targets), -1) 29 | if opt.cuda: 30 | data, targets = data.cuda(), targets.cuda() 31 | # Get feats 32 | feats, _ = model.feature_layer(data) 33 | # release data Tensor to save memory 34 | del data 35 | for tree in model.forest.trees: 36 | mu = tree(feats) 37 | # add the minimal value to prevent some numerical issue 38 | mu += FLT_MIN # [batch_size, n_leaf] 39 | # store the routing probability for each tree 40 | tree.mu_cache.append(mu) 41 | # release memory 42 | del feats 43 | # the update rule will use both the routing probability and the 44 | # target values 45 | target_batches.append(targets) 46 | return target_batches 47 | 48 | def train(model, optim, sche, db, opt, exp_id): 49 | """ 50 | Args: 51 | model: the model to be trained 52 | optim: pytorch optimizer to be used 53 | db : prepared torch dataset object 54 | opt: command line input from the user 55 | exp_id: experiment id 56 | """ 57 | 58 | best_model_dir = os.path.join(opt.save_dir, str(exp_id)) 59 | if not os.path.exists(best_model_dir): 60 | os.makedirs(best_model_dir) 61 | 62 | # (For FG-NET only) carry out leave-one-out validation according to the list length 63 | assert len(db['train']) == len(db['eval']) 64 | 65 | # record for each training experiment 66 | best_MAE = [] 67 | train_set = db['train'][exp_id] 68 | eval_set = db['eval'][exp_id] 69 | eval_loss, min_MAE, _ = evaluate(model, eval_set, opt) 70 | # in drop out mode, each time only leaf nodes of one tree is updated 71 | if opt.dropout: 72 | current_tree = 0 73 | 74 | # save training and validation history 75 | if opt.history: 76 | train_loss_history = [] 77 | eval_loss_history = [] 78 | 79 | for epoch in range(1, opt.epochs + 1): 80 | # At each epoch, train the neural decision forest and update 81 | # the leaf node distribution separately 82 | 83 | # Train neural decision forest 84 | # set the model in the training mode 85 | model.train() 86 | # data loader 87 | train_loader = torch.utils.data.DataLoader(train_set, 88 | batch_size = opt.batch_size, 89 | shuffle = True, 90 | num_workers = opt.num_threads) 91 | 92 | for batch_idx, batch in enumerate(train_loader): 93 | data = batch['image'] 94 | target = batch['age'] 95 | target = target.view(len(target), -1) 96 | if opt.cuda: 97 | with torch.no_grad(): 98 | # move to GPU 99 | data, target = data.cuda(), target.cuda() 100 | # erase all computed gradient 101 | optim.zero_grad() 102 | #prediction, decision_loss = model(data) 103 | 104 | # forward pass to get prediction 105 | prediction, reg_loss = model(data) 106 | 107 | loss = F.mse_loss(prediction, target) + reg_loss 108 | 109 | # compute gradient in the computational graph 110 | loss.backward() 111 | 112 | # update parameters in the model 113 | optim.step() 114 | 115 | # logging 116 | if batch_idx % opt.report_every == 0: 117 | logging.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} '.format( 118 | epoch, batch_idx * opt.batch_size, len(train_set), 119 | 100. * batch_idx / len(train_loader), loss.data.item())) 120 | # record loss 121 | if opt.history: 122 | train_loss_history.append((epoch, batch_idx, loss.data.item())) 123 | 124 | # Update the leaf node estimation 125 | if opt.leaf_node_type == 'simple' and batch_idx % opt.update_every == 0: 126 | logging.info("Epoch %d : Update leaf node prediction"%(epoch)) 127 | target_batches = prepare_batches(model, train_set, opt) 128 | # Update label prediction for each tree 129 | logging.info("Update leaf node prediction...") 130 | for i in range(opt.label_iter_time): 131 | # prepare features from the last feature layer 132 | # some cache is also stored in the forest for leaf node 133 | if opt.dropout: 134 | model.forest.trees[current_tree].update_label_distribution(target_batches) 135 | current_tree = (current_tree + 1)%opt.n_tree 136 | else: 137 | for tree in model.forest.trees: 138 | tree.update_label_distribution(target_batches) 139 | # release cache 140 | for tree in model.forest.trees: 141 | del tree.mu_cache 142 | tree.mu_cache = [] 143 | 144 | 145 | if opt.eval and batch_idx!= 0 and batch_idx % opt.eval_every == 0: 146 | # evaluate model 147 | eval_loss, MAE, CS = evaluate(model, eval_set, opt) 148 | # update learning rate 149 | sche.step(MAE.data.item()) 150 | # record the final MAE 151 | if epoch == opt.epochs: 152 | last_MAE = MAE 153 | # record the best MAE 154 | if MAE < min_MAE: 155 | min_MAE = MAE 156 | # save the best model 157 | model_name = opt.model_type + train_set.name 158 | best_model_path = os.path.join(best_model_dir, model_name) 159 | utils.save_best_model(model.cpu(), best_model_path) 160 | model.cuda() 161 | # update log 162 | utils.update_log(best_model_dir, (str(MAE.data.item()), 163 | str(min_MAE.data.item())), 164 | str(CS)) 165 | if opt.history: 166 | eval_loss_history.append((epoch, batch_idx, eval_loss, MAE)) 167 | # reset to training mode 168 | model.train() 169 | best_MAE.append(min_MAE.data.item()) 170 | if opt.history: 171 | utils.save_history(np.array(train_loss_history), np.array(eval_loss_history), opt) 172 | logging.info('Training finished.') 173 | return model, best_MAE, last_MAE 174 | 175 | def evaluate(model, dataset, opt, report_loss = True, predict = False): 176 | model.eval() 177 | if opt.cuda: 178 | model.cuda() 179 | loader = torch.utils.data.DataLoader(dataset, 180 | batch_size = opt.batch_size, 181 | shuffle=False, 182 | num_workers=opt.num_threads) 183 | eval_loss = 0 184 | MAE = 0 185 | # used to compute cumulative score (CS) 186 | threshold = opt.threshold/dataset.scale_factor 187 | counts_below_threshold = 0 188 | predicted_ages = [] 189 | for batch_idx, batch in enumerate(loader): 190 | data = batch['image'] 191 | target = batch['age'] 192 | target = target.view(len(target), -1) 193 | with torch.no_grad(): 194 | if opt.cuda: 195 | data, target = data.cuda(), target.cuda() 196 | prediction, reg_loss = model(data) 197 | predicted_ages += [prediction[i].data.item() for i in range(len(prediction))] 198 | age = prediction.view(len(prediction), -1) 199 | if report_loss: 200 | # rescale the predicted and target residual 201 | MAE += torch.abs((age - target)).sum(dim = 1).sum(dim = 0) 202 | counts_below_threshold += (torch.abs(age-target) < threshold).sum().data.item() 203 | eval_loss += F.mse_loss(prediction, target.view(len(target), -1), reduction='sum').data.item() 204 | if report_loss and not predict: 205 | eval_loss = eval_loss/len(dataset) 206 | MAE /= len(dataset) 207 | MAE *= dataset.scale_factor 208 | CS = counts_below_threshold/len(dataset) 209 | logging.info('{:s} set: Average loss: {:.4f}.'.format(dataset.split, eval_loss)) 210 | logging.info('{:s} set: Mean absolute error: {:.4f}.'.format(dataset.split, MAE)) 211 | logging.info('{:s} set: Cumulative score: {:.4f}.'.format(dataset.split, CS)) 212 | return eval_loss, MAE, CS -------------------------------------------------------------------------------- /regression/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | miscellaneous utility functions 3 | """ 4 | import argparse 5 | import time 6 | import logging 7 | 8 | import numpy as np 9 | from torch import save 10 | from os import path as path 11 | import os 12 | 13 | # function for parsing input arguments 14 | def parse_arg(): 15 | parser = argparse.ArgumentParser(description='train.py') 16 | ## paths 17 | parser.add_argument('-img_path', type=str, default='../data/CACD/crop') 18 | parser.add_argument('-save_dir', type=str, default='../model/') 19 | parser.add_argument('-save_his_dir', type=str, default='../history/CACD') 20 | parser.add_argument('-test_model_path', type=str, default='../model/CACD_MAE_4.59.pth') 21 | ##-----------------------------------------------------------------------## 22 | ## model settings 23 | parser.add_argument('-save_name', type=str, default='trained_model') 24 | parser.add_argument('-num_output', type=int, default=128) # only used for coupled routing functions 25 | parser.add_argument('-model_type', type=str, default='hybrid') # only used for coupled routing functions (hierarchy = False) 26 | parser.add_argument('-n_tree', type=int, default=5) 27 | parser.add_argument('-tree_depth', type=int, default=6) 28 | parser.add_argument('-leaf_node_type', type=str, default='simple') 29 | parser.add_argument('-pretrained', type=bool, default=False) # only used for coupled routing functions 30 | ## training settings 31 | # choose one tree for update at a time 32 | parser.add_argument('-dropout', type=bool, default=False) 33 | parser.add_argument('-batch_size', type=int, default=30) 34 | # random seed 35 | parser.add_argument('-seed', type=int, default=2019) 36 | # batch size used when updating label prediction 37 | parser.add_argument('-label_batch_size', type=int, default=500) 38 | # number of threads to use when loading data 39 | parser.add_argument('-num_threads', type=int, default=4) 40 | # update leaf node distribution every certain number of network training 41 | parser.add_argument('-update_every', type=int, default=50) 42 | # how many iterations to update prediction in each leaf node 43 | parser.add_argument('-label_iter_time', type=int, default=20) 44 | parser.add_argument('-gpuid', type=int, default=0) 45 | parser.add_argument('-epochs', type=int, default=15) 46 | parser.add_argument('-report_every', type=int, default=40) 47 | # whether to perform evaluation on evaluation set during training 48 | parser.add_argument('-eval', type=bool, default=True) 49 | # whether to record and report loss history at the end of training 50 | parser.add_argument('-history', type=bool, default=False) 51 | parser.add_argument('-eval_every', type=int, default=100) 52 | # threshold using for computing CS 53 | parser.add_argument('-threshold', type=int, default=5) 54 | ##-----------------------------------------------------------------------## 55 | ## dataset settings 56 | parser.add_argument('-dataset_name', type=str, default='CACD') 57 | parser.add_argument('-image_size', type=int, default=256) 58 | parser.add_argument('-crop_size', type=int, default=224) 59 | # whether to create cache for the images to avoid reading disks 60 | parser.add_argument('-cache', type=bool, default=False) 61 | # whether to apply data augmentation by applying random transformation 62 | parser.add_argument('-transform', type=bool, default=True) 63 | # whether to apply data augmentation by multiple shape initialization 64 | parser.add_argument('-augment', type=bool, default=False) 65 | # whether to use the training set of CACD dataset for training 66 | parser.add_argument('-cacd_train', type=bool, default=True) 67 | # whether to plot images after dataset initialization 68 | parser.add_argument('-gray_scale', type=bool, default=False) 69 | ##-----------------------------------------------------------------------## 70 | # Optimizer settings 71 | parser.add_argument('-optim_type', type=str, default='sgd') 72 | parser.add_argument('-lr', type=float, default=0.5, help="sgd: 0.5, adam: 0.001") 73 | parser.add_argument('-weight_decay', type=float, default=0.0) 74 | parser.add_argument('-momentum', type=float, default=0.9, help="sgd: 0.9") 75 | # reduce the learning rate after each milestone 76 | #parser.add_argument('-milestones', type=list, default=[6, 12, 18]) 77 | parser.add_argument('-milestones', type=list, default=[2,4,6,8]) 78 | # how much to reduce the learning rate 79 | parser.add_argument('-gamma', type=float, default=0.5) 80 | ##-----------------------------------------------------------------------## 81 | ## usage configuration 82 | parser.add_argument('-train', type=bool, default=False) 83 | parser.add_argument('-evaluate', type=bool, default=False) 84 | opt = parser.parse_args() 85 | return opt 86 | 87 | def get_save_dir(opt, str_type=None): 88 | if str_type == 'his': 89 | root = opt.save_his_dir 90 | else: 91 | root = opt.save_dir 92 | save_name = path.join(root, opt.save_name) 93 | save_name += '_model_type_' 94 | save_name += opt.model_type 95 | save_name += '_RNDF_' 96 | save_name += '_depth{:d}_tree{:d}_output{:d}'.format(opt.tree_depth, opt.n_tree, opt.num_output) 97 | save_name += time.asctime(time.localtime(time.time())) 98 | save_name += '.pth' 99 | return save_name 100 | 101 | def save_model(model, opt): 102 | save_name = get_save_dir(opt) 103 | save(model, save_name) 104 | return 105 | 106 | def save_best_model(model, path): 107 | save(model, path) 108 | return 109 | 110 | def update_log(best_model_dir, MAE, CS): 111 | text = time.asctime(time.localtime(time.time())) + ' ' 112 | text += "Current MAE: " + MAE[0] + " Current CS: " + CS + " " 113 | text += "Best MAE: " + MAE[1] + "\r\n" 114 | with open(os.path.join(best_model_dir, "log.txt"), "a") as myfile: 115 | myfile.write(text) 116 | return 117 | 118 | def save_history(train_his, eval_his, opt): 119 | save_name = get_save_dir(opt, 'his') 120 | train_his_name = save_name +'train_his_stage' 121 | eval_his_name = save_name + 'eval_his_stage' 122 | if not path.exists(opt.save_his_dir): 123 | os.mkdir(opt.save_his_dir) 124 | save(train_his, train_his_name) 125 | save(eval_his, eval_his_name) 126 | 127 | def split_dic(data_dic): 128 | img_path_list = [] 129 | age_list = [] 130 | for key in data_dic: 131 | img_path_list += data_dic[key]['path'] 132 | age_list += data_dic[key]['age_list'] 133 | img_path_list = np.array(img_path_list) 134 | age_list = np.array(age_list) 135 | total_imgs = len(img_path_list) 136 | random_indices = np.random.choice(total_imgs, total_imgs, replace=False) 137 | num_train = int(len(img_path_list)*0.8) 138 | train_path_list = list(img_path_list[random_indices[:num_train]]) 139 | train_age_list = list(age_list[random_indices[:num_train]]) 140 | valid_path_list = list(img_path_list[random_indices[num_train:]]) 141 | valid_age_list = list(age_list[random_indices[num_train:]]) 142 | train_dic = {'path':train_path_list, 'age_list':train_age_list} 143 | valid_dic = {'path':valid_path_list, 'age_list':valid_age_list} 144 | return train_dic, valid_dic 145 | 146 | def check_split(train_dic, eval_dic): 147 | train_num = len(train_dic['path']) 148 | valid_num = len(eval_dic['path']) 149 | logging.info("Image split: {:d} training, {:d} validation".format(train_num, valid_num)) 150 | logging.info("Total unique image num: {:d} ".format(len(set(train_dic['path'] + eval_dic['path'])))) 151 | return -------------------------------------------------------------------------------- /regression/vis_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from matplotlib.patches import ConnectionPatch 3 | import numpy as np 4 | import torch 5 | 6 | def get_sample(dataset, sample_num, name): 7 | # random seed 8 | #np.random.seed(2019) 9 | # get a batch of random sample images from the dataset 10 | indices = np.random.choice(list(range(len(dataset))), sample_num) 11 | if name in ['CACD', 'Morph', 'FGNET']: 12 | sample = [dataset[indices[i]]['image'].unsqueeze(0) for i in range(len(indices))] 13 | sample = torch.cat(sample, dim = 0) 14 | label = [dataset[indices[i]]['age'] for i in range(len(indices))] 15 | else: 16 | raise ValueError 17 | return sample, label 18 | 19 | def revert_preprocessing(data_tensor, name): 20 | if name == 'FGNET': 21 | data_tensor[:,0,:,:] = data_tensor[:,0,:,:]*0.218 + 0.425 22 | data_tensor[:,1,:,:] = data_tensor[:,1,:,:]*0.191 + 0.342 23 | data_tensor[:,2,:,:] = data_tensor[:,2,:,:]*0.182 + 0.314 24 | elif name == 'CACD': 25 | data_tensor[:,0,:,:] = data_tensor[:,0,:,:]*0.3 + 0.432 26 | data_tensor[:,1,:,:] = data_tensor[:,1,:,:]*0.264 + 0.359 27 | data_tensor[:,2,:,:] = data_tensor[:,2,:,:]*0.252 + 0.32 28 | elif name == 'Morph': 29 | data_tensor[:,0,:,:] = data_tensor[:,0,:,:]*0.281 + 0.564 30 | data_tensor[:,1,:,:] = data_tensor[:,1,:,:]*0.255 + 0.521 31 | data_tensor[:,2,:,:] = data_tensor[:,2,:,:]*0.246 + 0.508 32 | else: 33 | raise NotImplementedError 34 | return data_tensor 35 | 36 | def normalize(gradient, name): 37 | # take the maximum gradient from the 3 channels 38 | gradient = (gradient.max(dim=1)[0]).unsqueeze(dim=1) 39 | # normalize the gradient map to 0-1 range 40 | # get the maximum gradient 41 | max_gradient = torch.max(gradient.view(len(gradient), -1), dim=1)[0] 42 | max_gradient = max_gradient.view(len(gradient), 1, 1, 1) 43 | min_gradient = torch.min(gradient.view(len(gradient), -1), dim=1)[0] 44 | min_gradient = min_gradient.view(len(gradient), 1, 1, 1) 45 | # Do normalization 46 | gradient = (gradient - min_gradient)/(max_gradient - min_gradient) 47 | return gradient 48 | 49 | def get_parents_path(leaf_idx): 50 | parent_list = [] 51 | while leaf_idx > 1: 52 | parent = int(leaf_idx/2) 53 | parent_list = [parent] + parent_list 54 | leaf_idx = int(leaf_idx/2) 55 | return parent_list 56 | 57 | def trace(record, mu, depth): 58 | # get the computational path that is most likely to visit 59 | # from the forward pass record of one input sample 60 | path = [] 61 | # probability of arriving at the root node 62 | strongest_leaf_idx = np.argmax(mu) 63 | path.append((1,1)) 64 | prob = 1 65 | parent_list = get_parents_path(strongest_leaf_idx + 2**depth) 66 | for i in range(1, len(parent_list)): 67 | current_idx = parent_list[i] 68 | parent_idx = parent_list[i-1] 69 | if current_idx == (parent_idx*2 + 1): 70 | prob *= 1 - record[parent_idx] 71 | else: 72 | prob *= record[parent_idx] 73 | path.append((current_idx, prob)) 74 | return path 75 | 76 | def get_paths(dataset, model, name, depth, sample_num = 5): 77 | # compute the paths for the input 78 | sample, label = get_sample(dataset, sample_num, name) 79 | # forward pass 80 | pred, _, cache = model(sample.cuda(), save_flag = True) 81 | # pick the path that has the largest probability of being visited 82 | paths = [] 83 | for sample_idx in range(len(sample)): 84 | max_prob = 0 85 | for tree_idx in range(len(cache)): 86 | decision = cache[tree_idx]['decision'].data.cpu().numpy() 87 | mu = cache[tree_idx]['mu'].data.cpu().numpy() 88 | tempt_path = trace(decision[sample_idx], mu[sample_idx], depth) 89 | if tempt_path[-1][1] > max_prob: 90 | max_prob = tempt_path[-1][1] 91 | best_path = tempt_path 92 | paths.append(best_path) 93 | return sample, label, paths, pred 94 | 95 | def get_map(model, sample, node_idx, tree_idx, name): 96 | # helper function for computing the saliency map for a specified sample 97 | # and node 98 | sample = sample.unsqueeze(dim=0).cuda() 99 | sample.requires_grad = True 100 | feat = model.feature_layer(sample)[0] 101 | feature_mask = model.forest.trees[tree_idx].feature_mask.data.cpu().numpy() 102 | using_idx = np.argmax(feature_mask, axis=0)[node_idx] 103 | feat[:, using_idx].backward() 104 | gradient = sample.grad.data 105 | gradient = normalize(torch.abs(gradient), name) 106 | saliency_map = gradient.squeeze().cpu().numpy() 107 | return saliency_map 108 | 109 | def get_path_saliency(samples, labels, paths, pred, model, tree_idx, name, orientation = 'horizontal'): 110 | # show the saliency maps for the input samples with their 111 | # computational paths 112 | plt.figure(figsize=(20,4)) 113 | plt.rcParams.update({'font.size': 18}) 114 | num_samples = len(samples) 115 | path_length = len(paths[0]) 116 | for sample_idx in range(num_samples): 117 | sample = samples[sample_idx] 118 | # plot the sample 119 | plt.subplot(num_samples, path_length + 1, sample_idx*(path_length + 1) + 1) 120 | sample_to_plot = revert_preprocessing(sample.unsqueeze(dim=0), name) 121 | plt.imshow(sample_to_plot.squeeze().cpu().numpy().transpose((1,2,0))) 122 | plt.axis('off') 123 | plt.title('Pred:{:.2f}, GT:{:.0f}'.format(pred[sample_idx].data.item()*100, 124 | labels[sample_idx]*100)) 125 | path = paths[sample_idx] 126 | for node_idx in range(path_length): 127 | # compute and plot saliency for each node 128 | node = path[node_idx][0] 129 | # probability of arriving at this node 130 | prob = path[node_idx][1] 131 | saliency_map = get_map(model, sample, node, tree_idx, name) 132 | if orientation == 'horizontal': 133 | sub_plot_idx = sample_idx*(path_length + 1) + node_idx + 2 134 | plt.subplot(num_samples, path_length + 1, sub_plot_idx) 135 | elif orientation == 'vertical': 136 | raise NotImplementedError 137 | else: 138 | raise NotImplementedError 139 | plt.imshow(saliency_map,cmap='hot') 140 | plt.title('(N{:d}, P{:.2f})'.format(node, prob)) 141 | plt.axis('off') 142 | # draw some arrows 143 | for arrow_idx in range(num_samples*(path_length + 1) - 1): 144 | if (arrow_idx+1) % (path_length+1) == 0 and arrow_idx != 0: 145 | continue 146 | ax1 = plt.subplot(num_samples, path_length + 1, arrow_idx + 1) 147 | ax2 = plt.subplot(num_samples, path_length + 1, arrow_idx + 2) 148 | arrow = ConnectionPatch(xyA=[1.1,0.5], xyB=[-0.1, 0.5], coordsA='axes fraction', coordsB='axes fraction', 149 | axesA=ax1, axesB=ax2, arrowstyle="fancy") 150 | ax1.add_artist(arrow) 151 | left = 0.02 # the left side of the subplots of the figure 152 | right = 1 # the right side of the subplots of the figure 153 | bottom = 0.01 # the bottom of the subplots of the figure 154 | top = 0.90 # the top of the subplots of the figure 155 | wspace = 0.20 # the amount of width reserved for space between subplots, 156 | # expressed as a fraction of the average axis width 157 | hspace = 0.24 # the amount of height reserved for space between subplots, 158 | # expressed as a fraction of the average axis height 159 | plt.subplots_adjust(left, bottom, right, top, wspace, hspace) 160 | plt.show() 161 | return -------------------------------------------------------------------------------- /teasers/cacd_final1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nicholasli1995/VisualizingNDF/8209a75ad55201c1ec712580201b440669fcca73/teasers/cacd_final1.png -------------------------------------------------------------------------------- /teasers/cifar10_results.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nicholasli1995/VisualizingNDF/8209a75ad55201c1ec712580201b440669fcca73/teasers/cifar10_results.pdf -------------------------------------------------------------------------------- /teasers/cifar10_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nicholasli1995/VisualizingNDF/8209a75ad55201c1ec712580201b440669fcca73/teasers/cifar10_results.png -------------------------------------------------------------------------------- /teasers/mnist_results.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nicholasli1995/VisualizingNDF/8209a75ad55201c1ec712580201b440669fcca73/teasers/mnist_results.pdf -------------------------------------------------------------------------------- /teasers/mnist_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nicholasli1995/VisualizingNDF/8209a75ad55201c1ec712580201b440669fcca73/teasers/mnist_results.png --------------------------------------------------------------------------------