├── .gitignore ├── LICENSE ├── README.md ├── _config.yml ├── data ├── class_vectors.npy └── zeroshot_data.pkl ├── model ├── model.h5 └── model.json ├── src ├── detect_object.py ├── feature_extractor.py ├── train.py ├── train_classes.txt └── zsl_classes.txt └── test.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | __pycache__ 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Samet Çetin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # zero-shot-learning 2 | Implementation of Zero-Shot Learning algorithm 3 | 4 | Zero-Shot learning method aims to solve a task without receiving any example of that task at training phase. 5 | It simply allows us to recognize objects *we have not seen before*. 6 | 7 | Check the Medium story that I wrote for details: https://medium.com/@cetinsamet/zero-shot-learning-53080995d45f 8 | 9 | ### Classes 10 | **Train Classes:** 11 | arm, boy, bread, chicken, child, computer, ear, house, leg, sandwich, television, truck, vehicle, watch, woman 12 | **Zero-Shot Classes:** 13 | car, food, hand, man, neck 14 | 15 | ## Usage 16 | $**python3** detect_object.py input-image-path 17 | 18 | ### Example 19 | $**cd** src 20 | $**python3** detect_object.py ../test.jpg 21 | **->** --- Top-5 Prediction --- 22 | **->** 1- vehicle 23 | **->** 2- truck 24 | **->** 3- car 25 | **->** 4- house 26 | **->** 5- chicken 27 | 28 | ![Example Image](https://github.com/cetinsamet/zero-shot-learning/blob/master/test.jpg) 29 | *Test image is a beautiful green Jaguar E-Type.* 30 | *All related prediction results are ranked in first three.* 31 | 32 | P.S. Remember, the prediction results are only allowed to be among above classes (train and zero-shot classes). 33 | Algorithm will fail (although it will do its best to predict most related class) in case you try to detect an object from different other classes. 34 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-minimal -------------------------------------------------------------------------------- /data/class_vectors.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cetinsamet/zero-shot-learning/8a62c44f2db6c5d9344d0f38852c6e8d8a13a286/data/class_vectors.npy -------------------------------------------------------------------------------- /data/zeroshot_data.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cetinsamet/zero-shot-learning/8a62c44f2db6c5d9344d0f38852c6e8d8a13a286/data/zeroshot_data.pkl -------------------------------------------------------------------------------- /model/model.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cetinsamet/zero-shot-learning/8a62c44f2db6c5d9344d0f38852c6e8d8a13a286/model/model.h5 -------------------------------------------------------------------------------- /model/model.json: -------------------------------------------------------------------------------- 1 | {"class_name": "Model", "config": {"name": "model_6", "layers": [{"name": "dense_72_input", "class_name": "InputLayer", "config": {"batch_input_shape": [null, 4096], "dtype": "float32", "sparse": false, "name": "dense_72_input"}, "inbound_nodes": []}, {"name": "dense_72", "class_name": "Dense", "config": {"name": "dense_72", "trainable": true, "batch_input_shape": [null, 4096], "dtype": "float32", "units": 1024, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["dense_72_input", 0, 0, {}]]]}, {"name": "batch_normalization_6", "class_name": "BatchNormalization", "config": {"name": "batch_normalization_6", "trainable": true, "axis": -1, "momentum": 0.99, "epsilon": 0.001, "center": true, "scale": true, "beta_initializer": {"class_name": "Zeros", "config": {}}, "gamma_initializer": {"class_name": "Ones", "config": {}}, "moving_mean_initializer": {"class_name": "Zeros", "config": {}}, "moving_variance_initializer": {"class_name": "Ones", "config": {}}, "beta_regularizer": null, "gamma_regularizer": null, "beta_constraint": null, "gamma_constraint": null}, "inbound_nodes": [[["dense_72", 0, 0, {}]]]}, {"name": "dropout_25", "class_name": "Dropout", "config": {"name": "dropout_25", "trainable": true, "rate": 0.8, "noise_shape": null, "seed": null}, "inbound_nodes": [[["batch_normalization_6", 0, 0, {}]]]}, {"name": "dense_73", "class_name": "Dense", "config": {"name": "dense_73", "trainable": true, "units": 512, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["dropout_25", 0, 0, {}]]]}, {"name": "dropout_26", "class_name": "Dropout", "config": {"name": "dropout_26", "trainable": true, "rate": 0.5, "noise_shape": null, "seed": null}, "inbound_nodes": [[["dense_73", 0, 0, {}]]]}, {"name": "dense_74", "class_name": "Dense", "config": {"name": "dense_74", "trainable": true, "units": 256, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["dropout_26", 0, 0, {}]]]}, {"name": "dense_75", "class_name": "Dense", "config": {"name": "dense_75", "trainable": true, "units": 300, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["dense_74", 0, 0, {}]]]}], "input_layers": [["dense_72_input", 0, 0]], "output_layers": [["dense_75", 0, 0]]}, "keras_version": "2.1.6", "backend": "tensorflow"} -------------------------------------------------------------------------------- /src/detect_object.py: -------------------------------------------------------------------------------- 1 | # 2 | # detect_object.py 3 | # 4 | # Created by Samet Cetin. 5 | # Contact: cetin.samet@outlook.com 6 | # 7 | 8 | import sys 9 | import numpy as np 10 | from PIL import Image 11 | 12 | from sklearn.neighbors import KDTree 13 | from sklearn.preprocessing import normalize 14 | 15 | from feature_extractor import get_model, get_features 16 | from train import load_keras_model 17 | 18 | 19 | WORD2VECPATH = "../data/class_vectors.npy" 20 | MODELPATH = "../model/" 21 | 22 | def main(argv): 23 | 24 | if len(argv) != 1: 25 | print("Usage: python3 detect_object.py input-image-path") 26 | exit() 27 | 28 | # READ IMAGE 29 | IMAGEPATH = argv[0] 30 | img = Image.open(IMAGEPATH).resize((224, 224)) 31 | 32 | # LOAD PRETRAINED VGG16 MODEL FOR FEATURE EXTRACTION 33 | vgg_model = get_model() 34 | # EXTRACT IMAGE FEATURE 35 | img_feature = get_features(vgg_model, img) 36 | # L2 NORMALIZE FEATURE 37 | img_feature = normalize(img_feature, norm='l2') 38 | 39 | # LOAD ZERO-SHOT MODEL 40 | model = load_keras_model(model_path=MODELPATH) 41 | # MAKE PREDICTION 42 | pred = model.predict(img_feature) 43 | 44 | # LOAD CLASS WORD2VECS 45 | class_vectors = sorted(np.load(WORD2VECPATH), key=lambda x: x[0]) 46 | classnames, vectors = zip(*class_vectors) 47 | classnames = list(classnames) 48 | vectors = np.asarray(vectors, dtype=np.float) 49 | 50 | # PLACE WORD2VECS IN KDTREE 51 | tree = KDTree(vectors) 52 | # FIND CLOSEST WORD2VEC and GET PREDICTION RESULT 53 | dist, index = tree.query(pred, k=5) 54 | pred_labels = [classnames[idx] for idx in index[0]] 55 | 56 | # PRINT RESULT 57 | print() 58 | print("--- Top-5 Prediction ---") 59 | for i, classname in enumerate(pred_labels): 60 | print("%d- %s" %(i+1, classname)) 61 | print() 62 | return 63 | 64 | if __name__ == '__main__': 65 | main(sys.argv[1:]) -------------------------------------------------------------------------------- /src/feature_extractor.py: -------------------------------------------------------------------------------- 1 | # 2 | # feature_extractor.py 3 | # 4 | # Created by Samet Cetin. 5 | # Contact: cetin.samet@outlook.com 6 | # 7 | 8 | import keras 9 | from keras import backend as K 10 | from keras.preprocessing import image 11 | from keras.models import Model 12 | 13 | import tensorflow as tf 14 | from tensorflow.python.tools import freeze_graph, optimize_for_inference_lib 15 | 16 | import numpy as np 17 | import os 18 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 19 | 20 | 21 | 22 | def get_model(): 23 | vgg_model = keras.applications.VGG16(include_top=True, weights='imagenet') 24 | vgg_model.layers.pop() 25 | vgg_model.layers.pop() 26 | 27 | inp = vgg_model.input 28 | out = vgg_model.layers[-1].output 29 | 30 | model = Model(inp, out) 31 | return model 32 | 33 | def get_features(model, cropped_image): 34 | x = image.img_to_array(cropped_image) 35 | x = np.expand_dims(x, axis=0) 36 | x = keras.applications.vgg16.preprocess_input(x) 37 | features = model.predict(x) 38 | return features 39 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | # 2 | # train.py 3 | # 4 | # Created by Samet Cetin. 5 | # Contact: cetin.samet@outlook.com 6 | # 7 | 8 | import numpy as np 9 | np.random.seed(123) 10 | import gzip 11 | import _pickle as cPickle 12 | import os 13 | from collections import Counter 14 | 15 | from sklearn.preprocessing import LabelEncoder, normalize 16 | from sklearn.neighbors import KDTree 17 | 18 | from keras.models import Sequential, Model, model_from_json 19 | from keras.layers import Dense, Dropout 20 | from keras.layers import BatchNormalization 21 | from keras.optimizers import Adam 22 | from keras.utils import to_categorical 23 | 24 | 25 | WORD2VECPATH = "../data/class_vectors.npy" 26 | DATAPATH = "../data/zeroshot_data.pkl" 27 | MODELPATH = "../model/" 28 | 29 | def load_keras_model(model_path): 30 | with open(model_path +"model.json", 'r') as json_file: 31 | loaded_model_json = json_file.read() 32 | 33 | loaded_model = model_from_json(loaded_model_json) 34 | # load weights into new model 35 | loaded_model.load_weights(model_path+"model.h5") 36 | return loaded_model 37 | 38 | def save_keras_model(model, model_path): 39 | """save Keras model and its weights""" 40 | if not os.path.exists(model_path): 41 | os.makedirs(model_path) 42 | 43 | model_json = model.to_json() 44 | with open(model_path + "model.json", "w") as json_file: 45 | json_file.write(model_json) 46 | 47 | # serialize weights to HDF5 48 | model.save_weights(model_path + "model.h5") 49 | print("-> zsl model is saved.") 50 | return 51 | 52 | def load_data(): 53 | """read data, create datasets""" 54 | # READ DATA 55 | with gzip.GzipFile(DATAPATH, 'rb') as infile: 56 | data = cPickle.load(infile) 57 | 58 | # ONE-HOT-ENCODE DATA 59 | label_encoder = LabelEncoder() 60 | label_encoder.fit(train_classes) 61 | 62 | training_data = [instance for instance in data if instance[0] in train_classes] 63 | zero_shot_data = [instance for instance in data if instance[0] not in train_classes] 64 | # SHUFFLE TRAINING DATA 65 | np.random.shuffle(training_data) 66 | 67 | ### SPLIT DATA FOR TRAINING 68 | train_size = 300 69 | train_data = list() 70 | valid_data = list() 71 | for class_label in train_classes: 72 | ct = 0 73 | for instance in training_data: 74 | if instance[0] == class_label: 75 | if ct < train_size: 76 | train_data.append(instance) 77 | ct+=1 78 | continue 79 | valid_data.append(instance) 80 | 81 | # SHUFFLE TRAINING AND VALIDATION DATA 82 | np.random.shuffle(train_data) 83 | np.random.shuffle(valid_data) 84 | 85 | train_data = [(instance[1], to_categorical(label_encoder.transform([instance[0]]), num_classes=15))for instance in train_data] 86 | valid_data = [(instance[1], to_categorical(label_encoder.transform([instance[0]]), num_classes=15)) for instance in valid_data] 87 | 88 | # FORM X_TRAIN AND Y_TRAIN 89 | x_train, y_train = zip(*train_data) 90 | x_train, y_train = np.squeeze(np.asarray(x_train)), np.squeeze(np.asarray(y_train)) 91 | # L2 NORMALIZE X_TRAIN 92 | x_train = normalize(x_train, norm='l2') 93 | 94 | # FORM X_VALID AND Y_VALID 95 | x_valid, y_valid = zip(*valid_data) 96 | x_valid, y_valid = np.squeeze(np.asarray(x_valid)), np.squeeze(np.asarray(y_valid)) 97 | # L2 NORMALIZE X_VALID 98 | x_valid = normalize(x_valid, norm='l2') 99 | 100 | 101 | # FORM X_ZSL AND Y_ZSL 102 | y_zsl, x_zsl = zip(*zero_shot_data) 103 | x_zsl, y_zsl = np.squeeze(np.asarray(x_zsl)), np.squeeze(np.asarray(y_zsl)) 104 | # L2 NORMALIZE X_ZSL 105 | x_zsl = normalize(x_zsl, norm='l2') 106 | 107 | print("-> data loading is completed.") 108 | return (x_train, x_valid, x_zsl), (y_train, y_valid, y_zsl) 109 | 110 | 111 | def custom_kernel_init(shape): 112 | class_vectors = np.load(WORD2VECPATH) 113 | training_vectors = sorted([(label, vec) for (label, vec) in class_vectors if label in train_classes], key=lambda x: x[0]) 114 | classnames, vectors = zip(*training_vectors) 115 | vectors = np.asarray(vectors, dtype=np.float) 116 | vectors = vectors.T 117 | return vectors 118 | 119 | def build_model(): 120 | model = Sequential() 121 | model.add(Dense(1024, input_shape=(4096,), activation='relu')) 122 | model.add(BatchNormalization()) 123 | model.add(Dropout(0.8)) 124 | model.add(Dense(512, activation='relu')) 125 | model.add(Dropout(0.5)) 126 | model.add(Dense(256, activation='relu')) 127 | model.add(Dense(NUM_ATTR, activation='relu')) 128 | model.add(Dense(NUM_CLASS, activation='softmax', trainable=False, kernel_initializer=custom_kernel_init)) 129 | 130 | print("-> model building is completed.") 131 | return model 132 | 133 | 134 | def train_model(model, train_data, valid_data): 135 | x_train, y_train = train_data 136 | x_valid, y_valid = valid_data 137 | adam = Adam(lr=5e-5) 138 | model.compile(loss = 'categorical_crossentropy', 139 | optimizer = adam, 140 | metrics = ['categorical_accuracy', 'top_k_categorical_accuracy']) 141 | 142 | history = model.fit(x_train, y_train, 143 | validation_data = (x_valid, y_valid), 144 | verbose = 2, 145 | epochs = EPOCH, 146 | batch_size = BATCH_SIZE, 147 | shuffle = True) 148 | 149 | print("model training is completed.") 150 | return history 151 | 152 | def main(): 153 | 154 | global train_classes 155 | with open('train_classes.txt', 'r') as infile: 156 | train_classes = [str.strip(line) for line in infile] 157 | 158 | global zsl_classes 159 | with open('zsl_classes.txt', 'r') as infile: 160 | zsl_classes = [str.strip(line) for line in infile] 161 | 162 | # ---------------------------------------------------------------------------------------------------------------- # 163 | # ---------------------------------------------------------------------------------------------------------------- # 164 | # SET HYPERPARAMETERS 165 | 166 | global NUM_CLASS, NUM_ATTR, EPOCH, BATCH_SIZE 167 | NUM_CLASS = 15 168 | NUM_ATTR = 300 169 | BATCH_SIZE = 128 170 | EPOCH = 65 171 | 172 | # ---------------------------------------------------------------------------------------------------------------- # 173 | # ---------------------------------------------------------------------------------------------------------------- # 174 | # TRAINING PHASE 175 | 176 | (x_train, x_valid, x_zsl), (y_train, y_valid, y_zsl) = load_data() 177 | model = build_model() 178 | train_model(model, (x_train, y_train), (x_valid, y_valid)) 179 | print(model.summary()) 180 | 181 | # ---------------------------------------------------------------------------------------------------------------- # 182 | # ---------------------------------------------------------------------------------------------------------------- # 183 | # CREATE AND SAVE ZSL MODEL 184 | 185 | inp = model.input 186 | out = model.layers[-2].output 187 | zsl_model = Model(inp, out) 188 | print(zsl_model.summary()) 189 | save_keras_model(zsl_model, model_path=MODELPATH) 190 | 191 | # ---------------------------------------------------------------------------------------------------------------- # 192 | # ---------------------------------------------------------------------------------------------------------------- # 193 | # EVALUATION OF ZERO-SHOT LEARNING PERFORMANCE 194 | #(x_train, x_valid, x_zsl), (y_train, y_valid, y_zsl) = load_data() 195 | #zsl_model = load_keras_model(model_path=MODELPATH) 196 | 197 | class_vectors = sorted(np.load(WORD2VECPATH), key=lambda x: x[0]) 198 | classnames, vectors = zip(*class_vectors) 199 | classnames = list(classnames) 200 | vectors = np.asarray(vectors, dtype=np.float) 201 | 202 | tree = KDTree(vectors) 203 | pred_zsl = zsl_model.predict(x_zsl) 204 | 205 | top5, top3, top1 = 0, 0, 0 206 | for i, pred in enumerate(pred_zsl): 207 | pred = np.expand_dims(pred, axis=0) 208 | dist_5, index_5 = tree.query(pred, k=5) 209 | pred_labels = [classnames[index] for index in index_5[0]] 210 | true_label = y_zsl[i] 211 | if true_label in pred_labels: 212 | top5 += 1 213 | if true_label in pred_labels[:3]: 214 | top3 += 1 215 | if true_label in pred_labels[0]: 216 | top1 += 1 217 | 218 | print() 219 | print("ZERO SHOT LEARNING SCORE") 220 | print("-> Top-5 Accuracy: %.2f" % (top5 / float(len(x_zsl)))) 221 | print("-> Top-3 Accuracy: %.2f" % (top3 / float(len(x_zsl)))) 222 | print("-> Top-1 Accuracy: %.2f" % (top1 / float(len(x_zsl)))) 223 | return 224 | 225 | if __name__ == '__main__': 226 | main() 227 | -------------------------------------------------------------------------------- /src/train_classes.txt: -------------------------------------------------------------------------------- 1 | arm 2 | boy 3 | bread 4 | chicken 5 | child 6 | computer 7 | ear 8 | house 9 | leg 10 | sandwich 11 | television 12 | truck 13 | vehicle 14 | watch 15 | woman -------------------------------------------------------------------------------- /src/zsl_classes.txt: -------------------------------------------------------------------------------- 1 | car 2 | food 3 | hand 4 | man 5 | neck 6 | -------------------------------------------------------------------------------- /test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cetinsamet/zero-shot-learning/8a62c44f2db6c5d9344d0f38852c6e8d8a13a286/test.jpg --------------------------------------------------------------------------------