├── README.md ├── environment.yml ├── eval.py ├── eval_tensorrt.py ├── infer_tensorrt.py ├── libs └── dataset.py ├── onnx_to_trt.py ├── resources └── player_similarity.gif ├── siamese ├── __init__.py ├── siamese_network.py └── siamese_network_trt.py ├── torch_to_onnx.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Siamese Network 2 | 3 | A simple but pragmatic implementation of Siamese Networks in PyTorch using the pre-trained feature extraction networks provided in ```torchvision.models```. 4 | 5 | ![Results](resources/player_similarity.gif) 6 | 7 | ## Design Choices: 8 | - The siamese network provided in this repository uses a sigmoid at its output, thus making it a binary classification task (positive=same, negative=different) with binary cross entropy loss, as opposed to the triplet loss generally used. 9 | - I have added dropout to the final classification head network along-with BatchNorm. On online forums there is discussion that dropout with batchnorm is ineffective, however, I found it to improve the results on my specific private dataset. 10 | - Instead of concatenating the feature vectors of the two images, I opted to multiply them element-wise, which increased the validation accuracy for my specific dataset. 11 | 12 | ## Setting up environment. 13 | The provided setup instructions assume that anaconda is already installed on the system. To set up the environment for this repository, run the following commands to create and activate an environment named 'pytorch_siamese'. (The command takes a while to run, so please keep patience): 14 | ``` 15 | conda env create -f environment.yml 16 | conda activate pytorch_siamese 17 | ``` 18 | The environment contains the required packages for TensorRT as well, hence the arduous task of installing it in global cuda is not required. 19 | 20 | ## Setting up the dataset. 21 | The expected format for both the training and validation dataset is the same. Image belonging to a single entity/class should be placed in a folder with the name of the class. The folders for every class are then to be placed within a common root directory (which will be passed to the trainined and evaluation scripts). The folder structure is also explained below: 22 | ``` 23 | |--Train or Validation dataset root directory 24 | |--Class1 25 | |-Image1 26 | |-Image2 27 | . 28 | . 29 | . 30 | |-ImageN 31 | |--Class2 32 | |--Class3 33 | . 34 | . 35 | . 36 | |--ClassN 37 | ``` 38 | 39 | ## Training the model: 40 | To train the model, run the following command along with the required command line arguments: 41 | ``` 42 | python train.py [-h] --train_path TRAIN_PATH --val_path VAL_PATH -o OUT_PATH 43 | [-b BACKBONE] [-lr LEARNING_RATE] [-e EPOCHS] [-s SAVE_AFTER] 44 | 45 | optional arguments: 46 | -h, --help show this help message and exit 47 | --train_path TRAIN_PATH 48 | Path to directory containing training dataset. 49 | --val_path VAL_PATH Path to directory containing validation dataset. 50 | -o OUT_PATH, --out_path OUT_PATH 51 | Path for outputting model weights and tensorboard 52 | summary. 53 | -b BACKBONE, --backbone BACKBONE 54 | Network backbone from torchvision.models to be used in 55 | the siamese network. 56 | -lr LEARNING_RATE, --learning_rate LEARNING_RATE 57 | Learning Rate 58 | -e EPOCHS, --epochs EPOCHS 59 | Number of epochs to train 60 | -s SAVE_AFTER, --save_after SAVE_AFTER 61 | Model checkpoint is saved after each specified number 62 | of epochs. 63 | ``` 64 | The backbone can be chosen from any of the networks listed in [torchvision.models](https://pytorch.org/vision/stable/models.html) 65 | 66 | ## Evaluating the model: 67 | Following command can be used to evaluate the model on a validation set. Output images with containing the pair and their corresponding similarity confidence will be outputted to `{OUT_PATH}`. 68 | 69 | Note: During evaluation the pairs are generated with a deterministic seed for the numpy random module, so as to allow comparisons between multiple evaluations. 70 | 71 | ``` 72 | python eval.py [-h] -v VAL_PATH -o OUT_PATH -c CHECKPOINT 73 | 74 | optional arguments: 75 | -h, --help show this help message and exit 76 | -v VAL_PATH, --val_path VAL_PATH 77 | Path to directory containing validation dataset. 78 | -o OUT_PATH, --out_path OUT_PATH 79 | Path for saving prediction images. 80 | -c CHECKPOINT, --checkpoint CHECKPOINT 81 | Path of model checkpoint to be used for inference. 82 | ``` 83 | 84 | ## Converting from Torch to ONNX 85 | To convert the torch model (.pth extension) outputted by `train.py` into ONNX format, kindly use the file `torch_to_onnx.py`: 86 | ``` 87 | python torch_to_onnx.py [-h] -c CHECKPOINT -o OUT_PATH 88 | 89 | optional arguments: 90 | -h, --help show this help message and exit 91 | -c CHECKPOINT, --checkpoint CHECKPOINT 92 | Path of model checkpoint to be used for inference. 93 | -o OUT_PATH, --out_path OUT_PATH 94 | Path for saving tensorrt model. 95 | ``` 96 | 97 | ## Converting from ONNX to TensorRT Engine 98 | To generate a TensorRT engine file from the ONNX model outputted by `torch_to_onnx.py`, kindly use the file `onnx_to_trt.py`: 99 | ``` 100 | python onnx_to_trt.py [-h] --onnx ONNX --engine ENGINE 101 | 102 | optional arguments: 103 | -h, --help show this help message and exit 104 | --onnx ONNX Path of onnx model generated by 'torch_to_onnx.py'. 105 | --engine ENGINE Path for saving tensorrt engine. 106 | ``` 107 | 108 | ## Inference and Evaluation using TensorRT Engine: 109 | To infer and evaluate the TensorRT engine outputted by `onnx_to_trt.py`, kindly use the files `infer_tensorrt.py` and `eval_tensorrt.py`. Usages for both of these files are provided below: 110 | ``` 111 | python eval_tensorrt.py [-h] -v VAL_PATH -o OUT_PATH --engine ENGINE 112 | 113 | optional arguments: 114 | -h, --help show this help message and exit 115 | -v VAL_PATH, --val_path VAL_PATH 116 | Path to directory containing validation dataset. 117 | -o OUT_PATH, --out_path OUT_PATH 118 | Path for saving prediction images. 119 | --engine ENGINE Path to tensorrt engine generated by 'onnx_to_trt.py'. 120 | ``` 121 | 122 | ``` 123 | python infer_tensorrt.py [-h] --image1 IMAGE1 --image2 IMAGE2 --engine ENGINE 124 | 125 | optional arguments: 126 | -h, --help show this help message and exit 127 | --image1 IMAGE1 Path to first image of the pair. 128 | --image2 IMAGE2 Path to second image of the pair. 129 | --engine ENGINE Path to tensorrt engine generated by 'onnx_to_trt.py'. 130 | ``` 131 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: pytorch_siamese 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - absl-py=0.12.0=py36h06a4308_0 8 | - aiohttp=3.6.3=py36h7b6447c_0 9 | - async-timeout=3.0.1=py36h06a4308_0 10 | - attrs=20.3.0=pyhd3eb1b0_0 11 | - blas=1.0=mkl 12 | - blinker=1.4=py36h06a4308_0 13 | - brotlipy=0.7.0=py36h27cfd23_1003 14 | - c-ares=1.17.1=h27cfd23_0 15 | - ca-certificates=2021.4.13=h06a4308_1 16 | - cachetools=4.2.1=pyhd3eb1b0_0 17 | - certifi=2020.12.5=py36h06a4308_0 18 | - cffi=1.14.5=py36h261ae71_0 19 | - chardet=3.0.4=py36h06a4308_1003 20 | - click=7.1.2=pyhd3eb1b0_0 21 | - coverage=5.5=py36h27cfd23_2 22 | - cryptography=3.4.7=py36hd23ed53_0 23 | - cudatoolkit=10.2.89=hfd86e86_1 24 | - cython=0.29.23=py36h2531618_0 25 | - dataclasses=0.8=pyh4f3eec9_6 26 | - freetype=2.10.4=h5ab3b9f_0 27 | - google-auth=1.29.0=pyhd3eb1b0_0 28 | - google-auth-oauthlib=0.4.4=pyhd3eb1b0_0 29 | - grpcio=1.36.1=py36h2157cd5_1 30 | - idna=2.10=pyhd3eb1b0_0 31 | - idna_ssl=1.1.0=py36h06a4308_0 32 | - importlib-metadata=3.10.0=py36h06a4308_0 33 | - intel-openmp=2021.2.0=h06a4308_610 34 | - jpeg=9b=h024ee3a_2 35 | - lcms2=2.12=h3be6417_0 36 | - ld_impl_linux-64=2.33.1=h53a641e_7 37 | - libffi=3.3=he6710b0_2 38 | - libgcc-ng=9.1.0=hdf63c60_0 39 | - libpng=1.6.37=hbc83047_0 40 | - libprotobuf=3.14.0=h8c45485_0 41 | - libstdcxx-ng=9.1.0=hdf63c60_0 42 | - libtiff=4.1.0=h2733197_1 43 | - libuv=1.40.0=h7b6447c_0 44 | - lz4-c=1.9.3=h2531618_0 45 | - markdown=3.3.4=py36h06a4308_0 46 | - mkl=2020.2=256 47 | - mkl-service=2.3.0=py36he8ac12f_0 48 | - mkl_fft=1.3.0=py36h54f3939_0 49 | - mkl_random=1.1.1=py36h0573a6f_0 50 | - multidict=4.7.6=py36h7b6447c_1 51 | - ncurses=6.2=he6710b0_1 52 | - ninja=1.10.2=hff7bd54_1 53 | - numpy=1.19.2=py36h54aff64_0 54 | - numpy-base=1.19.2=py36hfa32c7d_0 55 | - oauthlib=3.1.0=py_0 56 | - olefile=0.46=py36_0 57 | - openssl=1.1.1k=h27cfd23_0 58 | - pillow=8.2.0=py36he98fc37_0 59 | - pip=21.0.1=py36h06a4308_0 60 | - pyasn1=0.4.8=py_0 61 | - pyasn1-modules=0.2.8=py_0 62 | - pycparser=2.20=py_2 63 | - pyjwt=1.7.1=py36_0 64 | - pyopenssl=20.0.1=pyhd3eb1b0_1 65 | - pysocks=1.7.1=py36h06a4308_0 66 | - python=3.6.13=hdb3f193_0 67 | - pytorch=1.7.1=py3.6_cuda10.2.89_cudnn7.6.5_0 68 | - readline=8.1=h27cfd23_0 69 | - requests=2.25.1=pyhd3eb1b0_0 70 | - requests-oauthlib=1.3.0=py_0 71 | - rsa=4.7.2=pyhd3eb1b0_1 72 | - setuptools=52.0.0=py36h06a4308_0 73 | - six=1.15.0=py36h06a4308_0 74 | - sqlite=3.35.4=hdfb4753_0 75 | - tensorboard=2.4.0=pyhc547734_0 76 | - tensorboard-plugin-wit=1.6.0=py_0 77 | - tk=8.6.10=hbc83047_0 78 | - torchaudio=0.7.2=py36 79 | - torchvision=0.8.2=py36_cu102 80 | - typing_extensions=3.7.4.3=pyha847dfd_0 81 | - urllib3=1.26.4=pyhd3eb1b0_0 82 | - werkzeug=1.0.1=pyhd3eb1b0_0 83 | - wheel=0.36.2=pyhd3eb1b0_0 84 | - xz=5.2.5=h7b6447c_0 85 | - yarl=1.6.3=py36h27cfd23_0 86 | - zipp=3.4.1=pyhd3eb1b0_0 87 | - zlib=1.2.11=h7b6447c_3 88 | - zstd=1.4.9=haebb681_0 89 | - pip: 90 | - albumentations==0.5.2 91 | - appdirs==1.4.4 92 | - cycler==0.10.0 93 | - decorator==4.4.2 94 | - imageio==2.9.0 95 | - imgaug==0.4.0 96 | - kiwisolver==1.3.1 97 | - mako==1.1.4 98 | - markupsafe==1.1.1 99 | - matplotlib==3.3.4 100 | - networkx==2.5.1 101 | - nvidia-cublas==11.4.1.1026 102 | - nvidia-cuda-nvrtc==11.1.105 103 | - nvidia-cuda-runtime==11.2.146 104 | - nvidia-cudnn==8.1.1.33 105 | - nvidia-pyindex==1.0.8 106 | - nvidia-tensorrt==7.2.3.4 107 | - onnx==1.9.0 108 | - opencv-python==4.5.1.48 109 | - opencv-python-headless==4.5.1.48 110 | - protobuf==3.15.8 111 | - pycuda==2021.1 112 | - pyparsing==2.4.7 113 | - python-dateutil==2.8.1 114 | - pytools==2021.2.5 115 | - pywavelets==1.1.1 116 | - pyyaml==5.4.1 117 | - scikit-image==0.17.2 118 | - scipy==1.5.4 119 | - shapely==1.7.1 120 | - tifffile==2020.9.3 121 | - tqdm==4.60.0 122 | prefix: /home/sohaib/anaconda3/envs/pytorch_siamese 123 | 124 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import cv2 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | import torch 9 | from torch.utils.data import DataLoader 10 | from torchvision import transforms 11 | 12 | from siamese import SiameseNetwork 13 | from libs.dataset import Dataset 14 | 15 | if __name__ == "__main__": 16 | parser = argparse.ArgumentParser() 17 | 18 | parser.add_argument( 19 | '-v', 20 | '--val_path', 21 | type=str, 22 | help="Path to directory containing validation dataset.", 23 | required=True 24 | ) 25 | parser.add_argument( 26 | '-o', 27 | '--out_path', 28 | type=str, 29 | help="Path for saving prediction images.", 30 | required=True 31 | ) 32 | parser.add_argument( 33 | '-c', 34 | '--checkpoint', 35 | type=str, 36 | help="Path of model checkpoint to be used for inference.", 37 | required=True 38 | ) 39 | 40 | args = parser.parse_args() 41 | 42 | os.makedirs(args.out_path, exist_ok=True) 43 | 44 | # Set device to CUDA if a CUDA device is available, else CPU 45 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 46 | 47 | val_dataset = Dataset(args.val_path, shuffle_pairs=False, augment=False) 48 | val_dataloader = DataLoader(val_dataset, batch_size=1) 49 | 50 | criterion = torch.nn.BCELoss() 51 | 52 | checkpoint = torch.load(args.checkpoint) 53 | model = SiameseNetwork(backbone=checkpoint['backbone']) 54 | model.to(device) 55 | model.load_state_dict(checkpoint['model_state_dict']) 56 | model.eval() 57 | 58 | losses = [] 59 | correct = 0 60 | total = 0 61 | 62 | inv_transform = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ], 63 | std = [ 1/0.229, 1/0.224, 1/0.225 ]), 64 | transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ], 65 | std = [ 1., 1., 1. ]), 66 | ]) 67 | 68 | for i, ((img1, img2), y, (class1, class2)) in enumerate(val_dataloader): 69 | print("[{} / {}]".format(i, len(val_dataloader))) 70 | 71 | img1, img2, y = map(lambda x: x.to(device), [img1, img2, y]) 72 | class1 = class1[0] 73 | class2 = class2[0] 74 | 75 | prob = model(img1, img2) 76 | loss = criterion(prob, y) 77 | 78 | losses.append(loss.item()) 79 | correct += torch.count_nonzero(y == (prob > 0.5)).item() 80 | total += len(y) 81 | 82 | fig = plt.figure("class1={}\tclass2={}".format(class1, class2), figsize=(4, 2)) 83 | plt.suptitle("cls1={} conf={:.2f} cls2={}".format(class1, prob[0][0].item(), class2)) 84 | 85 | # Apply inverse transform (denormalization) on the images to retrieve original images. 86 | img1 = inv_transform(img1).cpu().numpy()[0] 87 | img2 = inv_transform(img2).cpu().numpy()[0] 88 | # show first image 89 | ax = fig.add_subplot(1, 2, 1) 90 | plt.imshow(img1[0], cmap=plt.cm.gray) 91 | plt.axis("off") 92 | 93 | # show the second image 94 | ax = fig.add_subplot(1, 2, 2) 95 | plt.imshow(img2[0], cmap=plt.cm.gray) 96 | plt.axis("off") 97 | 98 | # show the plot 99 | plt.savefig(os.path.join(args.out_path, '{}.png').format(i)) 100 | 101 | print("Validation: Loss={:.2f}\t Accuracy={:.2f}\t".format(sum(losses)/len(losses), correct / total)) -------------------------------------------------------------------------------- /eval_tensorrt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms 10 | 11 | from libs.dataset import Dataset 12 | from siamese.siamese_network_trt import SiameseNetworkTRT 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument( 18 | '-v', 19 | '--val_path', 20 | type=str, 21 | help="Path to directory containing validation dataset.", 22 | required=True 23 | ) 24 | parser.add_argument( 25 | '-o', 26 | '--out_path', 27 | type=str, 28 | help="Path for saving prediction images.", 29 | required=True 30 | ) 31 | parser.add_argument( 32 | '--engine', 33 | type=str, 34 | help="Path to tensorrt engine generated by 'onnx_to_trt.py'.", 35 | required=True 36 | ) 37 | 38 | args = parser.parse_args() 39 | 40 | os.makedirs(args.out_path, exist_ok=True) 41 | 42 | val_dataset = Dataset(args.val_path, shuffle_pairs=False, augment=False) 43 | val_dataloader = DataLoader(val_dataset, batch_size=1) 44 | 45 | criterion = torch.nn.BCELoss() 46 | 47 | model = SiameseNetworkTRT() 48 | model.load_model(args.engine) 49 | 50 | losses = [] 51 | correct = 0 52 | total = 0 53 | 54 | inv_transform = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ], 55 | std = [ 1/0.229, 1/0.224, 1/0.225 ]), 56 | transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ], 57 | std = [ 1., 1., 1. ]), 58 | ]) 59 | 60 | for i, ((img1, img2), y, (class1, class2)) in enumerate(val_dataloader): 61 | print("[{} / {}]".format(i, len(val_dataloader))) 62 | 63 | class1 = class1[0] 64 | class2 = class2[0] 65 | 66 | prob = model.predict(img1.cpu().numpy(), img2.cpu().numpy(), preprocess=False) 67 | 68 | loss = criterion(torch.Tensor([[prob]]), y) 69 | 70 | losses.append(loss.item()) 71 | correct += y[0, 0].item() == (prob > 0.5) 72 | total += 1 73 | 74 | fig = plt.figure("class1={}\tclass2={}".format(class1, class2), figsize=(4, 2)) 75 | plt.suptitle("cls1={} conf={:.2f} cls2={}".format(class1, prob, class2)) 76 | 77 | img1 = inv_transform(img1).cpu().numpy()[0] 78 | img2 = inv_transform(img2).cpu().numpy()[0] 79 | # show first image 80 | ax = fig.add_subplot(1, 2, 1) 81 | plt.imshow(img1[0], cmap=plt.cm.gray) 82 | plt.axis("off") 83 | 84 | # show the second image 85 | ax = fig.add_subplot(1, 2, 2) 86 | plt.imshow(img2[0], cmap=plt.cm.gray) 87 | plt.axis("off") 88 | 89 | # show the plot 90 | plt.savefig(os.path.join(args.out_path, '{}.png').format(i)) 91 | 92 | print("Validation: Loss={:.2f}\t Accuracy={:.2f}\t".format(sum(losses)/len(losses), correct / total)) -------------------------------------------------------------------------------- /infer_tensorrt.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import argparse 3 | from siamese.siamese_network_trt import SiameseNetworkTRT 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | 8 | parser.add_argument( 9 | '--image1', 10 | type=str, 11 | help="Path to first image of the pair.", 12 | required=True 13 | ) 14 | parser.add_argument( 15 | '--image2', 16 | type=str, 17 | help="Path to second image of the pair.", 18 | required=True 19 | ) 20 | parser.add_argument( 21 | '--engine', 22 | type=str, 23 | help="Path to tensorrt engine generated by 'onnx_to_trt.py'.", 24 | required=True 25 | ) 26 | 27 | 28 | args = parser.parse_args() 29 | 30 | model = SiameseNetworkTRT() 31 | model.load_model(args.engine) 32 | 33 | image1 = cv2.imread(args.image1) 34 | image2 = cv2.imread(args.image2) 35 | similarity = model.predict(image1, image2) 36 | 37 | print(F"Similarity between the two images = {round(similarity, 2)}") -------------------------------------------------------------------------------- /libs/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import time 4 | 5 | import numpy as np 6 | from PIL import Image 7 | import matplotlib.pyplot as plt 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torchvision import transforms 13 | 14 | class Dataset(torch.utils.data.IterableDataset): 15 | def __init__(self, path, shuffle_pairs=True, augment=False): 16 | ''' 17 | Create an iterable dataset from a directory containing sub-directories of 18 | entities with their images contained inside each sub-directory. 19 | 20 | Parameters: 21 | path (str): Path to directory containing the dataset. 22 | shuffle_pairs (boolean): Pass True when training, False otherwise. When set to false, the image pair generation will be deterministic 23 | augment (boolean): When True, images will be augmented using a standard set of transformations. 24 | 25 | where b = batch size 26 | 27 | Returns: 28 | output (torch.Tensor): shape=[b, 1], Similarity of each pair of images 29 | ''' 30 | self.path = path 31 | 32 | self.feed_shape = [3, 224, 224] 33 | self.shuffle_pairs = shuffle_pairs 34 | 35 | self.augment = augment 36 | 37 | if self.augment: 38 | # If images are to be augmented, add extra operations for it (first two). 39 | self.transform = transforms.Compose([ 40 | transforms.RandomAffine(degrees=20, translate=(0.2, 0.2), scale=(0.8, 1.2), shear=0.2), 41 | transforms.RandomHorizontalFlip(p=0.5), 42 | transforms.ToTensor(), 43 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 44 | transforms.Resize(self.feed_shape[1:]) 45 | ]) 46 | else: 47 | # If no augmentation is needed then apply only the normalization and resizing operations. 48 | self.transform = transforms.Compose([ 49 | transforms.ToTensor(), 50 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 51 | transforms.Resize(self.feed_shape[1:]) 52 | ]) 53 | 54 | self.create_pairs() 55 | 56 | def create_pairs(self): 57 | ''' 58 | Creates two lists of indices that will form the pairs, to be fed for training or evaluation. 59 | ''' 60 | 61 | self.image_paths = glob.glob(os.path.join(self.path, "*/*.png")) 62 | self.image_classes = [] 63 | self.class_indices = {} 64 | 65 | for image_path in self.image_paths: 66 | image_class = image_path.split(os.path.sep)[-2] 67 | self.image_classes.append(image_class) 68 | 69 | if image_class not in self.class_indices: 70 | self.class_indices[image_class] = [] 71 | self.class_indices[image_class].append(self.image_paths.index(image_path)) 72 | 73 | self.indices1 = np.arange(len(self.image_paths)) 74 | 75 | if self.shuffle_pairs: 76 | np.random.seed(int(time.time())) 77 | np.random.shuffle(self.indices1) 78 | else: 79 | # If shuffling is set to off, set the random seed to 1, to make it deterministic. 80 | np.random.seed(1) 81 | 82 | select_pos_pair = np.random.rand(len(self.image_paths)) < 0.5 83 | 84 | self.indices2 = [] 85 | 86 | for i, pos in zip(self.indices1, select_pos_pair): 87 | class1 = self.image_classes[i] 88 | if pos: 89 | class2 = class1 90 | else: 91 | class2 = np.random.choice(list(set(self.class_indices.keys()) - {class1})) 92 | idx2 = np.random.choice(self.class_indices[class2]) 93 | self.indices2.append(idx2) 94 | self.indices2 = np.array(self.indices2) 95 | 96 | def __iter__(self): 97 | self.create_pairs() 98 | 99 | for idx, idx2 in zip(self.indices1, self.indices2): 100 | 101 | image_path1 = self.image_paths[idx] 102 | image_path2 = self.image_paths[idx2] 103 | 104 | class1 = self.image_classes[idx] 105 | class2 = self.image_classes[idx2] 106 | 107 | image1 = Image.open(image_path1).convert("RGB") 108 | image2 = Image.open(image_path2).convert("RGB") 109 | 110 | if self.transform: 111 | image1 = self.transform(image1).float() 112 | image2 = self.transform(image2).float() 113 | 114 | yield (image1, image2), torch.FloatTensor([class1==class2]), (class1, class2) 115 | 116 | def __len__(self): 117 | return len(self.image_paths) 118 | -------------------------------------------------------------------------------- /onnx_to_trt.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import argparse 3 | import tensorrt as trt 4 | 5 | # logger to capture errors, warnings, and other information during the build and inference phases 6 | TRT_LOGGER = trt.Logger() 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument( 12 | '--onnx', 13 | type=str, 14 | help="Path of onnx model generated by 'torch_to_onnx.py'.", 15 | required=True 16 | ) 17 | parser.add_argument( 18 | '--engine', 19 | type=str, 20 | help="Path for saving tensorrt engine.", 21 | required=True 22 | ) 23 | 24 | args = parser.parse_args() 25 | 26 | onnx_file_path = args.onnx 27 | # initialize TensorRT engine and parse ONNX model 28 | EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) 29 | 30 | builder = trt.Builder(TRT_LOGGER) 31 | network = builder.create_network(EXPLICIT_BATCH) 32 | parser = trt.OnnxParser(network, TRT_LOGGER) 33 | 34 | # parse ONNX 35 | with open(onnx_file_path, 'rb') as model: 36 | print('Beginning ONNX file parsing') 37 | parser.parse(model.read()) 38 | print('Completed parsing of ONNX file') 39 | 40 | # allow TensorRT to use up to 1GB of GPU memory for tactic selection 41 | builder.max_workspace_size = 1 << 30 42 | # we have only one image in batch 43 | builder.max_batch_size = 1 44 | # use FP16 mode if possible 45 | if builder.platform_has_fast_fp16: 46 | builder.fp16_mode = True 47 | 48 | # generate TensorRT engine optimized for the target platform 49 | print('Building an engine...') 50 | engine = builder.build_cuda_engine(network) 51 | print("Completed creating Engine") 52 | 53 | with open(args.engine, 'wb') as f: 54 | f.write(engine.serialize()) 55 | -------------------------------------------------------------------------------- /resources/player_similarity.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohaib023/siamese-pytorch/a65ec83b7294ecefea6541ef2321a3bcbd41f745/resources/player_similarity.gif -------------------------------------------------------------------------------- /siamese/__init__.py: -------------------------------------------------------------------------------- 1 | from .siamese_network import SiameseNetwork -------------------------------------------------------------------------------- /siamese/siamese_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torchvision import models 6 | 7 | class SiameseNetwork(nn.Module): 8 | def __init__(self, backbone="resnet18"): 9 | ''' 10 | Creates a siamese network with a network from torchvision.models as backbone. 11 | 12 | Parameters: 13 | backbone (str): Options of the backbone networks can be found at https://pytorch.org/vision/stable/models.html 14 | ''' 15 | 16 | super().__init__() 17 | 18 | if backbone not in models.__dict__: 19 | raise Exception("No model named {} exists in torchvision.models.".format(backbone)) 20 | 21 | # Create a backbone network from the pretrained models provided in torchvision.models 22 | self.backbone = models.__dict__[backbone](pretrained=True, progress=True) 23 | 24 | # Get the number of features that are outputted by the last layer of backbone network. 25 | out_features = list(self.backbone.modules())[-1].out_features 26 | 27 | # Create an MLP (multi-layer perceptron) as the classification head. 28 | # Classifies if provided combined feature vector of the 2 images represent same player or different. 29 | self.cls_head = nn.Sequential( 30 | nn.Dropout(p=0.5), 31 | nn.Linear(out_features, 512), 32 | nn.BatchNorm1d(512), 33 | nn.ReLU(), 34 | 35 | nn.Dropout(p=0.5), 36 | nn.Linear(512, 64), 37 | nn.BatchNorm1d(64), 38 | nn.Sigmoid(), 39 | nn.Dropout(p=0.5), 40 | 41 | nn.Linear(64, 1), 42 | nn.Sigmoid(), 43 | ) 44 | 45 | def forward(self, img1, img2): 46 | ''' 47 | Returns the similarity value between two images. 48 | 49 | Parameters: 50 | img1 (torch.Tensor): shape=[b, 3, 224, 224] 51 | img2 (torch.Tensor): shape=[b, 3, 224, 224] 52 | 53 | where b = batch size 54 | 55 | Returns: 56 | output (torch.Tensor): shape=[b, 1], Similarity of each pair of images 57 | ''' 58 | 59 | # Pass the both images through the backbone network to get their seperate feature vectors 60 | feat1 = self.backbone(img1) 61 | feat2 = self.backbone(img2) 62 | 63 | # Multiply (element-wise) the feature vectors of the two images together, 64 | # to generate a combined feature vector representing the similarity between the two. 65 | combined_features = feat1 * feat2 66 | 67 | # Pass the combined feature vector through classification head to get similarity value in the range of 0 to 1. 68 | output = self.cls_head(combined_features) 69 | return output -------------------------------------------------------------------------------- /siamese/siamese_network_trt.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | 6 | import tensorrt as trt 7 | import pycuda.autoinit 8 | import pycuda.driver as cuda 9 | 10 | from torchvision import transforms 11 | 12 | # logger to capture errors, warnings, and other information during the build and inference phases 13 | TRT_LOGGER = trt.Logger() 14 | 15 | class SiameseNetworkTRT: 16 | def __init__(self, backbone="resnet18", feed_shape=(224, 224)): 17 | self.context = None 18 | 19 | self.transform = transforms.Compose([ 20 | transforms.ToTensor(), 21 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 22 | transforms.Resize(feed_shape) 23 | ]) 24 | 25 | def load_model(self, engine_path): 26 | with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime: 27 | self.engine = runtime.deserialize_cuda_engine(f.read()) 28 | self.context = self.engine.create_execution_context() 29 | 30 | self.device_input1, self.device_input2 = [None] * 2 31 | for binding in self.engine: 32 | if self.engine.binding_is_input(binding): # we expect only one input 33 | input_shape = self.engine.get_binding_shape(binding) 34 | input_size = trt.volume(input_shape) * self.engine.max_batch_size * np.dtype(np.float32).itemsize # in bytes 35 | if self.device_input1 is None: 36 | self.device_input1 = cuda.mem_alloc(input_size) 37 | elif self.device_input2 is None: 38 | self.device_input2 = cuda.mem_alloc(input_size) 39 | else: 40 | raise Exception("Network expects more than 2 inputs.") 41 | else: # and one output 42 | self.output_shape = self.engine.get_binding_shape(binding) 43 | # create page-locked memory buffers (i.e. won't be swapped to disk) 44 | self.host_output = cuda.pagelocked_empty(trt.volume(self.output_shape) * self.engine.max_batch_size, dtype=np.float32) 45 | self.device_output = cuda.mem_alloc(self.host_output.nbytes) 46 | 47 | # Create a stream in which to copy inputs/outputs and run inference. 48 | self.stream = cuda.Stream() 49 | 50 | def predict(self, image1, image2, preprocess=True): 51 | ''' 52 | Returns the similarity value between two images. 53 | 54 | Parameters: 55 | image1 (np.array): Raw first image that is read using cv2.imread 56 | image2 (np.array): Raw second image that is read using cv2.imread 57 | preprocess (bool): Only provided for "eval_tensorrt.py". Otherwise always true when providing raw images. 58 | Returns: 59 | output (float): Similarity of the passed pair of images in range (0, 1) 60 | ''' 61 | if self.context is None: 62 | raise Exception(F"Context not found! Please load model first using 'load_model' function on this object.") 63 | 64 | if preprocess: 65 | image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB) 66 | image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB) 67 | 68 | image1 = Image.fromarray(image1).convert("RGB") 69 | image2 = Image.fromarray(image2).convert("RGB") 70 | 71 | image1 = self.transform(image1).float().numpy() 72 | image2 = self.transform(image2).float().numpy() 73 | 74 | image1 = image1.astype(np.float32) 75 | image2 = image2.astype(np.float32) 76 | 77 | cuda.memcpy_htod_async(self.device_input1, image1, self.stream) 78 | cuda.memcpy_htod_async(self.device_input2, image2, self.stream) 79 | 80 | self.context.execute_async(bindings=[int(self.device_input1), int(self.device_input2), int(self.device_output)], stream_handle=self.stream.handle) 81 | cuda.memcpy_dtoh_async(self.host_output, self.device_output, self.stream) 82 | self.stream.synchronize() 83 | 84 | output_data = torch.Tensor(self.host_output).reshape(self.engine.max_batch_size, self.output_shape[0]) 85 | return output_data[0][0].item() -------------------------------------------------------------------------------- /torch_to_onnx.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import onnx 5 | import torch 6 | 7 | from siamese import SiameseNetwork 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser() 11 | 12 | parser.add_argument( 13 | '-c', 14 | '--checkpoint', 15 | type=str, 16 | help="Path of model checkpoint to be used for inference.", 17 | required=True 18 | ) 19 | parser.add_argument( 20 | '-o', 21 | '--out_path', 22 | type=str, 23 | help="Path for saving tensorrt model.", 24 | required=True 25 | ) 26 | 27 | args = parser.parse_args() 28 | 29 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 30 | 31 | checkpoint = torch.load(args.checkpoint) 32 | model = SiameseNetwork(backbone=checkpoint['backbone']) 33 | model.to(device) 34 | model.load_state_dict(checkpoint['model_state_dict']) 35 | model.eval() 36 | 37 | torch.onnx.export(model, (torch.rand(1, 3, 224, 224).to(device), torch.rand(1, 3, 224, 224).to(device)), args.out_path, input_names=['input'], 38 | output_names=['output'], export_params=True) 39 | 40 | onnx_model = onnx.load(args.out_path) 41 | onnx.checker.check_model(onnx_model) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import cv2 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | import torch 9 | from torch.utils.data import DataLoader 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | from siamese import SiameseNetwork 13 | from libs.dataset import Dataset 14 | 15 | if __name__ == "__main__": 16 | parser = argparse.ArgumentParser() 17 | 18 | parser.add_argument( 19 | '--train_path', 20 | type=str, 21 | help="Path to directory containing training dataset.", 22 | required=True 23 | ) 24 | parser.add_argument( 25 | '--val_path', 26 | type=str, 27 | help="Path to directory containing validation dataset.", 28 | required=True 29 | ) 30 | parser.add_argument( 31 | '-o', 32 | '--out_path', 33 | type=str, 34 | help="Path for outputting model weights and tensorboard summary.", 35 | required=True 36 | ) 37 | parser.add_argument( 38 | '-b', 39 | '--backbone', 40 | type=str, 41 | help="Network backbone from torchvision.models to be used in the siamese network.", 42 | default="resnet18" 43 | ) 44 | parser.add_argument( 45 | '-lr', 46 | '--learning_rate', 47 | type=float, 48 | help="Learning Rate", 49 | default=1e-4 50 | ) 51 | parser.add_argument( 52 | '-e', 53 | '--epochs', 54 | type=int, 55 | help="Number of epochs to train", 56 | default=1000 57 | ) 58 | parser.add_argument( 59 | '-s', 60 | '--save_after', 61 | type=int, 62 | help="Model checkpoint is saved after each specified number of epochs.", 63 | default=25 64 | ) 65 | 66 | args = parser.parse_args() 67 | 68 | os.makedirs(args.out_path, exist_ok=True) 69 | 70 | # Set device to CUDA if a CUDA device is available, else CPU 71 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 72 | 73 | train_dataset = Dataset(args.train_path, shuffle_pairs=True, augment=True) 74 | val_dataset = Dataset(args.val_path, shuffle_pairs=False, augment=False) 75 | 76 | train_dataloader = DataLoader(train_dataset, batch_size=8, drop_last=True) 77 | val_dataloader = DataLoader(val_dataset, batch_size=8) 78 | 79 | model = SiameseNetwork(backbone=args.backbone) 80 | model.to(device) 81 | 82 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) 83 | criterion = torch.nn.BCELoss() 84 | 85 | writer = SummaryWriter(os.path.join(args.out_path, "summary")) 86 | 87 | best_val = 10000000000 88 | 89 | for epoch in range(args.epochs): 90 | print("[{} / {}]".format(epoch, args.epochs)) 91 | model.train() 92 | 93 | losses = [] 94 | correct = 0 95 | total = 0 96 | 97 | # Training Loop Start 98 | for (img1, img2), y, (class1, class2) in train_dataloader: 99 | img1, img2, y = map(lambda x: x.to(device), [img1, img2, y]) 100 | 101 | prob = model(img1, img2) 102 | loss = criterion(prob, y) 103 | 104 | optimizer.zero_grad() 105 | loss.backward() 106 | optimizer.step() 107 | 108 | losses.append(loss.item()) 109 | correct += torch.count_nonzero(y == (prob > 0.5)).item() 110 | total += len(y) 111 | 112 | writer.add_scalar('train_loss', sum(losses)/len(losses), epoch) 113 | writer.add_scalar('train_acc', correct / total, epoch) 114 | 115 | print("\tTraining: Loss={:.2f}\t Accuracy={:.2f}\t".format(sum(losses)/len(losses), correct / total)) 116 | # Training Loop End 117 | 118 | # Evaluation Loop Start 119 | model.eval() 120 | 121 | losses = [] 122 | correct = 0 123 | total = 0 124 | 125 | for (img1, img2), y, (class1, class2) in val_dataloader: 126 | img1, img2, y = map(lambda x: x.to(device), [img1, img2, y]) 127 | 128 | prob = model(img1, img2) 129 | loss = criterion(prob, y) 130 | 131 | losses.append(loss.item()) 132 | correct += torch.count_nonzero(y == (prob > 0.5)).item() 133 | total += len(y) 134 | 135 | val_loss = sum(losses)/max(1, len(losses)) 136 | writer.add_scalar('val_loss', val_loss, epoch) 137 | writer.add_scalar('val_acc', correct / total, epoch) 138 | 139 | print("\tValidation: Loss={:.2f}\t Accuracy={:.2f}\t".format(val_loss, correct / total)) 140 | # Evaluation Loop End 141 | 142 | # Update "best.pth" model if val_loss in current epoch is lower than the best validation loss 143 | if val_loss < best_val: 144 | best_val = val_loss 145 | torch.save( 146 | { 147 | "epoch": epoch + 1, 148 | "model_state_dict": model.state_dict(), 149 | "backbone": args.backbone, 150 | "optimizer_state_dict": optimizer.state_dict() 151 | }, 152 | os.path.join(args.out_path, "best.pth") 153 | ) 154 | 155 | # Save model based on the frequency defined by "args.save_after" 156 | if (epoch + 1) % args.save_after == 0: 157 | torch.save( 158 | { 159 | "epoch": epoch + 1, 160 | "model_state_dict": model.state_dict(), 161 | "backbone": args.backbone, 162 | "optimizer_state_dict": optimizer.state_dict() 163 | }, 164 | os.path.join(args.out_path, "epoch_{}.pth".format(epoch + 1)) 165 | ) --------------------------------------------------------------------------------