├── images ├── Training-Results.png ├── ResNet18-Architecture.png ├── Training-Results-Python.png └── Steps-Loading-Data-PyTorch.png ├── Transfer-Learning-on-Dogs-vs-Cats.pdf ├── README.md ├── CMakeLists.txt ├── scripts └── convert.py ├── classify.cpp ├── main.py ├── main.h ├── main.cpp ├── Blog.md └── Transfer-Learning-on-Dogs-vs-Cats.ipynb /images/Training-Results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/krshrimali/Transfer-Learning-Dogs-Cats-Libtorch/HEAD/images/Training-Results.png -------------------------------------------------------------------------------- /images/ResNet18-Architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/krshrimali/Transfer-Learning-Dogs-Cats-Libtorch/HEAD/images/ResNet18-Architecture.png -------------------------------------------------------------------------------- /images/Training-Results-Python.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/krshrimali/Transfer-Learning-Dogs-Cats-Libtorch/HEAD/images/Training-Results-Python.png -------------------------------------------------------------------------------- /Transfer-Learning-on-Dogs-vs-Cats.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/krshrimali/Transfer-Learning-Dogs-Cats-Libtorch/HEAD/Transfer-Learning-on-Dogs-vs-Cats.pdf -------------------------------------------------------------------------------- /images/Steps-Loading-Data-PyTorch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/krshrimali/Transfer-Learning-Dogs-Cats-Libtorch/HEAD/images/Steps-Loading-Data-PyTorch.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Transfer-Learning-Dogs-Cats-Libtorch 2 | 3 | Transfer Learning on Dogs vs Cats dataset using PyTorch C++ API 4 | 5 | **Implementation** 6 | 7 | 1. `mkdir build` 8 | 2. `cmake -DCMAKE_PREFIX_PATH=/absolute/path/to/libtorch ..` 9 | 3. `make` 10 | 4. `./example ` 11 | 12 | TODOs: 13 | 14 | 1. Load dataset in the way suggested. Prevents OOM (lazily load a single image) 15 | 2. ~~Test accuracy. And predictions samples.~~ 16 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0) 2 | project(example) 3 | 4 | find_package(Torch REQUIRED) 5 | find_package(OpenCV 4.1.0 REQUIRED) 6 | 7 | include_directories(${OpenCV_INCLUDE_DIRS}) 8 | 9 | add_executable(example main.cpp main.h) 10 | add_executable(classify classify.cpp) 11 | 12 | target_link_libraries(example ${OpenCV_LIBS}) 13 | target_link_libraries(example "${TORCH_LIBRARIES}") 14 | target_link_libraries(classify ${OpenCV_LIBS}) 15 | target_link_libraries(classify "${TORCH_LIBRARIES}") 16 | 17 | set_property(TARGET classify PROPERTY CXX_STANDARD 11) 18 | set_property(TARGET example PROPERTY CXX_STANDARD 11) 19 | -------------------------------------------------------------------------------- /scripts/convert.py: -------------------------------------------------------------------------------- 1 | """ 2 | This python script converts the network into Script Module 3 | """ 4 | 5 | import torch 6 | import torchvision.models as models 7 | 8 | resnet18 = models.resnet18(pretrained=True) 9 | for param in resnet18.parameters(): 10 | param.requires_grad = False 11 | resnet18.fc = torch.nn.Linear(512, 2) 12 | for param in resnet18.fc.parameters(): 13 | param.requires_grad = True 14 | 15 | example_input = torch.rand(1, 3, 224, 224) 16 | script_module = torch.jit.trace(resnet18, example_input) 17 | script_module.save('resnet18_with_last_layer.pt') 18 | 19 | # print(list(resnet18.children())) 20 | # resnet18 = torch.nn.Sequential(*list(resnet18.children())) # Take all layers except the last one 21 | # list(resnet18.children())[-1] = torch.nn.Linear(512, 2) 22 | # print(list(resnet18.children())[-1]) 23 | ''' 24 | example_input = torch.rand(1, 3, 224, 224) 25 | 26 | script_module = torch.jit.trace(resnet18, example_input) 27 | script_module.save('resnet18_without_lastlayer.pt') 28 | ''' 29 | -------------------------------------------------------------------------------- /classify.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // classify.cpp 3 | // transfer-learning 4 | // 5 | // Created by Kushashwa Ravi Shrimali on 15/08/19. 6 | // Copyright © 2019 Kushashwa Ravi Shrimali. All rights reserved. 7 | // 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | //#include "model.h" 15 | 16 | int main(int arc, char** argv) 17 | { 18 | std::string loc = argv[1]; 19 | 20 | // Load image with OpenCV. 21 | cv::Mat img = cv::imread(loc); 22 | cv::resize(img, img, cv::Size(224, 224), cv::INTER_CUBIC); 23 | // Convert the image and label to a tensor. 24 | torch::Tensor img_tensor = torch::from_blob(img.data, {1, img.rows, img.cols, 3}, torch::kByte); 25 | img_tensor = img_tensor.permute({0, 3, 1, 2}); // convert to CxHxW 26 | img_tensor = img_tensor.to(torch::kF32); 27 | 28 | // Load the model. 29 | torch::jit::script::Module model; 30 | model = torch::jit::load(argv[2]); 31 | 32 | std::cout << "Model loaded" << std::endl; 33 | // Predict the probabilities for the classes. 34 | std::vector input; 35 | input.push_back(img_tensor); 36 | torch::Tensor prob = model.forward(input).toTensor(); 37 | // torch::Tensor prob = torch::exp(log_prob); 38 | 39 | std::cout << "Probability of being cat: " << *(prob.data())*100. << ", of being dog: " << *(prob.data() + 1)*100. << std::endl; 40 | 41 | return 0; 42 | } 43 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torchvision 4 | from torchvision import datasets,transforms, models 5 | import os 6 | import numpy as np 7 | # import matplotlib.pyplot as plt 8 | from torch.autograd import Variable 9 | import time 10 | 11 | folder_path = "/Users/krshrimali/Documents/krshrimali-blogs/dataset/train/train_python/" 12 | transform = transforms.Compose([transforms.CenterCrop(224), transforms.ToTensor()]) 13 | data = datasets.ImageFolder(root = os.path.join(folder_path), transform = transform) 14 | 15 | batch_size = 4 16 | data_loader = torch.utils.data.DataLoader(dataset=data, batch_size = batch_size, shuffle = True) 17 | 18 | model = models.resnet18(pretrained = True) 19 | 20 | for parma in model.parameters(): 21 | parma.requires_grad = False 22 | 23 | model.fc = torch.nn.Linear(512, 2) 24 | 25 | for param in model.fc.parameters(): 26 | param.requires_grad = True 27 | 28 | cost = torch.nn.CrossEntropyLoss() 29 | optimizer = torch.optim.Adam(model.fc.parameters()) 30 | 31 | n_epochs = 15 32 | 33 | for epoch in range(n_epochs): 34 | mse = 0.0 35 | acc = 0 36 | batch_index = 0 37 | 38 | for data_batch in data_loader: 39 | batch_index += 1 40 | image, label = data_batch 41 | 42 | optimizer.zero_grad() 43 | 44 | output = model(image) 45 | _, predicted_label = torch.max(output.data, 1) 46 | 47 | loss = cost(output, label) 48 | 49 | loss.backward() 50 | optimizer.step() 51 | 52 | mse += loss.item() # data[0] 53 | acc += torch.sum(predicted_label == label.data) 54 | 55 | mse = mse/len(data) 56 | acc = 100*acc/len(data) 57 | 58 | print("Epoch: {}/{}, Loss: {:.4f}, Accuracy: {:.4f}".format(epoch+1, n_epochs, mse, acc)) 59 | -------------------------------------------------------------------------------- /main.h: -------------------------------------------------------------------------------- 1 | // 2 | // main.h 3 | // transfer-learning 4 | // 5 | // Created by Kushashwa Ravi Shrimali on 15/08/19. 6 | // Copyright © 2019 Kushashwa Ravi Shrimali. All rights reserved. 7 | // 8 | 9 | #ifndef main_h 10 | #define main_h 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | // Function to return image read at location given as type torch::Tensor 19 | // Resizes image to (224, 224, 3) 20 | torch::Tensor read_data(std::string location); 21 | 22 | // Function to return label from int (0, 1 for binary and 0, 1, ..., n-1 for n-class classification) as type torch::Tensor 23 | torch::Tensor read_label(int label); 24 | 25 | // Function returns vector of tensors (images) read from the list of images in a folder 26 | std::vector process_images(std::vector list_images); 27 | 28 | // Function returns vector of tensors (labels) read from the list of labels 29 | std::vector process_labels(std::vector list_labels); 30 | 31 | // Function to load data from given folder(s) name(s) (folders_name) 32 | // Returns pair of vectors of string (image locations) and int (respective labels) 33 | std::pair, std::vector> load_data_from_folder(std::vector folders_name); 34 | 35 | // Function to train the network on train data 36 | template 37 | void train(torch::jit::script::Module net, torch::nn::Linear lin, Dataloader& data_loader, torch::optim::Optimizer& optimizer, size_t dataset_size); 38 | 39 | // Function to test the network on test data 40 | template 41 | void test(torch::jit::script::Module network, torch::nn::Linear lin, Dataloader& loader, size_t data_size); 42 | 43 | // Custom Dataset class 44 | class CustomDataset : public torch::data::Dataset { 45 | private: 46 | /* data */ 47 | // Should be 2 tensors 48 | std::vector states, labels; 49 | size_t ds_size; 50 | public: 51 | CustomDataset(std::vector list_images, std::vector list_labels) { 52 | states = process_images(list_images); 53 | labels = process_labels(list_labels); 54 | ds_size = states.size(); 55 | }; 56 | 57 | torch::data::Example<> get(size_t index) override { 58 | /* This should return {torch::Tensor, torch::Tensor} */ 59 | torch::Tensor sample_img = states.at(index); 60 | torch::Tensor sample_label = labels.at(index); 61 | return {sample_img.clone(), sample_label.clone()}; 62 | }; 63 | 64 | torch::optional size() const override { 65 | return ds_size; 66 | }; 67 | }; 68 | 69 | #endif /* main_h */ 70 | -------------------------------------------------------------------------------- /main.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // main.cpp 3 | // transfer-learning 4 | // 5 | // Created by Kushashwa Ravi Shrimali on 12/08/19. 6 | // Copyright © 2019 Kushashwa Ravi Shrimali. All rights reserved. 7 | // 8 | 9 | #include "main.h" 10 | 11 | torch::Tensor read_data(std::string location) { 12 | /* 13 | Function to return image read at location given as type torch::Tensor 14 | Resizes image to (224, 224, 3) 15 | Parameters 16 | =========== 17 | 1. location (std::string type) - required to load image from the location 18 | 19 | Returns 20 | =========== 21 | torch::Tensor type - image read as tensor 22 | */ 23 | cv::Mat img = cv::imread(location, 1); 24 | cv::resize(img, img, cv::Size(224, 224), cv::INTER_CUBIC); 25 | torch::Tensor img_tensor = torch::from_blob(img.data, {img.rows, img.cols, 3}, torch::kByte); 26 | img_tensor = img_tensor.permute({2, 0, 1}); 27 | return img_tensor.clone(); 28 | } 29 | 30 | torch::Tensor read_label(int label) { 31 | /* 32 | Function to return label from int (0, 1 for binary and 0, 1, ..., n-1 for n-class classification) as type torch::Tensor 33 | Parameters 34 | =========== 35 | 1. label (int type) - required to convert int to tensor 36 | 37 | Returns 38 | =========== 39 | torch::Tensor type - label read as tensor 40 | */ 41 | torch::Tensor label_tensor = torch::full({1}, label); 42 | return label_tensor.clone(); 43 | } 44 | 45 | std::vector process_images(std::vector list_images) { 46 | /* 47 | Function returns vector of tensors (images) read from the list of images in a folder 48 | Parameters 49 | =========== 50 | 1. list_images (std::vector type) - list of image paths in a folder to be read 51 | 52 | Returns 53 | =========== 54 | std::vector type - Images read as tensors 55 | */ 56 | std::vector states; 57 | for(std::vector::iterator it = list_images.begin(); it != list_images.end(); ++it) { 58 | torch::Tensor img = read_data(*it); 59 | states.push_back(img); 60 | } 61 | return states; 62 | } 63 | 64 | std::vector process_labels(std::vector list_labels) { 65 | /* 66 | Function returns vector of tensors (labels) read from the list of labels 67 | Parameters 68 | =========== 69 | 1. list_labels (std::vector list_labels) - 70 | 71 | Returns 72 | =========== 73 | std::vector type - returns vector of tensors (labels) 74 | */ 75 | std::vector labels; 76 | for(std::vector::iterator it = list_labels.begin(); it != list_labels.end(); ++it) { 77 | torch::Tensor label = read_label(*it); 78 | labels.push_back(label); 79 | } 80 | return labels; 81 | } 82 | 83 | std::pair,std::vector> load_data_from_folder(std::vector folders_name) { 84 | /* 85 | Function to load data from given folder(s) name(s) (folders_name) 86 | Returns pair of vectors of string (image locations) and int (respective labels) 87 | Parameters 88 | =========== 89 | 1. folders_name (std::vector type) - name of folders as a vector to load data from 90 | 91 | Returns 92 | =========== 93 | std::pair, std::vector> type - returns pair of vector of strings (image paths) and respective labels' vector (int label) 94 | */ 95 | std::vector list_images; 96 | std::vector list_labels; 97 | int label = 0; 98 | for(auto const& value: folders_name) { 99 | std::string base_name = value + "/"; 100 | // cout << "Reading from: " << base_name << endl; 101 | DIR* dir; 102 | struct dirent *ent; 103 | if((dir = opendir(base_name.c_str())) != NULL) { 104 | while((ent = readdir(dir)) != NULL) { 105 | std::string filename = ent->d_name; 106 | if(filename.length() > 4 && filename.substr(filename.length() - 3) == "jpg") { 107 | // cout << base_name + ent->d_name << endl; 108 | // cv::Mat temp = cv::imread(base_name + "/" + ent->d_name, 1); 109 | list_images.push_back(base_name + ent->d_name); 110 | list_labels.push_back(label); 111 | } 112 | } 113 | closedir(dir); 114 | } else { 115 | std::cout << "Could not open directory" << std::endl; 116 | // return EXIT_FAILURE; 117 | } 118 | label += 1; 119 | } 120 | return std::make_pair(list_images, list_labels); 121 | } 122 | 123 | template 124 | void train(torch::jit::script::Module net, torch::nn::Linear lin, Dataloader& data_loader, torch::optim::Optimizer& optimizer, size_t dataset_size) { 125 | /* 126 | This function trains the network on our data loader using optimizer. 127 | 128 | Also saves the model as model.pt after every epoch. 129 | Parameters 130 | =========== 131 | 1. net (torch::jit::script::Module type) - Pre-trained model without last FC layer 132 | 2. lin (torch::nn::Linear type) - last FC layer with revised out_features depending on the number of classes 133 | 3. data_loader (DataLoader& type) - Training data loader 134 | 4. optimizer (torch::optim::Optimizer& type) - Optimizer like Adam, SGD etc. 135 | 5. size_t (dataset_size type) - Size of training dataset 136 | 137 | Returns 138 | =========== 139 | Nothing (void) 140 | */ 141 | 142 | float batch_index = 0; 143 | 144 | for(int i=0; i<15; i++) { 145 | float mse = 0; 146 | float Acc = 0.0; 147 | 148 | for(auto& batch: *data_loader) { 149 | auto data = batch.data; 150 | auto target = batch.target.squeeze(); 151 | 152 | // Should be of length: batch_size 153 | data = data.to(torch::kF32); 154 | target = target.to(torch::kInt64); 155 | 156 | std::vector input; 157 | input.push_back(data); 158 | optimizer.zero_grad(); 159 | 160 | auto output = net.forward(input).toTensor(); 161 | // For transfer learning 162 | output = output.view({output.size(0), -1}); 163 | output = lin(output); 164 | 165 | auto loss = torch::nll_loss(torch::log_softmax(output, 1), target); 166 | 167 | loss.backward(); 168 | optimizer.step(); 169 | 170 | auto acc = output.argmax(1).eq(target).sum(); 171 | 172 | Acc += acc.template item(); 173 | mse += loss.template item(); 174 | 175 | batch_index += 1; 176 | } 177 | 178 | mse = mse/float(batch_index); // Take mean of loss 179 | std::cout << "Epoch: " << i << ", " << "Accuracy: " << Acc/dataset_size << ", " << "MSE: " << mse << std::endl; 180 | net.save("model.pt"); 181 | } 182 | } 183 | 184 | template 185 | void test(torch::jit::script::Module network, torch::nn::Linear lin, Dataloader& loader, size_t data_size) { 186 | /* 187 | Function to test the network on test data 188 | 189 | Parameters 190 | =========== 191 | 1. network (torch::jit::script::Module type) - Pre-trained model without last FC layer 192 | 2. lin (torch::nn::Linear type) - last FC layer with revised out_features depending on the number of classes 193 | 3. loader (Dataloader& type) - test data loader 194 | 4. data_size (size_t type) - test data size 195 | 196 | Returns 197 | =========== 198 | Nothing (void) 199 | */ 200 | network.eval(); 201 | 202 | float Loss = 0, Acc = 0; 203 | 204 | for (const auto& batch : *loader) { 205 | auto data = batch.data; 206 | auto targets = batch.target.view({-1}); 207 | 208 | data = data.to(torch::kF32); 209 | targets = targets.to(torch::kInt64); 210 | std::vector input; 211 | input.push_back(data); 212 | auto output = network.forward(input).toTensor(); 213 | 214 | output = output.view({output.size(0), -1}); 215 | 216 | output = lin(output); 217 | 218 | auto loss = torch::nll_loss(torch::log_softmax(output, 1), targets); 219 | auto acc = output.argmax(1).eq(targets).sum(); 220 | Loss += loss.template item(); 221 | Acc += acc.template item(); 222 | } 223 | 224 | std::cout << "Test Loss: " << Loss/data_size << ", Acc:" << Acc/data_size << std::endl; 225 | } 226 | 227 | int main(int argc, const char * argv[]) { 228 | // Set folder names for cat and dog images 229 | std::string cats_name = "/Users/krshrimali/Documents/krshrimali-blogs/dataset/train/cat_test"; 230 | std::string dogs_name = "/Users/krshrimali/Documents/krshrimali-blogs/dataset/train/dog_test"; 231 | 232 | std::vector folders_name; 233 | folders_name.push_back(cats_name); 234 | folders_name.push_back(dogs_name); 235 | 236 | // Get paths of images and labels as int from the folder paths 237 | std::pair, std::vector> pair_images_labels = load_data_from_folder(folders_name); 238 | 239 | std::vector list_images = pair_images_labels.first; 240 | std::vector list_labels = pair_images_labels.second; 241 | 242 | // Initialize CustomDataset class and read data 243 | auto custom_dataset = CustomDataset(list_images, list_labels).map(torch::data::transforms::Stack<>()); 244 | 245 | // Load pre-trained model 246 | torch::jit::script::Module module; 247 | module = torch::jit::load(argv[1]); 248 | 249 | // Resource: https://discuss.pytorch.org/t/how-to-load-the-prebuilt-resnet-models-or-any-other-prebuilt-models/40269/8 250 | // For VGG: 512 * 14 * 14, 2 251 | 252 | torch::nn::Linear lin(512, 2); // the last layer of resnet, which we want to replace, has dimensions 512x1000 253 | torch::optim::Adam opt(lin->parameters(), torch::optim::AdamOptions(1e-3 /*learning rate*/)); 254 | 255 | auto data_loader = torch::data::make_data_loader(std::move(custom_dataset), 4); 256 | 257 | train(module, lin, data_loader, opt, custom_dataset.size().value()); 258 | 259 | return 0; 260 | } 261 | -------------------------------------------------------------------------------- /Blog.md: -------------------------------------------------------------------------------- 1 | ## Transfer Learning 2 | 3 | Before we go ahead and discuss the **Why** question of Transfer Learning, let's have a look at **What is Transfer Learning?** Let's have a look at the Notes from CS231n on Transfer Learning: 4 | 5 | > In practice, very few people train an entire Convolutional Network from scratch (with random initialization), because it is relatively rare to have a dataset of sufficient size. Instead, it is common to pretrain a ConvNet on a very large dataset (e.g. ImageNet, which contains 1.2 million images with 1000 categories), and then use the ConvNet either as an initialization or a fixed feature extractor for the task of interest. 6 | 7 | There are 3 scenarios possible: 8 | 9 | 1. When the data you have is similar (but not enough) to data trained on pre-trained model: Take an example of a pre-trained model trained on ImageNet dataset (containing 1000 classes). And the data we have has Dogs and Cats classes. Fortunate enough, ImageNet has some of the classes of Dog and Cat breeds and thus the model must have learned important features from the data. Let's say we don't have enough data but since the data is similar to the breeds in the ImageNet data set, we can simply use the ConvNet (except the last FC layer) to extract features from our dataset and train only the last Linear (FC) layer. We do this by the following code snippet in `Python`: 10 | 11 | ```python 12 | from torchvision import models 13 | # Download and load the pre-trained model 14 | model = models.resnet18(pretrained=True) 15 | 16 | # Set upgrading the gradients to False 17 | for param in model.parameters(): 18 | param.requires_grad = False 19 | 20 | # Change the output features to the FC Layer and set it to upgrade gradients as True 21 | resnet18.fc = torch.nn.Linear(512, 2) 22 | for param in resnet18.fc.parameters(): 23 | param.requires_grad = True 24 | ``` 25 | 26 | 2. When you have enough data (and is similar to the data trained with pre-trained model): Then you might go for fine tuning the weights of all the layers in the network. This is largely due to the reason that we know we won't overfit because we have enough data. 27 | 3. Using pre-trained models might just be enough if you have the data which matches the classes in the original data set. 28 | 29 | Transfer Learning came into existence (the answer of **Why Transfer Learning?**) because of some major reasons, which include: 30 | 31 | 1. Lack of resources or data set to train a CNN. At times, we either don't have enough data or we don't have enough resources to train a CNN from scratch. 32 | 2. Random Initialization of weights vs Initialization of weights from the pre-trained model. Sometimes, it's just better to initialize weights from the pre-trained model (as it must have learned the generic features from it's data set) instead of randomly initializing the weights. 33 | 34 | ## Setting up the data with PyTorch C++ API 35 | 36 | At every stage, we will compare the Python and C++ codes to do the same thing, to make the analogy easier and understandable. Starting with setting up the data we have. Note that we do have enough data and it is also similar to the original data set of ImageNet, but since I don't have enough resources to fine tune through the whole network, we perform Transfer Learning on the final FC layer only. 37 | 38 | Starting with loading the dataset, as discussed in the blogs before, I'll just post a flow chart of procedure. 39 | 40 | 41 | 42 | Once done, we can initialize the `CustomDataset` class: 43 | 44 | **C++** 45 | 46 | ```cpp 47 | std::vector list_images; // list of images of Dogs and Cats, use load_data_from_folder function explained in previous blogs 48 | std::vector list_labels; // list of labels of the images 49 | auto custom_dataset = CustomDataset(list_images, list_labels).map(torch::data::transforms::Stack<>()); 50 | ``` 51 | 52 | **Python** 53 | 54 | ```python 55 | from torchvision import datasets, transforms 56 | import torch 57 | 58 | folder_path = "/Users/krshrimali/Documents/dataset/train/" 59 | transform = transforms.Compose([transforms.CenterCrop(224), transforms.ToTensor()) 60 | data = datasets.ImageFolder(root = os.path.join(folder_path), transform = transform) 61 | ``` 62 | 63 | We then use `RandomSampler` to make our data loader: (Note: it's important to use `RandomSampler` as we load the images sequentially and we want mixture of images in each batch of data passed to the network in an epoch) 64 | 65 | **C++** 66 | 67 | ```cpp 68 | int batch_size = 4; 69 | auto data_loader = torch::data::make_data_loader(std::move(custom_dataset), batch_size); 70 | ``` 71 | 72 | **Python** 73 | 74 | ```python 75 | batch_size = 4 76 | data_loader = torch.utils.data.DataLoader(dataset=data, batch_size = batch_size, shuffle = True) 77 | ``` 78 | 79 | ## Loading the pre-trained model 80 | 81 | The steps to load the pre-trained model and perform Transfer Learning are listed below: 82 | 83 | 1. Download the pre-trained model of ResNet18. 84 | 2. Load pre-trained model. 85 | 3. Change output features of the final FC layer of the model loaded. (Number of classes would change from 1000 - ImageNet to 2 - Dogs vs Cats). 86 | 4. Define optimizer on parameters from the final FC layer to be trained. 87 | 5. Train the FC layer on Dogs vs Cats dataset 88 | 5. Save the model (#TODO) 89 | 90 | Let's go step by step. 91 | 92 | **Step-1**: Download the pre-trained model of ResNet18 93 | 94 | (yf225 COMMENT: this comment can probably be clarified - we do have C++ models available in torchvision https://github.com/pytorch/vision/pull/728, but for this tutorial, transferring the pre-trained model from Python to C++ using `torch.jit` is a good idea, as most PyTorch models in the wild are written in Python right now, and people can use this tutorial to learn how to trace their Python model and transfer it to C++.) Currently, PyTorch C++ API doesn't have models incooperated to their API (like Python API has). But it's alright, the PyTorch developers have worked hard in developing `torch.jit` module which is also available in C++ API. Let's see how we do that. 95 | 96 | First we download the pre-trained model and save it in the form of `torch.jit.trace` format to our local drive. 97 | 98 | ```python 99 | # Reference: #TODO- Add Link 100 | from torchvision import models 101 | # Download and load the pre-trained model 102 | model = models.resnet18(pretrained=True) 103 | 104 | # Set upgrading the gradients to False 105 | for param in model.parameters(): 106 | param.requires_grad = False 107 | 108 | # Save the model except the final FC Layer 109 | resnet18 = torch.nn.Sequential(*list(resnet18.children())[:-1]) 110 | 111 | example_input = torch.rand(1, 3, 224, 224) 112 | script_module = torch.jit.trace(resnet18, example_input) 113 | script_module.save('resnet18_without_last_layer.pt') 114 | ``` 115 | 116 | We will be using `resnet18_without_last_layer.pt` model file as our pre-trained model for transfer learning. 117 | 118 | **Step-2**: Load the pre-trained model 119 | 120 | Let's go ahead and load the pre-trained model using `torch::jit` module. Note that the reason we have converted `torch.nn.Module` to `torch.jit.ScriptModule` type, is because C++ API currently does not support loading Python `torch.nn.Module` models directly. 121 | 122 | **C++**: 123 | 124 | ```cpp 125 | torch::jit::script::Module module; 126 | module = torch::jit::load(argv[1]); // argv[1] should be the path to the model 127 | 128 | // We need to convert last layer input and output features from (512, 1000) to (512, 2) since we only have 2 classes 129 | torch::nn::Linear linear_layer(512, 2); 130 | 131 | // Define the optimizer on parameters of linear_layer with learning_rate = 1e-3 132 | torch::optim::Adam optimizer(linear_layer->parameters(), torch::optim::AdamOptions(1e-3)) 133 | ``` 134 | 135 | **Python**: 136 | 137 | ```python 138 | # We will directly load the torch.nn pre-trained model 139 | model = models.resnet18(pretrained = True) 140 | 141 | for param in model.parameters(): 142 | param.requires_grad = False 143 | 144 | model.fc = torch.nn.Linear(512, 2) 145 | for param in model.fc.parameters(): 146 | param.requires_grad = True 147 | 148 | optimizer = torch.optim.Adam(model.fc.parameters()) 149 | cost = torch.nn.CrossEntropyLoss() 150 | ``` 151 | 152 | ## Trainining the FC Layer 153 | 154 | Let's first have a look at ResNet18 Network Architecture 155 | 156 | Reference: https://www.researchgate.net/figure/ResNet-18-Architecture_tbl1_322476121 157 | 158 | The final step is to train the Fully Connected layer that we inserted at the end of the network (`linear_layer`). This one should be pretty straight forward, let's see how to do it. 159 | 160 | **C++**: 161 | 162 | ```cpp 163 | void train(torch::jit::script::Module net, torch::nn::Linear lin, Dataloader& data_loader, torch::optim::Optimizer& optimizer, size_t dataset_size) { 164 | /* 165 | This function trains the network on our data loader using optimizer for given number of epochs. 166 | 167 | Parameters 168 | ================== 169 | torch::jit::script::Module net: Pre-trained model 170 | torch::nn::Linear lin: Linear layer 171 | DataLoader& data_loader: Training data loader 172 | torch::optim::Optimizer& optimizer: Optimizer like Adam, SGD etc. 173 | size_t dataset_size: Size of training dataset 174 | */ 175 | 176 | float batch_index = 0; 177 | 178 | for(int i=0; i<15; i++) { 179 | float mse = 0; 180 | float Acc = 0.0; 181 | 182 | for(auto& batch: *data_loader) { 183 | auto data = batch.data; 184 | auto target = batch.target.squeeze(); 185 | 186 | // Should be of length: batch_size 187 | data = data.to(torch::kF32); 188 | target = target.to(torch::kInt64); 189 | 190 | std::vector input; 191 | input.push_back(data); 192 | optimizer.zero_grad(); 193 | 194 | auto output = net.forward(input).toTensor(); 195 | // For transfer learning 196 | output = output.view({output.size(0), -1}); 197 | 198 | output = lin(output); 199 | // Explicitly calculate torch::log_softmax of output from the FC Layer 200 | auto loss = torch::nll_loss(torch::log_softmax(output, 1), target); 201 | 202 | loss.backward(); 203 | optimizer.step(); 204 | 205 | auto acc = output.argmax(1).eq(target).sum(); 206 | 207 | Acc += acc.template item(); 208 | mse += loss.template item(); 209 | 210 | batch_index += 1; 211 | } 212 | 213 | mse = mse/float(batch_index); // Take mean of loss 214 | std::cout << "Epoch: " << i << ", " << "Accuracy: " << Acc/dataset_size << ", " << "MSE: " << mse << std::endl; 215 | net.save("model.pt"); 216 | } 217 | } 218 | ``` 219 | 220 | **Python**: 221 | 222 | ```python 223 | n_epochs = 15 224 | 225 | for epoch in range(n_epochs): 226 | mse = 0.0 227 | acc = 0 228 | batch_index = 0 229 | 230 | for data_batch in data_loader: 231 | batch_index += 1 232 | image, label = data_batch 233 | 234 | optimizer.zero_grad() 235 | 236 | output = model(image) 237 | _, predicted_label = torch.max(output.data, 1) 238 | 239 | loss = cost(output, label) 240 | 241 | loss.backward() 242 | optimizer.step() 243 | 244 | mse += loss.item() # data[0] 245 | acc += torch.sum(predicted_label == label.data) 246 | 247 | mse = mse/len(data) 248 | acc = 100*acc/len(data) 249 | 250 | print("Epoch: {}/{}, Loss: {:.4f}, Accuracy: {:.4f}".format(epoch+1, n_epochs, mse, acc)) 251 | ``` 252 | 253 | The code to test should also not change much except the need of optimizer. 254 | 255 | ## Results 256 | 257 | Results using PyTorch C++ API 258 | Results using PyTorch in Python 259 | 260 | On a set of 400 images for training data, the maximum training Accuracy I could achieve was 91.25% in just less than 15 epochs using PyTorch C++ API and 89.0% using Python. (Note that this doesn't conclude superiority in terms of accuracy between any of the two backends - C++ or Python) 261 | -------------------------------------------------------------------------------- /Transfer-Learning-on-Dogs-vs-Cats.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "#pragma cling add_library_path(\"/Users/krshrimali/Downloads/libtorch/lib/\")\n", 10 | "#pragma cling add_include_path(\"/Users/krshrimali/Downloads/libtorch/include/\")\n", 11 | "#pragma cling add_include_path(\"/Users/krshrimali/Downloads/libtorch/include/torch/csrc/api/include/\")\n", 12 | "#pragma cling add_library_path(\"/usr/local/Cellar/opencv/4.1.0_2/lib\")\n", 13 | "#pragma cling add_include_path(\"/usr/local/Cellar/opencv/4.1.0_2/include/opencv4\")\n", 14 | "#pragma cling load(\"/Users/krshrimali/Downloads/libtorch/lib/libiomp5.dylib\")\n", 15 | "#pragma cling load(\"/Users/krshrimali/Downloads/libtorch/lib/libmklml.dylib\")\n", 16 | "#pragma cling load(\"/Users/krshrimali/Downloads/libtorch/lib/libc10.dylib\")\n", 17 | "#pragma cling load(\"/Users/krshrimali/Downloads/libtorch/lib/libtorch.dylib\")\n", 18 | "#pragma cling load(\"/Users/krshrimali/Downloads/libtorch/lib/libcaffe2_detectron_ops.dylib\")\n", 19 | "#pragma cling load(\"/Users/krshrimali/Downloads/libtorch/lib/libcaffe2_module_test_dynamic.dylib\")\n", 20 | "#pragma cling load(\"/Users/krshrimali/Downloads/libtorch/lib/libcaffe2_observers.dylib\")\n", 21 | "#pragma cling load(\"/Users/krshrimali/Downloads/libtorch/lib/libshm.dylib\")\n", 22 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_datasets.4.1.0.dylib\")\n", 23 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_aruco.4.1.0.dylib\")\n", 24 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_bgsegm.4.1.0.dylib\")\n", 25 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_bioinspired.4.1.0.dylib\")\n", 26 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_calib3d.4.1.0.dylib\")\n", 27 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_ccalib.4.1.0.dylib\")\n", 28 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_core.4.1.0.dylib\")\n", 29 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_dnn_objdetect.4.1.0.dylib\")\n", 30 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_dnn.4.1.0.dylib\")\n", 31 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_dpm.4.1.0.dylib\")\n", 32 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_face.4.1.0.dylib\")\n", 33 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_features2d.4.1.0.dylib\")\n", 34 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_flann.4.1.0.dylib\")\n", 35 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_freetype.4.1.0.dylib\")\n", 36 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_fuzzy.4.1.0.dylib\")\n", 37 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_gapi.4.1.0.dylib\")\n", 38 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_hfs.4.1.0.dylib\")\n", 39 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_highgui.4.1.0.dylib\")\n", 40 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_img_hash.4.1.0.dylib\")\n", 41 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_imgcodecs.4.1.0.dylib\")\n", 42 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_imgproc.4.1.0.dylib\")\n", 43 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_line_descriptor.4.1.0.dylib\")\n", 44 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_ml.4.1.0.dylib\")\n", 45 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_objdetect.4.1.0.dylib\")\n", 46 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_optflow.4.1.0.dylib\")\n", 47 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_phase_unwrapping.4.1.0.dylib\")\n", 48 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_photo.4.1.0.dylib\")\n", 49 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_plot.4.1.0.dylib\")\n", 50 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_quality.4.1.0.dylib\")\n", 51 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_reg.4.1.0.dylib\")\n", 52 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_rgbd.4.1.0.dylib\")\n", 53 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_saliency.4.1.0.dylib\")\n", 54 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_sfm.4.1.0.dylib\")\n", 55 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_shape.4.1.0.dylib\")\n", 56 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_stereo.4.1.0.dylib\")\n", 57 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_stitching.4.1.0.dylib\")\n", 58 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_structured_light.4.1.0.dylib\")\n", 59 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_superres.4.1.0.dylib\")\n", 60 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_surface_matching.4.1.0.dylib\")\n", 61 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_text.4.1.0.dylib\")\n", 62 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_tracking.4.1.0.dylib\")\n", 63 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_video.4.1.0.dylib\")\n", 64 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_videoio.4.1.0.dylib\")\n", 65 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_videostab.4.1.0.dylib\")\n", 66 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_xfeatures2d.4.1.0.dylib\")\n", 67 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_ximgproc.4.1.0.dylib\")\n", 68 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_xobjdetect.4.1.0.dylib\")\n", 69 | "#pragma cling load(\"/usr/local/Cellar/opencv/4.1.0_2/lib/libopencv_xphoto.4.1.0.dylib\")" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 2, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "#include \n", 79 | "#include \n", 80 | "#include \n", 81 | "#include \n", 82 | "#include " 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "## Transfer Learning\n", 90 | "\n", 91 | "Before we go ahead and discuss the **Why** question of Transfer Learning, let's have a look at **What is Transfer Learning?** Let's have a look at the Notes from CS231n on Transfer Learning:\n", 92 | "\n", 93 | "> In practice, very few people train an entire Convolutional Network from scratch (with random initialization), because it is relatively rare to have a dataset of sufficient size. Instead, it is common to pretrain a ConvNet on a very large dataset (e.g. ImageNet, which contains 1.2 million images with 1000 categories), and then use the ConvNet either as an initialization or a fixed feature extractor for the task of interest.\n", 94 | "\n", 95 | "There are 3 scenarios possible:\n", 96 | "\n", 97 | "1. When the data you have is similar (but not enough) to data trained on pre-trained model: Take an example of a pre-trained model trained on ImageNet dataset (containing 1000 classes). And the data we have has Dogs and Cats classes. Fortunate enough, ImageNet has some of the classes of Dog and Cat breeds and thus the model must have learned important features from the data. Let's say we don't have enough data but since the data is similar to the breeds in the ImageNet data set, we can simply use the ConvNet (except the last FC layer) to extract features from our dataset and train only the last Linear (FC) layer. We do this by the following code snippet in `Python`:\n", 98 | "\n", 99 | "```python\n", 100 | "from torchvision import models\n", 101 | "# Download and load the pre-trained model\n", 102 | "model = models.resnet18(pretrained=True)\n", 103 | "\n", 104 | "# Set upgrading the gradients to False\n", 105 | "for param in model.parameters():\n", 106 | "\tparam.requires_grad = False\n", 107 | "\n", 108 | "# Change the output features to the FC Layer and set it to upgrade gradients as True\n", 109 | "resnet18.fc = torch.nn.Linear(512, 2)\n", 110 | "for param in resnet18.fc.parameters():\n", 111 | "\tparam.requires_grad = True\n", 112 | "```\n", 113 | "\n", 114 | "2. When you have enough data (and is similar to the data trained with pre-trained model): Then you might go for fine tuning the weights of all the layers in the network. This is largely due to the reason that we know we won't overfit because we have enough data.\n", 115 | "3. Using pre-trained models might just be enough if you have the data which matches the classes in the original data set. \n", 116 | "\n", 117 | "Transfer Learning came into existence (the answer of **Why Transfer Learning?**) because of some major reasons, which include:\n", 118 | "\n", 119 | "1. Lack of resources or data set to train a CNN. At times, we either don't have enough data or we don't have enough resources to train a CNN from scratch.\n", 120 | "2. Random Initialization of weights vs Initialization of weights from the pre-trained model. Sometimes, it's just better to initialize weights from the pre-trained model (as it must have learned the generic features from it's data set) instead of randomly initializing the weights.\n" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "## Setting up the data with PyTorch C++ API\n", 128 | "\n", 129 | "At every stage, we will compare the Python and C++ codes to do the same thing, to make the analogy easier and understandable. Starting with setting up the data we have. Note that we do have enough data and it is also similar to the original data set of ImageNet, but since I don't have enough resources to fine tune through the whole network, we perform Transfer Learning on the final FC layer only.\n", 130 | "\n", 131 | "Starting with loading the dataset, as discussed in the blogs before, I will just post a flow chart of procedure.\n", 132 | "\n", 133 | "\n", 134 | "\n", 135 | "Let's go ahead and define the required utility functions to define Custom Dataset class." 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": {}, 141 | "source": [ 142 | "## Utility Function - 1: `read_data` and `read_label`\n", 143 | "\n", 144 | "**Documentation** of `read_data` function.\n", 145 | "\n", 146 | "```\n", 147 | "torch::Tensor read_data(std::string location)\n", 148 | "\n", 149 | "Function to return image read at location given as type torch::Tensor\n", 150 | " Resizes image to (224, 224, 3)\n", 151 | " Parameters\n", 152 | " ===========\n", 153 | " 1. location (std::string type) - required to load image from the location\n", 154 | " \n", 155 | " Returns\n", 156 | " ===========\n", 157 | " torch::Tensor type - image read as tensor\n", 158 | "```" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 3, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "torch::Tensor read_data(std::string location) {\n", 168 | " /*\n", 169 | " Function to return image read at location given as type torch::Tensor\n", 170 | " Resizes image to (224, 224, 3)\n", 171 | " Parameters\n", 172 | " ===========\n", 173 | " 1. location (std::string type) - required to load image from the location\n", 174 | " \n", 175 | " Returns\n", 176 | " ===========\n", 177 | " torch::Tensor type - image read as tensor\n", 178 | " */\n", 179 | " cv::Mat img = cv::imread(location, 1);\n", 180 | " cv::resize(img, img, cv::Size(224, 224), cv::INTER_CUBIC);\n", 181 | " torch::Tensor img_tensor = torch::from_blob(img.data, {img.rows, img.cols, 3}, torch::kByte);\n", 182 | " img_tensor = img_tensor.permute({2, 0, 1});\n", 183 | " return img_tensor.clone();\n", 184 | "}" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": {}, 190 | "source": [ 191 | "**Documentation** of `read_label` function.\n", 192 | "\n", 193 | "```\n", 194 | "torch::Tensor read_label(int label)\n", 195 | "\n", 196 | "Function to return label from int (0, 1 for binary and 0, 1, ..., n-1 for n-class classification) as type torch::Tensor\n", 197 | " Parameters\n", 198 | " ===========\n", 199 | " 1. label (int type) - required to convert int to tensor\n", 200 | " \n", 201 | " Returns\n", 202 | " ===========\n", 203 | " torch::Tensor type - label read as tensor\n", 204 | "```" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 4, 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "torch::Tensor read_label(int label) {\n", 214 | " /*\n", 215 | " Function to return label from int (0, 1 for binary and 0, 1, ..., n-1 for n-class classification) as type torch::Tensor\n", 216 | " Parameters\n", 217 | " ===========\n", 218 | " 1. label (int type) - required to convert int to tensor\n", 219 | " \n", 220 | " Returns\n", 221 | " ===========\n", 222 | " torch::Tensor type - label read as tensor\n", 223 | " */\n", 224 | " torch::Tensor label_tensor = torch::full({1}, label);\n", 225 | " return label_tensor.clone();\n", 226 | "}" 227 | ] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": {}, 232 | "source": [ 233 | "## Utility Function 2: `process_images` and `process_labels`\n", 234 | "\n", 235 | "**Documentation** of `process_images` function.\n", 236 | "\n", 237 | "```\n", 238 | "Function returns vector of tensors (images) read from the list of images in a folder\n", 239 | " Parameters\n", 240 | " ===========\n", 241 | " 1. list_images (std::vector type) - list of image paths in a folder to be read\n", 242 | " \n", 243 | " Returns\n", 244 | " ===========\n", 245 | " std::vector type - Images read as tensors\n", 246 | "```" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 5, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "std::vector process_images(std::vector list_images) {\n", 256 | " /*\n", 257 | " Function returns vector of tensors (images) read from the list of images in a folder\n", 258 | " Parameters\n", 259 | " ===========\n", 260 | " 1. list_images (std::vector type) - list of image paths in a folder to be read\n", 261 | " \n", 262 | " Returns\n", 263 | " ===========\n", 264 | " std::vector type - Images read as tensors\n", 265 | " */\n", 266 | " std::vector states;\n", 267 | " for(std::vector::iterator it = list_images.begin(); it != list_images.end(); ++it) {\n", 268 | " torch::Tensor img = read_data(*it);\n", 269 | " states.push_back(img);\n", 270 | " }\n", 271 | " return states;\n", 272 | "}" 273 | ] 274 | }, 275 | { 276 | "cell_type": "markdown", 277 | "metadata": {}, 278 | "source": [ 279 | "**Documentation** of `process_labels` function.\n", 280 | "\n", 281 | "```\n", 282 | "Function returns vector of tensors (labels) read from the list of labels\n", 283 | " Parameters\n", 284 | " ===========\n", 285 | " 1. list_labels (std::vector list_labels) -\n", 286 | " \n", 287 | " Returns\n", 288 | " ===========\n", 289 | " std::vector type - returns vector of tensors (labels)\n", 290 | "```" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 6, 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [ 299 | "std::vector process_labels(std::vector list_labels) {\n", 300 | " /*\n", 301 | " Function returns vector of tensors (labels) read from the list of labels\n", 302 | " Parameters\n", 303 | " ===========\n", 304 | " 1. list_labels (std::vector type) - required to convert int to tensor labels\n", 305 | " \n", 306 | " Returns\n", 307 | " ===========\n", 308 | " std::vector type - returns vector of tensors (labels)\n", 309 | " */\n", 310 | " std::vector labels;\n", 311 | " for(std::vector::iterator it = list_labels.begin(); it != list_labels.end(); ++it) {\n", 312 | " torch::Tensor label = read_label(*it);\n", 313 | " labels.push_back(label);\n", 314 | " }\n", 315 | " return labels;\n", 316 | "}\n", 317 | "\n" 318 | ] 319 | }, 320 | { 321 | "cell_type": "markdown", 322 | "metadata": {}, 323 | "source": [ 324 | "## Utility Function 3: `load_images` and `load_labels`:\n", 325 | "\n", 326 | "**Documentation** of `load_images` function.\n", 327 | "\n", 328 | "```\n", 329 | "Function returns vector of strings (image paths) read from the folder name given\n", 330 | " Parameters\n", 331 | " ===========\n", 332 | " 1. folder_name (std::string type) - name of folder containing images\n", 333 | " \n", 334 | " Returns\n", 335 | " ===========\n", 336 | " std::vector type - returns vector of image paths\n", 337 | "```" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": 7, 343 | "metadata": {}, 344 | "outputs": [], 345 | "source": [ 346 | "std::vector load_images(std::string folder_name) {\n", 347 | " /*\n", 348 | " Function returns vector of strings (image paths) read from the folder name given\n", 349 | " Parameters\n", 350 | " ===========\n", 351 | " 1. folder_name (std::string type) - name of folder containing images\n", 352 | " \n", 353 | " Returns\n", 354 | " ===========\n", 355 | " std::vector type - returns vector of image paths\n", 356 | " */\n", 357 | " std::vector list_images;\n", 358 | " \n", 359 | " std::string base_name = folder_name;\n", 360 | " \n", 361 | " DIR* dir;\n", 362 | " struct dirent *ent;\n", 363 | " \n", 364 | " if((dir = opendir(base_name.c_str())) != NULL) {\n", 365 | " while((ent = readdir(dir)) != NULL) {\n", 366 | " std::string filename = ent->d_name;\n", 367 | " if(filename.length() > 4 && filename.substr(filename.length() - 3) == \"jpg\") {\n", 368 | " std::string newf = base_name + filename;\n", 369 | " list_images.push_back(newf);\n", 370 | " }\n", 371 | " }\n", 372 | " }\n", 373 | " \n", 374 | " return list_images;\n", 375 | "}" 376 | ] 377 | }, 378 | { 379 | "cell_type": "markdown", 380 | "metadata": {}, 381 | "source": [ 382 | "**Documentation** of `load_labels` function.\n", 383 | "\n", 384 | "```\n", 385 | "Function returns vector of int (labels) to each image in the folder (folder_name)\n", 386 | " Parameters\n", 387 | " ===========\n", 388 | " 1. folder_name (std::string type) - name of folder containing images\n", 389 | " 2. label (int type) - label of the class (0 or 1 in case of binary, 0 1 ... n-1 in case of n-class classification)\n", 390 | " Returns\n", 391 | " ===========\n", 392 | " std::vector type - returns vector of labels assigned to each image of each class\n", 393 | "```" 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "execution_count": 8, 399 | "metadata": {}, 400 | "outputs": [], 401 | "source": [ 402 | "std::vector load_labels(std::string folder_name, int label) {\n", 403 | " /*\n", 404 | " Function returns vector of int (labels) to each image in the folder (folder_name)\n", 405 | " Parameters\n", 406 | " ===========\n", 407 | " 1. folder_name (std::string type) - name of folder containing images\n", 408 | " 2. label (int type) - label of the class (0 or 1 in case of binary, 0 1 ... n-1 in case of n-class classification)\n", 409 | " Returns\n", 410 | " ===========\n", 411 | " std::vector type - returns vector of labels assigned to each image of each class\n", 412 | " */\n", 413 | " std::vector list_labels;\n", 414 | " DIR* dir;\n", 415 | " \n", 416 | " std::string base_name = folder_name;\n", 417 | " \n", 418 | " struct dirent *ent;\n", 419 | " \n", 420 | " if((dir = opendir(base_name.c_str())) != NULL) {\n", 421 | " while((ent = readdir(dir)) != NULL) {\n", 422 | " std::string filename = ent->d_name;\n", 423 | " if(filename.length() > 4 && filename.substr(filename.length() - 3) == \"jpg\") {\n", 424 | " list_labels.push_back(label);\n", 425 | " }\n", 426 | " }\n", 427 | " }\n", 428 | " \n", 429 | " return list_labels;\n", 430 | "}" 431 | ] 432 | }, 433 | { 434 | "cell_type": "markdown", 435 | "metadata": {}, 436 | "source": [ 437 | "Since we are done with all the utility functions, we can go ahead and define the `CustomDataset` class." 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": 9, 443 | "metadata": {}, 444 | "outputs": [], 445 | "source": [ 446 | "class CustomDataset : public torch::data::Dataset {\n", 447 | "private:\n", 448 | " /* data */\n", 449 | " // Should be 2 tensors\n", 450 | " std::vector states, labels;\n", 451 | " size_t ds_size;\n", 452 | "public:\n", 453 | " CustomDataset(std::vector list_images, std::vector list_labels) {\n", 454 | " states = process_images(list_images);\n", 455 | " labels = process_labels(list_labels);\n", 456 | " ds_size = states.size();\n", 457 | " };\n", 458 | " \n", 459 | " torch::data::Example<> get(size_t index) override {\n", 460 | " /* This should return {torch::Tensor, torch::Tensor} */\n", 461 | " torch::Tensor sample_img = states.at(index);\n", 462 | " torch::Tensor sample_label = labels.at(index);\n", 463 | " return {sample_img.clone(), sample_label.clone()};\n", 464 | " };\n", 465 | " \n", 466 | " torch::optional size() const override {\n", 467 | " return ds_size;\n", 468 | " };\n", 469 | "};" 470 | ] 471 | }, 472 | { 473 | "cell_type": "markdown", 474 | "metadata": {}, 475 | "source": [ 476 | "## Trainining the FC Layer\n", 477 | "\n", 478 | "Let's first have a look at ResNet18 Network Architecture\n", 479 | "\n", 480 | "Reference: https://www.researchgate.net/figure/ResNet-18-Architecture_tbl1_322476121\n", 481 | "\n", 482 | "The next step is to train the Fully Connected layer that we inserted at the end of the network (`linear_layer`). This one should be pretty straight forward, let's see how to do it.\n", 483 | "\n", 484 | "**Documentation** of `train()` function.\n", 485 | "\n", 486 | "```\n", 487 | "This function trains the network on our data loader using optimizer.\n", 488 | " \n", 489 | " Also saves the model as model.pt after every epoch.\n", 490 | " Parameters\n", 491 | " ===========\n", 492 | " 1. net (torch::jit::script::Module type) - Pre-trained model without last FC layer\n", 493 | " 2. lin (torch::nn::Linear type) - last FC layer with revised out_features depending on the no. of classes\n", 494 | " 3. data_loader (DataLoader& type) - Training data loader\n", 495 | " 4. optimizer (torch::optim::Optimizer& type) - Optimizer like Adam, SGD etc.\n", 496 | " 5. size_t (dataset_size type) - Size of training dataset\n", 497 | " \n", 498 | " Returns\n", 499 | " ===========\n", 500 | " Nothing (void)\n", 501 | "```" 502 | ] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "execution_count": 10, 507 | "metadata": {}, 508 | "outputs": [], 509 | "source": [ 510 | "template\n", 511 | "void train(torch::jit::script::Module net, torch::nn::Linear lin, Dataloader& data_loader, torch::optim::Optimizer& optimizer, size_t dataset_size) {\n", 512 | " /*\n", 513 | " This function trains the network on our data loader using optimizer.\n", 514 | " \n", 515 | " Also saves the model as model.pt after every epoch.\n", 516 | " Parameters\n", 517 | " ===========\n", 518 | " 1. net (torch::jit::script::Module type) - Pre-trained model without last FC layer\n", 519 | " 2. lin (torch::nn::Linear type) - last FC layer with revised out_features depending on the number of classes\n", 520 | " 3. data_loader (DataLoader& type) - Training data loader\n", 521 | " 4. optimizer (torch::optim::Optimizer& type) - Optimizer like Adam, SGD etc.\n", 522 | " 5. size_t (dataset_size type) - Size of training dataset\n", 523 | " \n", 524 | " Returns\n", 525 | " ===========\n", 526 | " Nothing (void)\n", 527 | " */\n", 528 | " \n", 529 | " float batch_index = 0;\n", 530 | " \n", 531 | " for(int i=0; i<15; i++) {\n", 532 | " float mse = 0;\n", 533 | " float Acc = 0.0;\n", 534 | " \n", 535 | " for(auto& batch: *data_loader) {\n", 536 | " auto data = batch.data;\n", 537 | " auto target = batch.target.squeeze();\n", 538 | " \n", 539 | " // Should be of length: batch_size\n", 540 | " data = data.to(torch::kF32);\n", 541 | " target = target.to(torch::kInt64);\n", 542 | " \n", 543 | " std::vector input;\n", 544 | " input.push_back(data);\n", 545 | " optimizer.zero_grad();\n", 546 | " \n", 547 | " auto output = net.forward(input).toTensor();\n", 548 | " // For transfer learning\n", 549 | " output = output.view({output.size(0), -1});\n", 550 | " output = lin(output);\n", 551 | " \n", 552 | " auto loss = torch::nll_loss(torch::log_softmax(output, 1), target);\n", 553 | " \n", 554 | " loss.backward();\n", 555 | " optimizer.step();\n", 556 | " \n", 557 | " auto acc = output.argmax(1).eq(target).sum();\n", 558 | " \n", 559 | " Acc += acc.template item();\n", 560 | " mse += loss.template item();\n", 561 | " \n", 562 | " batch_index += 1;\n", 563 | " }\n", 564 | " \n", 565 | " mse = mse/float(batch_index); // Take mean of loss\n", 566 | " std::cout << \"Epoch: \" << i << \", \" << \"Accuracy: \" << Acc/dataset_size << \", \" << \"MSE: \" << mse << std::endl;\n", 567 | " net.save(\"model.pt\");\n", 568 | " }\n", 569 | "}" 570 | ] 571 | }, 572 | { 573 | "cell_type": "markdown", 574 | "metadata": {}, 575 | "source": [ 576 | "## Setting up the Dataset for training\n", 577 | "\n", 578 | "Let's go ahead and load our dataset into `DataLoader` class. \n", 579 | "\n", 580 | "The distribution of the dataset is: `Cat Images`: 200 and `Dog Images`: 200." 581 | ] 582 | }, 583 | { 584 | "cell_type": "code", 585 | "execution_count": 11, 586 | "metadata": {}, 587 | "outputs": [], 588 | "source": [ 589 | "// Set folder names for cat and dog images\n", 590 | "std::string name_cats = \"/Users/krshrimali/Documents/krshrimali-blogs/dataset/train/cat_test/\";\n", 591 | "std::string name_dogs = \"/Users/krshrimali/Documents/krshrimali-blogs/dataset/train/dog_test/\";\n", 592 | "\n", 593 | "std::vector images_cats = load_images(name_cats);\n", 594 | "std::vector labels_cats = load_labels(name_cats, 0);\n", 595 | "\n", 596 | "std::vector images_dogs = load_images(name_dogs);\n", 597 | "std::vector labels_dogs = load_labels(name_dogs, 1);" 598 | ] 599 | }, 600 | { 601 | "cell_type": "code", 602 | "execution_count": 12, 603 | "metadata": {}, 604 | "outputs": [], 605 | "source": [ 606 | "std::vector images_total;\n", 607 | "\n", 608 | "for(auto const& value: images_cats) {\n", 609 | " images_total.push_back(value);\n", 610 | "}\n", 611 | "\n", 612 | "for(auto const& value: images_dogs) {\n", 613 | " images_total.push_back(value);\n", 614 | "}" 615 | ] 616 | }, 617 | { 618 | "cell_type": "code", 619 | "execution_count": 13, 620 | "metadata": {}, 621 | "outputs": [], 622 | "source": [ 623 | "std::vector labels_total;\n", 624 | "\n", 625 | "for(auto const& value: labels_cats) {\n", 626 | " labels_total.push_back(value);\n", 627 | "}\n", 628 | "\n", 629 | "for(auto const& value: labels_dogs) {\n", 630 | " labels_total.push_back(value);\n", 631 | "}" 632 | ] 633 | }, 634 | { 635 | "cell_type": "code", 636 | "execution_count": 14, 637 | "metadata": {}, 638 | "outputs": [], 639 | "source": [ 640 | "auto custom_dataset = CustomDataset(images_total, labels_total).map(torch::data::transforms::Stack<>());\n", 641 | "auto data_loader = torch::data::make_data_loader(std::move(custom_dataset), 4);" 642 | ] 643 | }, 644 | { 645 | "cell_type": "markdown", 646 | "metadata": {}, 647 | "source": [ 648 | "## Loading the pre-trained model\n", 649 | "\n", 650 | "The steps to load the pre-trained model and perform Transfer Learning are listed below:\n", 651 | "\n", 652 | "1. Download the pre-trained model of ResNet18.\n", 653 | "2. Load pre-trained model.\n", 654 | "3. Change output features of the final FC layer of the model loaded. (Number of classes would change from 1000 - ImageNet to 2 - Dogs vs Cats).\n", 655 | "4. Define optimizer on parameters from the final FC layer to be trained.\n", 656 | "5. Train the FC layer on Dogs vs Cats dataset\n", 657 | "5. Save the model (#TODO)\n", 658 | "\n", 659 | "Let's go step by step.\n", 660 | "\n", 661 | "**Step-1**: Download the pre-trained model of ResNet18\n", 662 | "\n", 663 | "Thanks to the developers, we do have C++ models available in torchvision (https://github.com/pytorch/vision/pull/728) but for this tutorial, transferring the pre-trained model from Python to C++ using `torch.jit` is a good idea, as most PyTorch models in the wild are written in Python right now, and people can use this tutorial to learn how to trace their Python model and transfer it to C++.) \n", 664 | "\n", 665 | "First we download the pre-trained model and save it in the form of `torch.jit.trace` format to our local drive. \n", 666 | "\n", 667 | "```python\n", 668 | "# Reference: #TODO- Add Link\n", 669 | "from torchvision import models\n", 670 | "# Download and load the pre-trained model\n", 671 | "model = models.resnet18(pretrained=True)\n", 672 | "\n", 673 | "# Set upgrading the gradients to False\n", 674 | "for param in model.parameters():\n", 675 | "\tparam.requires_grad = False\n", 676 | "\n", 677 | "# Save the model except the final FC Layer\n", 678 | "resnet18 = torch.nn.Sequential(*list(resnet18.children())[:-1])\n", 679 | "\n", 680 | "example_input = torch.rand(1, 3, 224, 224)\n", 681 | "script_module = torch.jit.trace(resnet18, example_input)\n", 682 | "script_module.save('resnet18_without_last_layer.pt')\n", 683 | "```\n", 684 | "\n", 685 | "We will be using `resnet18_without_last_layer.pt` model file as our pre-trained model for transfer learning. \n", 686 | "\n", 687 | "**Step-2**: Load the pre-trained model\n", 688 | "\n", 689 | "Let's go ahead and load the pre-trained model using `torch::jit` module. Note that the reason we have converted `torch.nn.Module` to `torch.jit.ScriptModule` type, is because C++ API currently does not support loading Python `torch.nn.Module` models directly." 690 | ] 691 | }, 692 | { 693 | "cell_type": "code", 694 | "execution_count": 15, 695 | "metadata": {}, 696 | "outputs": [], 697 | "source": [ 698 | "torch::jit::script::Module module;\n", 699 | "module = torch::jit::load(\"/Users/krshrimali/Documents/krshrimali-blogs/codes/transfer-learning/transfer-learning/build/resnet18_without_lastlayer.pt\");" 700 | ] 701 | }, 702 | { 703 | "cell_type": "markdown", 704 | "metadata": {}, 705 | "source": [ 706 | "## Experimentation\n", 707 | "\n", 708 | "Since we are almost done with defining required functions, let's go ahead and define the optimizer on our last FC layer and train the FC layer on our dataset." 709 | ] 710 | }, 711 | { 712 | "cell_type": "code", 713 | "execution_count": 16, 714 | "metadata": {}, 715 | "outputs": [], 716 | "source": [ 717 | "torch::nn::Linear lin(512, 2); // the last layer of resnet, which we want to replace, has dimensions 512x1000\n", 718 | "torch::optim::Adam opt(lin->parameters(), torch::optim::AdamOptions(1e-3 /*learning rate*/));" 719 | ] 720 | }, 721 | { 722 | "cell_type": "code", 723 | "execution_count": 17, 724 | "metadata": { 725 | "scrolled": true 726 | }, 727 | "outputs": [ 728 | { 729 | "name": "stdout", 730 | "output_type": "stream", 731 | "text": [ 732 | "Epoch: 0, Accuracy: 0.745, MSE: 0.491316\n", 733 | "Epoch: 1, Accuracy: 0.8825, MSE: 0.151102\n", 734 | "Epoch: 2, Accuracy: 0.845, MSE: 0.111356\n", 735 | "Epoch: 3, Accuracy: 0.8025, MSE: 0.106991\n", 736 | "Epoch: 4, Accuracy: 0.885, MSE: 0.0578496\n", 737 | "Epoch: 5, Accuracy: 0.865, MSE: 0.0572935\n", 738 | "Epoch: 6, Accuracy: 0.9, MSE: 0.0337119\n", 739 | "Epoch: 7, Accuracy: 0.855, MSE: 0.0399212\n", 740 | "Epoch: 8, Accuracy: 0.865, MSE: 0.0347004\n", 741 | "Epoch: 9, Accuracy: 0.8425, MSE: 0.0341781\n", 742 | "Epoch: 10, Accuracy: 0.8825, MSE: 0.0239106\n", 743 | "Epoch: 11, Accuracy: 0.86, MSE: 0.026371\n", 744 | "Epoch: 12, Accuracy: 0.86, MSE: 0.0269321\n", 745 | "Epoch: 13, Accuracy: 0.88, MSE: 0.0224421\n", 746 | "Epoch: 14, Accuracy: 0.8875, MSE: 0.0172629\n" 747 | ] 748 | } 749 | ], 750 | "source": [ 751 | "train(module, lin, data_loader, opt, custom_dataset.size().value());" 752 | ] 753 | } 754 | ], 755 | "metadata": { 756 | "kernelspec": { 757 | "display_name": "C++17", 758 | "language": "C++17", 759 | "name": "xcpp17" 760 | }, 761 | "language_info": { 762 | "codemirror_mode": "text/x-c++src", 763 | "file_extension": ".cpp", 764 | "mimetype": "text/x-c++src", 765 | "name": "c++", 766 | "version": "-std=c++17" 767 | } 768 | }, 769 | "nbformat": 4, 770 | "nbformat_minor": 2 771 | } 772 | --------------------------------------------------------------------------------