├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── dataset.py ├── images └── multi_view_car_plug.png ├── models ├── kmeans.pkl ├── logistic_regression.pkl ├── mvcnn_stage_1.pkl └── mvcnn_stage_2.pkl ├── network.py ├── notebooks ├── Car Plugs - Multi-View Image Classification.html └── Car Plugs - Multi-View Image Classification.ipynb ├── resources ├── scaler_mean.npy ├── scaler_std.npy └── vocabulary_surf.npy └── trainer.py /.gitattributes: -------------------------------------------------------------------------------- 1 | mvcnn_*.pkl filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | notebooks/.ipynb_checkpoints/ 3 | data/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Samy Tafasca 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-View Image Classification 2 | 3 | ## Introduction 4 | 5 | This project solves a multi-view image classification problem using two different approaches : 6 | 7 | 1. White-box feature extractors (e.g. SIFT) and clustering for image quantization, combined with a classical machine learning algorithm (e.g. SVM) for prediction. 8 | 2. A neural network architecture (MVCNN) that inherently deals with the multi-view aspect by taking multiple images at once as an input and combining their feature maps down the road before classifying. 9 | 10 | There is also a [Medium article](https://towardsdatascience.com/multi-view-image-classification-427c69720f30) that goes into details about the problem and these two approaches. 11 | 12 | This project is implemented in **Python** and makes use of **Scikit-Learn** and **PyTorch** for model building and training and **OpenCV** and **Pillow** for image processing. 13 | 14 | ## Project Structure 15 | 16 | `models` folder that contains the trained models. Specifically, kmeans and logistic regression for the first approach, and two MVCNN's (for feature extraction and fine-tuning). 17 | 18 | `notebooks` folder that contains the jupyter notebook of this project as well as its html export for easy reading. 19 | 20 | `resources` folder that contains extra components required for testing the model (e.g. normalization constants and the vocabulary features). 21 | 22 | `dataset.py` script of the PyTorch custom dataset class that reads all the view images of each data sample and returns a tensor of shape *(Views, Channels, Height, Width)*. 23 | 24 | `network.py` script of the Multi-View Convolutional Neural Network (MVCNN) class that takes inputs of shape *(Samples, Views, Channels, Height, Width)* and returns the logits of the 6 classes. 25 | 26 | `trainer.py` script of a utility function to train the model while keeping track of some metrics of interest. 27 | 28 | ## Data 29 | 30 | The data used in this problem consists of images of 833 car plugs. Each car plug is an item, and has 8 images, 6 of which correspond to orthographic projections while the other 2 are random isometric projections of the car plug. The colors found in the images are meaningless. Finally, each car plug has a unique codename and maps to a single label among a set of 6 predefined classes. This information is contained in a `train.csv` file. It is worth noting that the classes are imbalanced. 31 | 32 | Below is an example of 8 raw images of a car plug. 33 | 34 | ![Multi-view Car Plug](/images/multi_view_car_plug.png) 35 | 36 | Unfortunately, I can not include the data as it was privately provided by a company during a competition. However, in order to follow along certain parts of the notebook, you may need to know the structure of the data folder, as detailed below. 37 | 38 | Multi-View-Image-Classification/ 39 | ├── data/ 40 | │ ├── raw/ 41 | │ │ ├── codename1_x1.png 42 | │ │ ├── codename1_x2.png 43 | │ │ 44 | │ │ ... 45 | │ │ 46 | │ │ ├── codename1_x8.png 47 | │ │ 48 | │ │ ... 49 | │ │ 50 | │ │ ├── codename833_x8.png 51 | │ │ └── train.csv 52 | ├── models 53 | ├── notebooks 54 | ├── resources 55 | 56 | ## Results 57 | 58 | The performance of the models is measured by a weighted accuracy to account for the class imbalance problem. 59 | 60 | 61 | Approach 1 (SURF + LR) | Approach 2 (MVCNN) 62 | ---------------------- | ------------------ 63 | 92% | 98% 64 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | from PIL import Image 4 | import torch 5 | from torch.utils.data import Dataset 6 | import torchvision.transforms as transforms 7 | 8 | 9 | class CarPlugDataset(Dataset): 10 | """ 11 | Custom Car Plug Dataset 12 | Returns tensor data items of shape (Views, Channels, Height, Width) with the corresponding 13 | encoded label. 14 | """ 15 | 16 | def __init__(self, root, plug_label_filename): 17 | self.root = root 18 | self.plug_label_map = self._get_plug_label_map(plug_label_filename) 19 | self.plug_names = self._get_plug_names(root) 20 | self.label_encoder = {'CPA Stecker':0, 21 | 'Schraubstecker':1, 22 | 'Bügelstecker':2, 23 | 'Schieberstecker':3, 24 | 'Kraftschlüssig mit Schraubsicherung':4, 25 | 'Bajonette':5} 26 | 27 | def _get_plug_names(self, root): 28 | plug_names = [fname.split('_')[0] for fname in os.listdir(root) if fname.endswith('.png')] 29 | plug_names = list(set(plug_names)) 30 | return plug_names 31 | 32 | def _get_plug_label_map(self, filename): 33 | reader = csv.DictReader(open(filename)) 34 | plug_label_map = {} 35 | for row in reader: 36 | plug_label_map[row['part_no']] = row['label'] 37 | return plug_label_map 38 | 39 | def __len__(self): 40 | return len(self.plug_names) 41 | 42 | def _transform(self, image): 43 | transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()]) 44 | return transform(image) 45 | 46 | def __getitem__(self, index): 47 | plug_name = self.plug_names[index] 48 | # Get Images of the Plug 49 | plug_fnames = glob.glob(self.root + f'/{plug_name}_*.png') 50 | plug = torch.stack([self._transform(Image.open(fname).convert('RGB')) for fname in plug_fnames]) 51 | label = self.label_encoder[self.plug_label_map[plug_name]] 52 | return plug, label -------------------------------------------------------------------------------- /images/multi_view_car_plug.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAMY-ER/Multi-View-Image-Classification/05ea80e74d00e35b7cfb9d5d8102607810432257/images/multi_view_car_plug.png -------------------------------------------------------------------------------- /models/kmeans.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAMY-ER/Multi-View-Image-Classification/05ea80e74d00e35b7cfb9d5d8102607810432257/models/kmeans.pkl -------------------------------------------------------------------------------- /models/logistic_regression.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAMY-ER/Multi-View-Image-Classification/05ea80e74d00e35b7cfb9d5d8102607810432257/models/logistic_regression.pkl -------------------------------------------------------------------------------- /models/mvcnn_stage_1.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:f0d36b4ceb18b9e9877a66fda28218d7edc5a34b2087f392767b6c4c6d41875e 3 | size 106282816 4 | -------------------------------------------------------------------------------- /models/mvcnn_stage_2.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a58c09f77410c92ba0dfffb68ed208f6fe5f3d225f05c07031dd340b54a18b68 3 | size 106282816 4 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | 5 | 6 | class MVCNN(nn.Module): 7 | """ 8 | Multi-View Convolutional Neural Network (MVCNN) 9 | Initializes a model with the architecture of a MVCNN with a ResNet34 base. 10 | """ 11 | def __init__(self, num_classes=6, pretrained=True): 12 | super(MVCNN, self).__init__() 13 | resnet = models.resnet34(pretrained = pretrained) 14 | fc_in_features = resnet.fc.in_features 15 | self.features = nn.Sequential(*list(resnet.children())[:-1]) 16 | self.classifier = nn.Sequential( 17 | nn.Dropout(), 18 | nn.Linear(fc_in_features, 2048), 19 | nn.ReLU(inplace=True), 20 | nn.Dropout(), 21 | nn.Linear(2048, 2048), 22 | nn.ReLU(inplace=True), 23 | nn.Linear(2048, num_classes) 24 | ) 25 | 26 | def forward(self, inputs): # inputs.shape = samples x views x height x width x channels 27 | inputs = inputs.transpose(0, 1) 28 | view_features = [] 29 | for view_batch in inputs: 30 | view_batch = self.features(view_batch) 31 | view_batch = view_batch.view(view_batch.shape[0], view_batch.shape[1:].numel()) 32 | view_features.append(view_batch) 33 | 34 | pooled_views, _ = torch.max(torch.stack(view_features), 0) 35 | outputs = self.classifier(pooled_views) 36 | return outputs -------------------------------------------------------------------------------- /resources/scaler_mean.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAMY-ER/Multi-View-Image-Classification/05ea80e74d00e35b7cfb9d5d8102607810432257/resources/scaler_mean.npy -------------------------------------------------------------------------------- /resources/scaler_std.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAMY-ER/Multi-View-Image-Classification/05ea80e74d00e35b7cfb9d5d8102607810432257/resources/scaler_std.npy -------------------------------------------------------------------------------- /resources/vocabulary_surf.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAMY-ER/Multi-View-Image-Classification/05ea80e74d00e35b7cfb9d5d8102607810432257/resources/vocabulary_surf.npy -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import copy 3 | import torch 4 | 5 | 6 | def train_model(model, dataloaders, criterion, optimizer, num_epochs=25): 7 | """ 8 | Function to train a model for a number of epochs. 9 | 10 | Args: 11 | model: The neural network model to train 12 | dataloaders: A dictionary of data loaders of the shape {'train': train_loader, 'val': val_loader}. 13 | criterion: The loss function. Takes model outputs and labels as input and produces a loss value 14 | optimizer: The optimizer object to use in order to train the model 15 | num_epochs: The number of epochs to train the model for 16 | 17 | Returns: 18 | (model, val_acc_history): The trained model and the history of validation accuracy 19 | 20 | """ 21 | since = time.time() 22 | 23 | val_acc_history = [] 24 | 25 | best_model_wts = copy.deepcopy(model.state_dict()) 26 | best_acc = 0.0 27 | 28 | for epoch in range(1, num_epochs+1): 29 | print('Epoch {}/{}'.format(epoch, num_epochs)) 30 | print('-' * 10) 31 | 32 | # Each epoch has a training and validation phase 33 | for phase in ['train', 'val']: 34 | if phase == 'train': 35 | model.train() # Set model to training mode 36 | else: 37 | model.eval() # Set model to evaluate mode 38 | 39 | running_loss = 0.0 40 | running_corrects = 0 41 | all_preds = [] 42 | all_labels = [] 43 | # Iterate over data. 44 | for inputs, labels in dataloaders[phase]: 45 | inputs = inputs.to(device) 46 | labels = labels.to(device) 47 | 48 | # zero the parameter gradients 49 | optimizer.zero_grad() 50 | 51 | # forward 52 | # track history if only in train 53 | with torch.set_grad_enabled(phase == 'train'): 54 | # Get model outputs and calculate loss 55 | outputs = model(inputs) 56 | loss = criterion(outputs, labels) 57 | # Get model predictions 58 | _, preds = torch.max(outputs, 1) 59 | 60 | # backward + optimize only if in training phase 61 | if phase == 'train': 62 | loss.backward() 63 | optimizer.step() 64 | 65 | # statistics 66 | running_loss += loss.item() * inputs.size(0) 67 | running_corrects += torch.sum(preds == labels.data) 68 | all_preds.append(preds) 69 | all_labels.append(labels) 70 | 71 | epoch_loss = running_loss / len(dataloaders[phase].sampler.indices) 72 | epoch_acc = running_corrects.double() / len(dataloaders[phase].sampler.indices) 73 | all_labels = torch.cat(all_labels, 0) 74 | all_preds = torch.cat(all_preds, 0) 75 | epoch_weighted_acc = accuracy_score(all_labels.cpu().numpy(), all_preds.cpu().numpy(), sample_weight=compute_sample_weight(class_weights, all_labels.cpu().numpy())) 76 | 77 | 78 | print('{} Loss: {:.4f} - Acc: {:.4f} - Weighted Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc, epoch_weighted_acc)) 79 | 80 | # deep copy the model 81 | if phase == 'val' and epoch_weighted_acc > best_acc: 82 | best_acc = epoch_weighted_acc 83 | best_model_wts = copy.deepcopy(model.state_dict()) 84 | if phase == 'val': 85 | val_acc_history.append(epoch_weighted_acc) 86 | 87 | print() 88 | 89 | time_elapsed = time.time() - since 90 | print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 91 | print('Best val Acc: {:4f}'.format(best_acc)) 92 | 93 | # load best model weights 94 | model.load_state_dict(best_model_wts) 95 | return model, val_acc_history --------------------------------------------------------------------------------