├── README.md ├── color.gif ├── conf_mat.png ├── main.py ├── model.py ├── scatter_color.png ├── siamese_net.png └── test.py /README.md: -------------------------------------------------------------------------------- 1 | # Siamese-Networks-Tutorial 2 | 3 | siamnet 4 | 5 | ### This a basic tutorial of siamese networks. [Here](https://medium.com/@adityadutt/siamese-networks-introduction-and-implementation-2140e3443dee) is an article written on Medium. 6 | 7 | Here is the dataset used for this project: [Multi Color and Shapes Dataset](https://github.com/AdityaDutt/MultiColor-Shapes-Database) 8 | 9 | --- 10 | 11 | ### Steps to run the file: 12 | 13 | 1. Download the dataset. 14 | 2. In the file main.py, adjust the path of the dataset. In the file model.py, adjust the path where the model will be saved. This will create data pairs and then model will be trained. 15 | 3. Run test.py to test model and display a 3d scatter plot. Adjust the saved model path. 16 | 17 | --- 18 | 19 | ### Output 20 | 21 | gif 22 | 23 | ```3d plot of features``` 24 | 25 |
26 | 27 | --- 28 | 29 | confmat 30 | 31 | ```Confusion matrix of same shape but different color``` 32 | -------------------------------------------------------------------------------- /color.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaDutt/Siamese-Networks-Tutorial/faef0e9a399649fb59a1782886e59f9122aa6e48/color.gif -------------------------------------------------------------------------------- /conf_mat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaDutt/Siamese-Networks-Tutorial/faef0e9a399649fb59a1782886e59f9122aa6e48/conf_mat.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os, sys, cv2, matplotlib.pyplot as plt, numpy as np, shutil 2 | from random import random, randint, seed 3 | import random 4 | import pickle, itertools, sklearn, pandas as pd, seaborn as sn 5 | from scipy.spatial import distance 6 | from keras.models import Model, load_model, Sequential 7 | from keras import backend as K 8 | from keras.utils.vis_utils import plot_model 9 | from scipy import spatial 10 | from sklearn.metrics import confusion_matrix 11 | 12 | import warnings 13 | warnings.filterwarnings('ignore') 14 | 15 | 16 | # Import color encoder which uses siamese networks 17 | from model import train_color_encoder 18 | 19 | 20 | 21 | 22 | ## Prepare positive and negative pais of data samples 23 | 24 | 25 | # Prepare data for different shapes but same colors 26 | 27 | dir = os.getcwd() + "/shapes/" 28 | 29 | images = [] 30 | y_col = [] 31 | 32 | for root, dirs, files in os.walk(dir, topdown=False): 33 | for name in files: 34 | fullname = os.path.join(root, name) 35 | if fullname.find(".png") != -1 : 36 | images.append(fullname) 37 | if fullname.find("red") != -1 : 38 | y_col.append(0) 39 | elif fullname.find("blue") != -1 : 40 | y_col.append(1) 41 | else : 42 | y_col.append(2) 43 | 44 | y_col = np.array(y_col) 45 | images = np.array(images) 46 | 47 | 48 | 49 | # Generate positive samples 50 | 51 | red_im = images[np.where(y_col==0)] 52 | green_im = images[np.where(y_col==1)] 53 | blue_im = images[np.where(y_col==2)] 54 | 55 | # Test images 56 | test_red_im = red_im[50:] 57 | test_green_im = green_im[50:] 58 | test_blue_im = blue_im[50:] 59 | 60 | # Read only 20 images from each class for training 61 | red_im = red_im[:20] 62 | green_im = green_im[:20] 63 | blue_im = blue_im[:20] 64 | 65 | 66 | 67 | positive_red = list(itertools.combinations(red_im, 2)) 68 | 69 | positive_blue = list(itertools.combinations(blue_im, 2)) 70 | 71 | positive_green = list(itertools.combinations(green_im, 2)) 72 | 73 | 74 | # Generate negative samples 75 | 76 | negative1 = itertools.product(red_im,green_im) 77 | negative1 = list(negative1) 78 | 79 | negative2 = itertools.product(green_im,blue_im) 80 | negative2 = list(negative2) 81 | 82 | negative3 = itertools.product(red_im,blue_im) 83 | negative3 = list(negative3) 84 | 85 | 86 | # Create pairs of images and set target label for them. Target output is 1 if pair of images have same color else it is 0. 87 | color_X1 = [] 88 | color_X2 = [] 89 | color_y = [] 90 | positive_samples = positive_blue + positive_green + positive_red 91 | negative_samples = negative1 + negative2 + negative3 92 | 93 | for fname in positive_samples : 94 | im = cv2.imread(fname[0]) 95 | color_X1.append(im) 96 | im = cv2.imread(fname[1]) 97 | color_X2.append(im) 98 | color_y.append(1) 99 | 100 | for fname in negative_samples : 101 | im = cv2.imread(fname[0]) 102 | color_X1.append(im) 103 | im = cv2.imread(fname[1]) 104 | color_X2.append(im) 105 | color_y.append(0) 106 | 107 | 108 | color_y = np.array(color_y) 109 | color_X1 = np.array(color_X1) 110 | color_X2 = np.array(color_X2) 111 | color_X1 = color_X1.reshape((len(negative_samples) + len(positive_samples), 28, 28, 3)) 112 | color_X2 = color_X2.reshape((len(negative_samples) + len(positive_samples), 28, 28, 3)) 113 | 114 | color_X1 = 1 - color_X1/255 115 | color_X2 = 1 - color_X2/255 116 | 117 | print("Color data : ", color_X1.shape, color_X2.shape, color_y.shape) 118 | 119 | # Save test data 120 | f = open(os.getcwd()+"/test_images.pkl", 'wb') 121 | pickle.dump([test_red_im, test_blue_im, test_green_im], f) 122 | f.close() 123 | 124 | # train model 125 | train_color_encoder(color_X1, color_X2, color_y) 126 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Input, Dense, InputLayer, Conv2D, MaxPooling2D, UpSampling2D, InputLayer, Concatenate, Flatten, Reshape, Lambda, Embedding, dot 2 | from keras.models import Model, load_model, Sequential 3 | import matplotlib.pyplot as plt 4 | import keras.backend as K 5 | from sklearn.model_selection import train_test_split 6 | import os, sys 7 | import tensorflow as tf 8 | from keras.utils.vis_utils import plot_model 9 | 10 | 11 | 12 | 13 | 14 | 15 | # Train autoencoder and save encoder model and encodings 16 | def train_color_encoder(X1, X2, y) : 17 | 18 | 19 | 20 | # Color Encoder 21 | input_layer = Input((28, 28, 3)) 22 | layer1 = Conv2D(16, (3, 3), activation='relu', padding='same')(input_layer) 23 | layer2 = MaxPooling2D((2, 2), padding='same')(layer1) 24 | layer3 = Conv2D(8, (3, 3), activation='relu', padding='same')(layer2) 25 | layer4 = MaxPooling2D((2, 2), padding='same')(layer3) 26 | layer5 = Flatten()(layer4) 27 | embeddings = Dense(16, activation=None)(layer5) 28 | norm_embeddings = tf.nn.l2_normalize(embeddings, axis=-1) 29 | 30 | 31 | # Create model 32 | model = Model(inputs=input_layer, outputs=norm_embeddings) 33 | 34 | 35 | # Create siamese model 36 | input1 = Input((28,28,3)) 37 | input2 = Input((28,28,3)) 38 | 39 | # Create left and right twin models 40 | left_model = model(input1) 41 | right_model = model(input2) 42 | 43 | 44 | # Dot product layer 45 | dot_product = dot([left_model, right_model], axes=1, normalize=False) 46 | 47 | siamese_model = Model(inputs=[input1, input2], outputs=dot_product) 48 | 49 | # Model summary 50 | print(siamese_model.summary()) 51 | 52 | # Compile model 53 | siamese_model.compile(optimizer='adam', loss= 'mse') 54 | 55 | # Plot flowchart fo model 56 | plot_model(siamese_model, to_file=os.getcwd()+'/siamese_model_mnist.png', show_shapes=1, show_layer_names=1) 57 | 58 | 59 | # Fit model 60 | siamese_model.fit([X1, X2], y, epochs=100, batch_size=5, shuffle=True, verbose=True) 61 | 62 | model.save(os.getcwd()+"/color_encoder.h5") 63 | siamese_model.save(os.getcwd()+"/color_siamese_model.h5") 64 | 65 | return model, siamese_model 66 | -------------------------------------------------------------------------------- /scatter_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaDutt/Siamese-Networks-Tutorial/faef0e9a399649fb59a1782886e59f9122aa6e48/scatter_color.png -------------------------------------------------------------------------------- /siamese_net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaDutt/Siamese-Networks-Tutorial/faef0e9a399649fb59a1782886e59f9122aa6e48/siamese_net.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from mpl_toolkits.mplot3d import Axes3D 2 | import os, sys, cv2, matplotlib.pyplot as plt, numpy as np, pickle 3 | import sklearn, pandas as pd, seaborn as sn 4 | from keras.models import Model, load_model, Sequential 5 | from keras import backend as K 6 | from sklearn.metrics import confusion_matrix 7 | 8 | import warnings 9 | warnings.filterwarnings('ignore') 10 | 11 | 12 | model = load_model(os.getcwd()+"/color_encoder.h5") 13 | siamese_model = load_model(os.getcwd()+"/color_siamese_model.h5") 14 | 15 | 16 | # Load test data 17 | f = open(os.getcwd()+"/test_images.pkl", 'rb') 18 | test_red_im, test_blue_im, test_green_im = pickle.load(f) 19 | f.close() 20 | 21 | 22 | names = list(test_red_im) + list(test_blue_im) + list(test_green_im)# + list(test_cyan_im) #+ list(test_yellow_im) 23 | 24 | names1 = [x for x in names if 'red' in x] 25 | names2 = [x for x in names if 'blue' in x] 26 | names3 = [x for x in names if 'green' in x] 27 | 28 | test_im = [] 29 | for i in range(len(names)) : 30 | test_im.append(cv2.imread(names[i])) 31 | 32 | r,c,_ = test_im[0].shape 33 | test_im = np.array(test_im) 34 | test_im = test_im.reshape((len(test_im), r,c,3)) 35 | names = [x.split("/")[-1] for x in names] 36 | 37 | test_im = 1 - test_im/255 38 | 39 | pred = model.predict(test_im) 40 | 41 | num = int(pred.shape[0]/3) 42 | colors = ['red', 'blue', 'green'] 43 | y = [colors[0] for i in range(num)] 44 | y += [colors[1] for i in range(num)] 45 | y += [colors[2] for i in range(num)] 46 | 47 | feat1 = pred[:,0] 48 | feat2 = pred[:,1] 49 | feat3 = pred[:,2] 50 | 51 | 52 | fig = plt.figure() 53 | ax = Axes3D(fig) 54 | ax.scatter(feat1, feat2, feat3, c=y, marker='.') 55 | plt.show() 56 | --------------------------------------------------------------------------------