├── README.md
├── color.gif
├── conf_mat.png
├── main.py
├── model.py
├── scatter_color.png
├── siamese_net.png
└── test.py
/README.md:
--------------------------------------------------------------------------------
1 | # Siamese-Networks-Tutorial
2 |
3 |
4 |
5 | ### This a basic tutorial of siamese networks. [Here](https://medium.com/@adityadutt/siamese-networks-introduction-and-implementation-2140e3443dee) is an article written on Medium.
6 |
7 | Here is the dataset used for this project: [Multi Color and Shapes Dataset](https://github.com/AdityaDutt/MultiColor-Shapes-Database)
8 |
9 | ---
10 |
11 | ### Steps to run the file:
12 |
13 | 1. Download the dataset.
14 | 2. In the file main.py, adjust the path of the dataset. In the file model.py, adjust the path where the model will be saved. This will create data pairs and then model will be trained.
15 | 3. Run test.py to test model and display a 3d scatter plot. Adjust the saved model path.
16 |
17 | ---
18 |
19 | ### Output
20 |
21 |
22 |
23 | ```3d plot of features```
24 |
25 |
26 |
27 | ---
28 |
29 |
30 |
31 | ```Confusion matrix of same shape but different color```
32 |
--------------------------------------------------------------------------------
/color.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdityaDutt/Siamese-Networks-Tutorial/faef0e9a399649fb59a1782886e59f9122aa6e48/color.gif
--------------------------------------------------------------------------------
/conf_mat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdityaDutt/Siamese-Networks-Tutorial/faef0e9a399649fb59a1782886e59f9122aa6e48/conf_mat.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os, sys, cv2, matplotlib.pyplot as plt, numpy as np, shutil
2 | from random import random, randint, seed
3 | import random
4 | import pickle, itertools, sklearn, pandas as pd, seaborn as sn
5 | from scipy.spatial import distance
6 | from keras.models import Model, load_model, Sequential
7 | from keras import backend as K
8 | from keras.utils.vis_utils import plot_model
9 | from scipy import spatial
10 | from sklearn.metrics import confusion_matrix
11 |
12 | import warnings
13 | warnings.filterwarnings('ignore')
14 |
15 |
16 | # Import color encoder which uses siamese networks
17 | from model import train_color_encoder
18 |
19 |
20 |
21 |
22 | ## Prepare positive and negative pais of data samples
23 |
24 |
25 | # Prepare data for different shapes but same colors
26 |
27 | dir = os.getcwd() + "/shapes/"
28 |
29 | images = []
30 | y_col = []
31 |
32 | for root, dirs, files in os.walk(dir, topdown=False):
33 | for name in files:
34 | fullname = os.path.join(root, name)
35 | if fullname.find(".png") != -1 :
36 | images.append(fullname)
37 | if fullname.find("red") != -1 :
38 | y_col.append(0)
39 | elif fullname.find("blue") != -1 :
40 | y_col.append(1)
41 | else :
42 | y_col.append(2)
43 |
44 | y_col = np.array(y_col)
45 | images = np.array(images)
46 |
47 |
48 |
49 | # Generate positive samples
50 |
51 | red_im = images[np.where(y_col==0)]
52 | green_im = images[np.where(y_col==1)]
53 | blue_im = images[np.where(y_col==2)]
54 |
55 | # Test images
56 | test_red_im = red_im[50:]
57 | test_green_im = green_im[50:]
58 | test_blue_im = blue_im[50:]
59 |
60 | # Read only 20 images from each class for training
61 | red_im = red_im[:20]
62 | green_im = green_im[:20]
63 | blue_im = blue_im[:20]
64 |
65 |
66 |
67 | positive_red = list(itertools.combinations(red_im, 2))
68 |
69 | positive_blue = list(itertools.combinations(blue_im, 2))
70 |
71 | positive_green = list(itertools.combinations(green_im, 2))
72 |
73 |
74 | # Generate negative samples
75 |
76 | negative1 = itertools.product(red_im,green_im)
77 | negative1 = list(negative1)
78 |
79 | negative2 = itertools.product(green_im,blue_im)
80 | negative2 = list(negative2)
81 |
82 | negative3 = itertools.product(red_im,blue_im)
83 | negative3 = list(negative3)
84 |
85 |
86 | # Create pairs of images and set target label for them. Target output is 1 if pair of images have same color else it is 0.
87 | color_X1 = []
88 | color_X2 = []
89 | color_y = []
90 | positive_samples = positive_blue + positive_green + positive_red
91 | negative_samples = negative1 + negative2 + negative3
92 |
93 | for fname in positive_samples :
94 | im = cv2.imread(fname[0])
95 | color_X1.append(im)
96 | im = cv2.imread(fname[1])
97 | color_X2.append(im)
98 | color_y.append(1)
99 |
100 | for fname in negative_samples :
101 | im = cv2.imread(fname[0])
102 | color_X1.append(im)
103 | im = cv2.imread(fname[1])
104 | color_X2.append(im)
105 | color_y.append(0)
106 |
107 |
108 | color_y = np.array(color_y)
109 | color_X1 = np.array(color_X1)
110 | color_X2 = np.array(color_X2)
111 | color_X1 = color_X1.reshape((len(negative_samples) + len(positive_samples), 28, 28, 3))
112 | color_X2 = color_X2.reshape((len(negative_samples) + len(positive_samples), 28, 28, 3))
113 |
114 | color_X1 = 1 - color_X1/255
115 | color_X2 = 1 - color_X2/255
116 |
117 | print("Color data : ", color_X1.shape, color_X2.shape, color_y.shape)
118 |
119 | # Save test data
120 | f = open(os.getcwd()+"/test_images.pkl", 'wb')
121 | pickle.dump([test_red_im, test_blue_im, test_green_im], f)
122 | f.close()
123 |
124 | # train model
125 | train_color_encoder(color_X1, color_X2, color_y)
126 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Input, Dense, InputLayer, Conv2D, MaxPooling2D, UpSampling2D, InputLayer, Concatenate, Flatten, Reshape, Lambda, Embedding, dot
2 | from keras.models import Model, load_model, Sequential
3 | import matplotlib.pyplot as plt
4 | import keras.backend as K
5 | from sklearn.model_selection import train_test_split
6 | import os, sys
7 | import tensorflow as tf
8 | from keras.utils.vis_utils import plot_model
9 |
10 |
11 |
12 |
13 |
14 |
15 | # Train autoencoder and save encoder model and encodings
16 | def train_color_encoder(X1, X2, y) :
17 |
18 |
19 |
20 | # Color Encoder
21 | input_layer = Input((28, 28, 3))
22 | layer1 = Conv2D(16, (3, 3), activation='relu', padding='same')(input_layer)
23 | layer2 = MaxPooling2D((2, 2), padding='same')(layer1)
24 | layer3 = Conv2D(8, (3, 3), activation='relu', padding='same')(layer2)
25 | layer4 = MaxPooling2D((2, 2), padding='same')(layer3)
26 | layer5 = Flatten()(layer4)
27 | embeddings = Dense(16, activation=None)(layer5)
28 | norm_embeddings = tf.nn.l2_normalize(embeddings, axis=-1)
29 |
30 |
31 | # Create model
32 | model = Model(inputs=input_layer, outputs=norm_embeddings)
33 |
34 |
35 | # Create siamese model
36 | input1 = Input((28,28,3))
37 | input2 = Input((28,28,3))
38 |
39 | # Create left and right twin models
40 | left_model = model(input1)
41 | right_model = model(input2)
42 |
43 |
44 | # Dot product layer
45 | dot_product = dot([left_model, right_model], axes=1, normalize=False)
46 |
47 | siamese_model = Model(inputs=[input1, input2], outputs=dot_product)
48 |
49 | # Model summary
50 | print(siamese_model.summary())
51 |
52 | # Compile model
53 | siamese_model.compile(optimizer='adam', loss= 'mse')
54 |
55 | # Plot flowchart fo model
56 | plot_model(siamese_model, to_file=os.getcwd()+'/siamese_model_mnist.png', show_shapes=1, show_layer_names=1)
57 |
58 |
59 | # Fit model
60 | siamese_model.fit([X1, X2], y, epochs=100, batch_size=5, shuffle=True, verbose=True)
61 |
62 | model.save(os.getcwd()+"/color_encoder.h5")
63 | siamese_model.save(os.getcwd()+"/color_siamese_model.h5")
64 |
65 | return model, siamese_model
66 |
--------------------------------------------------------------------------------
/scatter_color.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdityaDutt/Siamese-Networks-Tutorial/faef0e9a399649fb59a1782886e59f9122aa6e48/scatter_color.png
--------------------------------------------------------------------------------
/siamese_net.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdityaDutt/Siamese-Networks-Tutorial/faef0e9a399649fb59a1782886e59f9122aa6e48/siamese_net.png
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from mpl_toolkits.mplot3d import Axes3D
2 | import os, sys, cv2, matplotlib.pyplot as plt, numpy as np, pickle
3 | import sklearn, pandas as pd, seaborn as sn
4 | from keras.models import Model, load_model, Sequential
5 | from keras import backend as K
6 | from sklearn.metrics import confusion_matrix
7 |
8 | import warnings
9 | warnings.filterwarnings('ignore')
10 |
11 |
12 | model = load_model(os.getcwd()+"/color_encoder.h5")
13 | siamese_model = load_model(os.getcwd()+"/color_siamese_model.h5")
14 |
15 |
16 | # Load test data
17 | f = open(os.getcwd()+"/test_images.pkl", 'rb')
18 | test_red_im, test_blue_im, test_green_im = pickle.load(f)
19 | f.close()
20 |
21 |
22 | names = list(test_red_im) + list(test_blue_im) + list(test_green_im)# + list(test_cyan_im) #+ list(test_yellow_im)
23 |
24 | names1 = [x for x in names if 'red' in x]
25 | names2 = [x for x in names if 'blue' in x]
26 | names3 = [x for x in names if 'green' in x]
27 |
28 | test_im = []
29 | for i in range(len(names)) :
30 | test_im.append(cv2.imread(names[i]))
31 |
32 | r,c,_ = test_im[0].shape
33 | test_im = np.array(test_im)
34 | test_im = test_im.reshape((len(test_im), r,c,3))
35 | names = [x.split("/")[-1] for x in names]
36 |
37 | test_im = 1 - test_im/255
38 |
39 | pred = model.predict(test_im)
40 |
41 | num = int(pred.shape[0]/3)
42 | colors = ['red', 'blue', 'green']
43 | y = [colors[0] for i in range(num)]
44 | y += [colors[1] for i in range(num)]
45 | y += [colors[2] for i in range(num)]
46 |
47 | feat1 = pred[:,0]
48 | feat2 = pred[:,1]
49 | feat3 = pred[:,2]
50 |
51 |
52 | fig = plt.figure()
53 | ax = Axes3D(fig)
54 | ax.scatter(feat1, feat2, feat3, c=y, marker='.')
55 | plt.show()
56 |
--------------------------------------------------------------------------------