├── lib ├── bt_net │ ├── __init__.py │ ├── nms.py │ ├── config.py │ ├── generate_boxes.py │ ├── visualize_graph.py │ ├── faster_rcnn_boxes.py │ ├── test_bts.py │ ├── train_bts.py │ └── evaluate_results.py ├── bt_datasets │ ├── __init__.py │ ├── imdb.py │ ├── factory.py │ ├── places365.py │ ├── open_images.py │ ├── places205.py │ ├── nus_wide.py │ ├── imagenet.py │ └── imagenet1k.py ├── preprocessing │ ├── __init__.py │ ├── preprocess_images.py │ ├── fix_results.py │ ├── preprocess_pkl_proposals.py │ ├── prepare_data.py │ ├── train_places_file.py │ ├── clean_imagenet.py │ ├── load_images.py │ ├── imagenet_imagelist_cleaning.py │ ├── load_langmod.py │ ├── clean_openimages.py │ └── clean_nus_wide.py └── language_models │ ├── __init__.py │ ├── lmdb.py │ ├── language_factory.py │ └── glove_factory.py ├── tools ├── README.md ├── __init__.py ├── _init_paths.py ├── visualize_graph.py ├── evaluate_results.py ├── train_brute_force.py ├── visualize_space.py ├── train_ml_brute_force.py ├── test_brute_force.py └── downloader.py ├── data ├── lm_data │ └── README.md └── image_data │ └── README.md ├── scripts ├── train_spaces_bts.sh ├── predict_spaces.sh ├── train_losses_bts.sh.save ├── train_losses_bts.sh └── train_ml_bruteforce.sh ├── .gitignore └── README.md /lib/bt_net/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/README.md: -------------------------------------------------------------------------------- 1 | Tools for training and generating... 2 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Made by Bjoernar Remmen 3 | # Short is better -------------------------------------------------------------------------------- /lib/bt_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # /usr/local/bin/python3.5 2 | # Made by Bjoernar Remmen 3 | # Short is better -------------------------------------------------------------------------------- /lib/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | # /usr/local/bin/python3.5 2 | # Made by Bjoernar Remmen 3 | # Short is better -------------------------------------------------------------------------------- /lib/language_models/__init__.py: -------------------------------------------------------------------------------- 1 | # /usr/local/bin/python3.5 2 | # Made by Bjoernar Remmen 3 | # Short is better -------------------------------------------------------------------------------- /data/lm_data/README.md: -------------------------------------------------------------------------------- 1 | Create symlinks to populate this folder with word vectors 2 | 3 | Example of structure: 4 | ``` 5 | lm_data 6 | └─── glove 7 | │ glove_wiki_50.txt (word vectors trained on wikipedia using glove) 8 | └─── word2vec 9 | | word2vec_wiki_50.txt (word vectors trained on wikipedia using word2vec) 10 | ``` 11 | -------------------------------------------------------------------------------- /lib/language_models/lmdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import scipy.sparse 5 | import bt_net.config as cfg 6 | 7 | class lmdb(object): 8 | """ 9 | Language model database 10 | """ 11 | 12 | def __init__(self, name): 13 | self._name = name 14 | 15 | @property 16 | def name(self): 17 | return self._name 18 | -------------------------------------------------------------------------------- /tools/_init_paths.py: -------------------------------------------------------------------------------- 1 | ### Set up paths 2 | 3 | import os.path as osp 4 | import sys 5 | 6 | 7 | def add_path(path): 8 | if path not in sys.path: 9 | sys.path.insert(0, path) 10 | 11 | 12 | this_dir = osp.dirname(__file__) 13 | 14 | ## Including lib folder 15 | lib_path = osp.join(this_dir, '..', 'lib') 16 | add_path(lib_path) 17 | 18 | zsl_root = osp.join(this_dir, '..') 19 | add_path(zsl_root) 20 | -------------------------------------------------------------------------------- /data/image_data/README.md: -------------------------------------------------------------------------------- 1 | Folder to populate image data for training and testing, similar structure as py-faster-rcnn: 2 | 3 | ``` 4 | ILSVRC13 5 | └─── Images 6 | │ *.JPEG (Image files, ex:ILSVRC2013_val_00000565.JPEG) 7 | └─── Annotations 8 | | *.xml 9 | └─── ImageSets 10 | │ train.txt 11 | └─── Generated_proposals 12 | │ vgg_cnn_m_1024_rpn_stage1_iter_90000_proposals.pkl 13 | ``` 14 | 15 | Follow [this](https://github.com/deboc/py-faster-rcnn/tree/master/help) documentation on how to train on own dataset on [py-faster-rcnn](https://github.com/rbgirshick/py-faster-rcnn). Remember when training with on zero shot to add a imageset that contains the untrained classes. 16 | -------------------------------------------------------------------------------- /lib/preprocessing/preprocess_images.py: -------------------------------------------------------------------------------- 1 | from keras.applications.inception_v3 import preprocess_input 2 | from keras.preprocessing import image 3 | from PIL import Image 4 | 5 | import numpy as np 6 | 7 | 8 | def preprocess_array(img_array, target_size=(299, 299)): 9 | img = Image.fromarray(np.uint8(img_array)) 10 | img = img.resize(target_size) 11 | x = image.img_to_array(img) 12 | x = np.expand_dims(x, axis=0) 13 | x = preprocess_input(x) 14 | return np.squeeze(np.array(x, dtype=np.float32)) 15 | 16 | 17 | def read_img(img_path, target_size=None): 18 | img = image.load_img(img_path, target_size=target_size) 19 | x = image.img_to_array(img) 20 | x = np.expand_dims(x, axis=0) 21 | x = preprocess_input(x) 22 | return np.squeeze(np.array(x, dtype=np.float32)) 23 | -------------------------------------------------------------------------------- /lib/preprocessing/fix_results.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def fix_file(filename): 4 | with open(filename + ".txt") as read, open(filename + "fixed.txt","w") as wrt: 5 | i = 0 6 | for line in read: 7 | if i%10000 == 0: print "lines checked: ", i 8 | line = line.rstrip().split(",") 9 | processed_line = [] 10 | for item in line: 11 | if item.startswith(" ['") and item.endswith("]"): 12 | processed_line.append(" "+item[3:-2]) 13 | else: 14 | processed_line.append(item) 15 | line = ",".join(processed_line) 16 | wrt.write(line + "\n") 17 | i += 1 18 | 19 | if __name__ == "__main__": 20 | fix_file('results_nus_wide_Test_zs_us_img_lbl_w2v_wiki_300D_ml_yolo_squared_hinge_2_l2_imagenet') 21 | -------------------------------------------------------------------------------- /scripts/train_spaces_bts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # Usage: 3 | # ./scripts/train_spaces_bts.sh DATASET LOSS function 4 | # DATASET is either imagenet or imagenet1k. 5 | # 6 | # Example: 7 | # ./scripts/train_spaces_bts.sh imagenet1k squared_hinge 8 | 9 | set -x 10 | set -e 11 | 12 | export PYTHONUNBUFFERED='True' 13 | 14 | DATASET=$1 15 | LOSS=$2 16 | 17 | case $DATASET in 18 | imagenet) 19 | TRAIN_IMDB="imagenet_train_bts" 20 | TEST_IMDB="imagenet_test_zsl_bts" 21 | ;; 22 | imagenet1k) 23 | TRAIN_IMDB="imagenet1k_train_bts" 24 | TEST_IMDB="imagenet1k_test_zsl_bts" 25 | ;; 26 | *) 27 | echo "No dataset given" 28 | exit 29 | ;; 30 | esac 31 | 32 | for LM in 'w2v_wiki_150D' 'w2v_wiki_300D' 33 | do 34 | # Train model 35 | time python tools/train_brute_force.py \ 36 | --imdb ${TRAIN_IMDB} \ 37 | --lm ${LM} \ 38 | --loss ${LOSS} \ 39 | --iters 10000 40 | done 41 | -------------------------------------------------------------------------------- /lib/preprocessing/preprocess_pkl_proposals.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | 4 | Loading the pickle proposals generated by py-faster-rcnn and generates hdf5 files ready to train the network. 5 | 6 | """ 7 | 8 | 9 | 10 | def generate_hdf5(dataset): 11 | return 0 12 | 13 | def get_feature_boxes(im, boxes, net): 14 | for bbox in boxes[:3]: 15 | x, y = (bbox[0], bbox[1]) 16 | width = bbox[2] - bbox[0] 17 | height = bbox[3] - bbox[1] 18 | y_2 = y + height 19 | x_2 = x + width 20 | 21 | image = im[int(round(y)):int(round(y_2)), int(round(x)):int(round(x_2))] 22 | image = resize(image, (224, 224)) 23 | image = np.rollaxis(image, 2, 1) 24 | image = np.rollaxis(image, 1, 0) 25 | image = image[np.newaxis, :, :, :] 26 | 27 | 28 | features = imagenet_test(im,imdb_boxes[i], net=net_box_to_feat) 29 | 30 | 31 | if __name__ == '__main__': 32 | 33 | -------------------------------------------------------------------------------- /lib/bt_datasets/imdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import PIL 4 | #from utils.cython_bbox import bbox_overlaps 5 | import numpy as np 6 | import scipy.sparse 7 | import bt_net.config as cfg 8 | import cPickle 9 | 10 | class imdb(object): 11 | """ Image database """ 12 | 13 | def __init__(self, name): 14 | self._name = name 15 | self._num_classes = 0 16 | self._classes = [] 17 | self._image_index = [] 18 | 19 | @property 20 | def name(self): 21 | return self._name 22 | 23 | @property 24 | def num_classes(self): 25 | return len(self._classes) 26 | 27 | @property 28 | def classes(self): 29 | return self._classes 30 | 31 | @property 32 | def image_index(self): 33 | return self._image_index 34 | 35 | @property 36 | def num_images(self): 37 | return len(self._image_index) 38 | 39 | 40 | -------------------------------------------------------------------------------- /scripts/predict_spaces.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # Usage: 3 | # ./scripts/predict_spaces_bts.sh DATASET LOSS function 4 | # DATASET is either imagenet or imagenet1k. 5 | # 6 | # Example: 7 | # ./scripts/predict_spaces_bts.sh imagenet1k squared_hinge 8 | 9 | set -x 10 | set -e 11 | 12 | export PYTHONUNBUFFERED='True' 13 | 14 | DATASET=$1 15 | LOSS=$2 16 | 17 | case $DATASET in 18 | imagenet) 19 | TRAIN_IMDB="imagenet_train_bts" 20 | TEST_IMDB="imagenet_test_zsl_bts" 21 | ;; 22 | imagenet1k) 23 | TRAIN_IMDB="imagenet1k_train_bts" 24 | TEST_IMDB="imagenet1k_test_zsl_bts" 25 | ;; 26 | *) 27 | echo "No dataset given" 28 | exit 29 | ;; 30 | esac 31 | 32 | for LM in 'glove_wiki_300D' 'w2v_wiki_300D' 33 | do 34 | # Train model 35 | time python tools/test_brute_force.py \ 36 | --ckpt output/bts_ckpt/${TRAIN_IMDB}/model_${TRAIN_IMDB}_${LM}_${LOSS}_l2.hdf5 \ 37 | --imdb ${TEST_IMDB} \ 38 | --lm ${LM} \ 39 | # --singlelabel_predict \ 40 | --space 1 41 | done 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.py~ 2 | *.pyc 3 | *.prototxt 4 | output/* 5 | snapshots/* 6 | 7 | ### PyCharm ### 8 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 9 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 10 | 11 | # User-specific stuff: 12 | .idea/workspace.xml 13 | .idea/tasks.xml 14 | 15 | # Sensitive or high-churn files: 16 | .idea/* 17 | .idea 18 | .idea/dataSources/ 19 | .idea/dataSources.ids 20 | .idea/dataSources.xml 21 | .idea/dataSources.local.xml 22 | .idea/sqlDataSources.xml 23 | .idea/dynamic.xml 24 | .idea/uiDesigner.xml 25 | /snapshots/bt_net/* 26 | /snapshots/MNIST/* 27 | 28 | # *.iml 29 | # modules.xml 30 | # .idea/misc.xml 31 | # *.ipr 32 | 33 | # End of https://www.gitignore.io/api/pycharm 34 | data/image_data/* 35 | !data/image_data/README.md 36 | data/lm_data/* 37 | !data/lm_data/README.md 38 | 39 | data/lm_data/glove/* 40 | data/lm_data/word2vec/* 41 | data/region_prop 42 | models/inception_keras/* 43 | -------------------------------------------------------------------------------- /tools/visualize_graph.py: -------------------------------------------------------------------------------- 1 | import _init_paths 2 | import sys 3 | import argparse 4 | from bt_net.visualize_graph import vis_graph 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser(description='Evaluate results') 8 | parser.add_argument('--lm', dest='lang_name', 9 | help='language model to use', 10 | default='glove_wiki', type=str) 11 | parser.add_argument('--imdb', dest='imdb_name', 12 | help='dataset to use', 13 | default='imagenet_train', type=str) 14 | parser.add_argument('--vis', dest='vis', 15 | help='select graph to display', 16 | default='vis_distance', type=str) 17 | if len(sys.argv) == 1: 18 | parser.print_help() 19 | sys.exit(1) 20 | 21 | args = parser.parse_args() 22 | return args 23 | 24 | 25 | if __name__ =='__main__': 26 | args = parse_args() 27 | 28 | print('Called with args:') 29 | print(args) 30 | vis_graph(args.lang_name, args.imdb_name, args.vis) 31 | -------------------------------------------------------------------------------- /lib/language_models/language_factory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Factory method for easily load pretrained language vector models by name """ 3 | 4 | __sets = {} 5 | 6 | from language_models.glove_factory import glove_factory 7 | 8 | # Set up language vectors 9 | #for corpus in ['wiki']: 10 | # for dimension in [50, 150, 300]: 11 | for corpus, dimension in [('glove_wiki_50D',50), ('glove_wiki_150D',150),\ 12 | ('glove_wiki_300D',300),('glove_pretrained',300),\ 13 | ('w2v_wiki_50D',50), ('w2v_wiki_150D',150), ("w2v_wiki_300D",300),\ 14 | ('w2v_pretrained',300), ("fast_eng",300),( "fast_nor",300)]: 15 | 16 | name = '{}'.format(corpus) 17 | __sets[name] = (lambda corpus=corpus, dimension=dimension: glove_factory(corpus, dimension)) 18 | 19 | def get_language_model(name): 20 | """ Get an language model by its name """ 21 | if not __sets.has_key(name): 22 | raise KeyError('Unknown language model: {}'.format(name)) 23 | return __sets[name]() 24 | 25 | def list_language_models(): 26 | """ List all registered language models """ 27 | return __sets.keys() 28 | -------------------------------------------------------------------------------- /lib/preprocessing/prepare_data.py: -------------------------------------------------------------------------------- 1 | # /usr/local/bin/python3.5 2 | # Made by Bjoernar Remmen 3 | # Short is better 4 | import h5py 5 | import numpy as np 6 | 7 | def addaxis(image): 8 | return image[np.newaxis, np.newaxis, :, :] 9 | 10 | def get_ids(label_array): 11 | id_label = {} 12 | counter = 0 13 | for labels in label_array: 14 | for label in labels: 15 | if label not in id_label: 16 | id_label[label] = counter 17 | counter += 1 18 | return id_label, counter+1 19 | 20 | def save_h5_file(name,directory,images, label_array): 21 | filename_h5 = directory + name + ".h5" 22 | filename_txt = directory + name + ".txt" 23 | 24 | #label_index, array_len = get_ids(label_array) 25 | 26 | new_labels = [] 27 | #for i in range(len(images)): 28 | # new_labels.append([label_index[word] for word in label_array[i]]) 29 | #images[i] = images[i,np.newaxis,:,:] 30 | 31 | with h5py.File(filename_h5, "w") as f: 32 | f['data'] = images 33 | f['label'] = label_array 34 | 35 | with open(filename_txt, "w") as text_file: 36 | text_file.write(filename_h5) 37 | 38 | return filename_txt,filename_h5 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /lib/preprocessing/train_places_file.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Made by Bjoernar Remmen 3 | # Short is better 4 | 5 | import os.path as osp 6 | root = osp.join("..","..") 7 | 8 | places = osp.join(root, "data", "image_data","places205") 9 | 10 | image_path = osp.join("data", "vision", "torralba", "deeplearning", "images256") 11 | from glob import glob 12 | def get_paths(): 13 | path = osp.join(places,image_path) 14 | paths = glob(osp.join(path,"*")) 15 | #print paths 16 | all_image_paths = [] 17 | for path in paths: 18 | labels = glob(osp.join(path,"*")) 19 | for image in labels: 20 | images = glob(osp.join(image,"*")) 21 | for img in images: 22 | p = img 23 | label = p.split("/")[-2] 24 | all_image_paths.append((p,label)) 25 | return all_image_paths 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | if __name__ == "__main__": 34 | tuple_list = get_paths() 35 | text_file = osp.join(places,"places_train.txt") 36 | with open(text_file,"w") as save: 37 | counter = 0 38 | for item in tuple_list: 39 | print counter, "/" , len(tuple_list) 40 | path = item[0] 41 | label = item[1] 42 | save.write(path + "," + label + "\n") 43 | counter += 1 44 | -------------------------------------------------------------------------------- /lib/bt_net/nms.py: -------------------------------------------------------------------------------- 1 | # import the necessary packages 2 | import numpy as np 3 | 4 | def py_cpu_nms(dets, thresh): 5 | """Pure Python NMS baseline.""" 6 | x1 = dets[:, 0] 7 | y1 = dets[:, 1] 8 | x2 = dets[:, 2] 9 | y2 = dets[:, 3] 10 | scores = dets[:, 4] 11 | 12 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 13 | order = scores.argsort()[::-1] 14 | 15 | keep = [] 16 | while order.size > 0: 17 | i = order[0] 18 | keep.append(i) 19 | xx1 = np.maximum(x1[i], x1[order[1:]]) 20 | yy1 = np.maximum(y1[i], y1[order[1:]]) 21 | xx2 = np.minimum(x2[i], x2[order[1:]]) 22 | yy2 = np.minimum(y2[i], y2[order[1:]]) 23 | 24 | w = np.maximum(0.0, xx2 - xx1 + 1) 25 | h = np.maximum(0.0, yy2 - yy1 + 1) 26 | inter = w * h 27 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 28 | 29 | inds = np.where(ovr <= thresh)[0] 30 | order = order[inds + 1] 31 | 32 | return keep 33 | 34 | def py_area_fix(image, dets, thresh): 35 | x1 = dets[:, 0] 36 | y1 = dets[:, 1] 37 | x2 = dets[:, 2] 38 | y2 = dets[:, 3] 39 | scores = dets[:, 4] 40 | 41 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 42 | order = scores.argsort()[::-1] 43 | image_area = image.shape[0] * image.shape[1] 44 | keep = [] 45 | for i in range(len(areas)): 46 | if (areas[i] / float(image_area)) > thresh: 47 | keep.append(i) 48 | return keep -------------------------------------------------------------------------------- /scripts/train_losses_bts.sh.save: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # Usage: 3 | # ./scripts/train_losses.sh DATASET LANGUAGE_MODEL 4 | # DATASET is either imagenet or imagenet1k. 5 | # 6 | # Example: 7 | # ./scripts/train_losses_bts.sh imagenet1k glove_wiki_300 8 | 9 | set -x 10 | set -e 11 | 12 | export PYTHONUNBUFFERED='True' 13 | 14 | DATASET=$1 15 | LM=$2 16 | 17 | case $DATASET in 18 | imagenet) 19 | TRAIN_IMDB="imagenet_train_bts" 20 | TEST_IMDB="imagenet_test_zsl_bts" 21 | ;; 22 | imagenet1k) 23 | TRAIN_IMDB="imagenet1k_train_bts" 24 | TEST_IMDB="imagenet1k_test_zsl_bts" 25 | ;; 26 | *) 27 | echo "No dataset given" 28 | exit 29 | ;; 30 | esac 31 | 32 | 33 | for loss in 'cosine_proximity' 'squared_hinge' 34 | do 35 | time python tools/train_brute_force.py \ 36 | --imdb ${TRAIN_IMDB} \ 37 | --lm ${LM} \ 38 | --loss ${loss} \ 39 | --iters 10000 40 | done 41 | 42 | 43 | 44 | # Calculate accuracy for the different losses 45 | for loss in 'hinge' 'cosine_proximity' 'squared_hinge' 46 | do 47 | time python tools/test_brute_force.py \ 48 | --ckpt output/bts_ckpt/${TRAIN_IMDB}/model_${TRAIN_IMDB}_${LM}_${loss}.hdf5 \ 49 | --imdb ${TEST_IMDB} \ 50 | --lm ${LM} \ 51 | --singlelabel_predict 52 | done 53 | 54 | # Also special case euclidean distance 55 | loss = 'euclidean' 56 | time python tools/test_brute_force.py \ 57 | --ckpt output/bts_ckpt/${TRAIN_IMDB}/model_${TRAIN_IMDB}_${LM}_${loss}.hdf5 \ 58 | --imdb ${TEST_IMDB} \ 59 | --lm ${LM} \ 60 | --euc_loss \ 61 | --singlelabel_predict 62 | -------------------------------------------------------------------------------- /tools/evaluate_results.py: -------------------------------------------------------------------------------- 1 | import _init_paths 2 | import argparse 3 | import sys 4 | from bt_net.evaluate_results import * 5 | from language_models.language_factory import get_language_model 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser(description='Evaluate results') 9 | parser.add_argument('--results', dest='fn', 10 | help='location of result file', type=str) 11 | parser.add_argument('--k', dest='k', 12 | help='top-k value', type=str) 13 | parser.add_argument('--method', dest='method', 14 | help='select method: flat&MAP, pclass or avgeuc', 15 | default=None, type=str) 16 | parser.add_argument('--lm', dest='lang_mod', 17 | help='Initiate which language model to use with euc_distance', 18 | default='w2v_wiki_300D', type=str) 19 | 20 | if len(sys.argv) == 1: 21 | parser.print_help() 22 | sys.exit(1) 23 | 24 | args = parser.parse_args() 25 | return args 26 | 27 | if __name__=='__main__': 28 | args = parse_args() 29 | print('Called with args: ') 30 | print args 31 | 32 | if args.method == 'pclass': 33 | evaluate_pr_class(args.fn) 34 | elif args.method == 'apredict': 35 | show_actual_predicted(args.fn) 36 | elif args.method == 'avg_class': 37 | average_gt_classes(args.fn) 38 | elif args.method == 'avgeuc': 39 | average_cosine(args.fn, get_language_model(args.lang_mod)) 40 | else: 41 | evaluate_flat_map(args.fn) 42 | -------------------------------------------------------------------------------- /scripts/train_losses_bts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # Usage: 3 | # ./scripts/train_losses.sh DATASET LANGUAGE_MODEL 4 | # DATASET is either imagenet or imagenet1k. 5 | # 6 | # Example: 7 | # ./scripts/train_losses_bts.sh imagenet1k glove_wiki_300 8 | 9 | set -x 10 | set -e 11 | 12 | export PYTHONUNBUFFERED='True' 13 | 14 | DATASET=$1 15 | LM=$2 16 | 17 | case $DATASET in 18 | imagenet) 19 | TRAIN_IMDB="imagenet_train_bts" 20 | TEST_IMDB="imagenet_test_zsl_bts" 21 | ;; 22 | imagenet1k) 23 | TRAIN_IMDB="imagenet1k_train_bts" 24 | TEST_IMDB="imagenet1k_test_zsl_bts" 25 | ;; 26 | *) 27 | echo "No dataset given" 28 | exit 29 | ;; 30 | esac 31 | 32 | for loss in 'hinge' 'cosine_proximity' 'squared_hinge' 33 | do 34 | # Train model 35 | time python tools/train_brute_force.py \ 36 | --imdb ${TRAIN_IMDB} \ 37 | --lm ${LM} \ 38 | --loss ${loss} \ 39 | --iters 10000 40 | 41 | # Make prediction 42 | time python tools/test_brute_force.py \ 43 | --ckpt output/bts_ckpt/${TRAIN_IMDB}/model_${TRAIN_IMDB}_${LM}_${loss}_l2_.hdf5 \ 44 | --imdb ${TEST_IMDB} \ 45 | --lm ${LM} \ 46 | --singlelabel_predict \ 47 | --space 1 48 | 49 | done 50 | 51 | 52 | # Also special case euclidean distance 53 | loss='euclidean' 54 | time python tools/test_brute_force.py \ 55 | --ckpt output/bts_ckpt/${TRAIN_IMDB}/model_${TRAIN_IMDB}_${LM}_${loss}_l2_adam.hdf5 \ 56 | --imdb ${TEST_IMDB} \ 57 | --lm ${LM} \ 58 | --singlelabel_predict 59 | 60 | time python tools/test_brute_force.py \ 61 | --ckpt output/bts_ckpt/${TRAIN_IMDB}/model_${TRAIN_IMDB}_${LM}_${loss}_l2_.hdf5 \ 62 | --imdb ${TEST_IMDB} \ 63 | --lm ${LM} \ 64 | --euc_loss \ 65 | --singlelabel_predict \ 66 | --space 1 67 | -------------------------------------------------------------------------------- /tools/train_brute_force.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Train brute force Zero Shot Classification 4 | 5 | Example: python tools/train_brute_force.py --imdb imagenet1k_train_bts --lm glove_wiki_300D --loss squared_hinge --iters 10000 6 | 7 | """ 8 | import _init_paths 9 | import argparse 10 | import sys 11 | import numpy as np 12 | from bt_net.train_bts import train_bts 13 | from language_models.language_factory import get_language_model 14 | from bt_datasets.factory import get_imdb 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser(description='Train the Brute-Force BTnet model') 18 | parser.add_argument('--iters', dest='max_iters', 19 | help='number of iterations to train', 20 | default=40000, type=int) 21 | parser.add_argument('--lm', dest='lang_name', 22 | help='language model to use', 23 | default='glove_wiki', type=str) 24 | parser.add_argument('--imdb', dest='imdb_name', 25 | help='dataset to train on', 26 | default='imagenet_train', type=str) 27 | parser.add_argument('--rand', dest='randomize', 28 | help='randomize (do not use a fixed seed)', 29 | action='store_true') 30 | parser.add_argument('--loss', dest='loss', 31 | help='loss function to run', 32 | default='hinge', type=str) 33 | 34 | if len(sys.argv) == 1: 35 | parser.print_help() 36 | sys.exit(1) 37 | 38 | args = parser.parse_args() 39 | return args 40 | 41 | 42 | if __name__ =='__main__': 43 | args = parse_args() 44 | 45 | print('Called with args:') 46 | print(args) 47 | 48 | if not args.randomize: 49 | # fix the random seeds (numpy and caffe) for reproducibility 50 | np.random.seed(42) 51 | 52 | lang_db = get_language_model(args.lang_name) 53 | imdb = get_imdb(args.imdb_name) 54 | 55 | train_bts(lang_db, imdb, max_iters = args.max_iters, loss = args.loss) 56 | 57 | -------------------------------------------------------------------------------- /tools/visualize_space.py: -------------------------------------------------------------------------------- 1 | import _init_paths 2 | import argparse 3 | import sys 4 | import codecs 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | from sklearn.manifold import TSNE 9 | from language_models.language_factory import get_language_model 10 | from bt_datasets.factory import get_imdb 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser(description='Visualize the predict space') 14 | parser.add_argument('--lm', dest='lang_name', 15 | help='language model to use', 16 | default='glove_wiki', type=str) 17 | parser.add_argument('--imdb', dest='imdb_name', 18 | help='dataset to predict on', 19 | default='imagenet1k_test_zsl', type=str) 20 | parser.add_argument('--space', dest='space', 21 | help='predict space to visualize: 0: all of wikipedia, 1: only unseen labels, 2: seen + unseen (default: unseen+seen).', default=2, type=int) 22 | 23 | if len(sys.argv) == 1: 24 | parser.print_help() 25 | sys.exit(1) 26 | 27 | args = parser.parse_args() 28 | return args 29 | 30 | 31 | def main(): 32 | 33 | args = parse_args() 34 | 35 | print('Called with args:') 36 | print(args) 37 | lang_db = get_language_model(args.lang_name) 38 | imdb = get_imdb(args.imdb_name) 39 | 40 | # Get words in space 41 | vocabulary = imdb.get_labels(args.space) 42 | 43 | # Get features for words 44 | wv = [lang_db.word_vector(w) for w in vocabulary] 45 | from sklearn.metrics.pairwise import cosine_similarity 46 | from scipy import spatial 47 | #spatial.distance.cosine(dataSetI, dataSetII) 48 | tsne = TSNE(n_components=2, random_state=0) 49 | np.set_printoptions(suppress=True) 50 | Y = tsne.fit_transform(wv) 51 | 52 | plt.scatter(Y[:, 0], Y[:, 1]) 53 | for label, x, y in zip(vocabulary, Y[:, 0], Y[:, 1]): 54 | plt.annotate(label, xy=(x, y), xytext=(0, 0), textcoords='offset points') 55 | plt.show() 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /lib/bt_datasets/factory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Factory method for easily getting imdbs by name.""" 3 | 4 | __sets = {} 5 | 6 | from bt_datasets.imagenet import imagenet 7 | from bt_datasets.open_images import open_images 8 | from bt_datasets.nus_wide import nus_wide 9 | from bt_datasets.imagenet1k import imagenet1k 10 | from bt_datasets.places205 import places205 11 | from bt_datasets.places365 import places365 12 | 13 | 14 | 15 | #from bt_datasets.places205 import places205 16 | import numpy as np 17 | 18 | # Set up imagenet 19 | for split in ['train', 'train_small', 'train_bts','train_no_extra', 'train_clean_imagelist','train_1','val', 'val_bts','val1', 'val2', 'test', 'test_zs']: 20 | name = 'imagenet_{}'.format(split) 21 | __sets[name] = (lambda split=split: imagenet(split)) 22 | 23 | # Set up Imagenet1k -2012 24 | for split in ['train_cls','test', 'val', 'train_800', 25 | 'test_200', 'train_800_fixed', 'test_200_fixed', 26 | 'train_bts', 'test_zsl_bts', 'val_bts']: 27 | name = 'imagenet1k_{}'.format(split) 28 | __sets[name] = (lambda split=split: imagenet1k(split)) 29 | 30 | # Set up NUS-WIDE 31 | for split in ['Train_zs_920_img_lbl', 'Test_zs_us_img_lbl', 'Test_zs_u_img_lbl', 'train_1k' ]: 32 | name = 'nus_wide_{}'.format(split) 33 | __sets[name] = (lambda split=split: nus_wide(split)) 34 | 35 | 36 | # Set up places205 37 | for split in ['train', 'train_8']: 38 | name = 'places205_{}'.format(split) 39 | __sets[name] = (lambda split=split: places205(split)) 40 | 41 | # Set up places365 42 | for split in ['train_zs']: 43 | name = 'places365_{}'.format(split) 44 | __sets[name] = (lambda split=split: places365(split)) 45 | 46 | # Set up for OpenImages 47 | for split in ['train_zs',"test_zs","validation"]: 48 | name = 'openimages_{}'.format(split) 49 | __sets[name] = (lambda split=split: open_images(split)) 50 | 51 | 52 | def get_imdb(name): 53 | """Get an imdb (image database) by name.""" 54 | if not __sets.has_key(name): 55 | raise KeyError('Unknown dataset: {}'.format(name)) 56 | return __sets[name]() 57 | 58 | def list_imdbs(): 59 | """List all registered imdbs.""" 60 | return __sets.keys() -------------------------------------------------------------------------------- /lib/preprocessing/clean_imagenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os.path as osp 3 | import sys 4 | 5 | def generate_imagenet_single(lang_db, imdb): 6 | with open('val_bts.txt', 'w+') as wf: 7 | skipped = 0 8 | for i in range(imdb.num_images): 9 | if i % 1000 == 0: print "{} / {}, skipped {}".format(i, imdb.num_images, skipped) 10 | if imdb._image_folder() is 'val': 11 | y = imdb.get_val_label(i, lang_db) 12 | else: 13 | y = imdb.get_image_label(i, lang_db) 14 | if y is False: 15 | skipped = skipped + 1 16 | continue 17 | wf.write('%s' % y[0]) 18 | wf.write(' %s\n' % y[1]) 19 | 20 | def generate_splitted_classes(lang_db, fn = 'test_zsl_bts.txt'): 21 | train_split_f = 'train_800_classes.txt' 22 | test_split_f = 'test_200_classes.txt' 23 | with open(fn) as f, open('map_clsloc.txt') as synset,\ 24 | open(train_split_f, 'w+') as train_f, open(test_split_f, 'w+') as test_f: 25 | classes = [] 26 | for line in synset: 27 | w = line.strip().split(" ") 28 | classes.append(w[2].lower()) 29 | 30 | test_cls = [] 31 | for line in f: 32 | words = line.strip().split(' ') 33 | cls = words[1].lower() 34 | if cls not in test_cls: 35 | test_f.write('%s\n' % cls) 36 | test_cls.append(cls) 37 | classes.remove(cls) 38 | print "Total test classes: {}".format(len(test_cls)) 39 | 40 | train_cls = [] 41 | for c in classes: 42 | if lang_db.word_vector(c) is not None: 43 | train_cls.append(c) 44 | train_f.write('%s\n' % c) 45 | 46 | print "Total train classes: {}".format(len(train_cls)) 47 | 48 | if __name__ == '__main__': 49 | 50 | this_dir = osp.dirname(__file__) 51 | lib_path = osp.join(this_dir, '../..', 'lib') 52 | sys.path.insert(0, lib_path) 53 | 54 | from bt_datasets.factory import get_imdb 55 | from language_models.language_factory import get_language_model 56 | 57 | lang_db = get_language_model('glove_wiki_300D') 58 | imdb = get_imdb('imagenet1k_val') 59 | generate_imagenet_single(lang_db, imdb) 60 | 61 | #generate_splitted_classes(lang_db) 62 | 63 | -------------------------------------------------------------------------------- /lib/preprocessing/load_images.py: -------------------------------------------------------------------------------- 1 | # /usr/local/bin/python3.5 2 | # Made by Bjoernar Remmen 3 | # Short is better 4 | import caffe 5 | import lmdb 6 | 7 | import numpy as np 8 | import lmdb 9 | import caffe 10 | 11 | N = 1000 12 | 13 | # Let's pretend this is interesting data 14 | X = np.zeros((N, 3, 32, 32), dtype=np.uint8) 15 | y = np.zeros(N, dtype=np.int64) 16 | 17 | # We need to prepare the database for the size. We'll set it 10 times 18 | # greater than what we theoretically need. There is little drawback to 19 | # setting this too big. If you still run into problem after raising 20 | # this, you might want to try saving fewer entries in a single 21 | # transaction. 22 | map_size = X.nbytes * 10 23 | 24 | env = lmdb.open('mylmdb', map_size=map_size) 25 | 26 | with env.begin(write=True) as txn: 27 | # txn is a Transaction object 28 | for i in range(N): 29 | datum = caffe.proto.caffe_pb2.Datum() 30 | datum.channels = X.shape[1] 31 | datum.height = X.shape[2]import numpy as np 32 | import lmdb 33 | import caffe 34 | 35 | N = 1000 36 | 37 | # Let's pretend this is interesting data 38 | X = np.zeros((N, 3, 32, 32), dtype=np.uint8) 39 | y = np.zeros(N, dtype=np.int64) 40 | 41 | # We need to prepare the database for the size. We'll set it 10 times 42 | # greater than what we theoretically need. There is little drawback to 43 | # setting this too big. If you still run into problem after raising 44 | # this, you might want to try saving fewer entries in a single 45 | # transaction. 46 | map_size = X.nbytes * 10 47 | 48 | env = lmdb.open('mylmdb', map_size=map_size) 49 | 50 | with env.begin(write=True) as txn: 51 | # txn is a Transaction object 52 | for i in range(N): 53 | datum = caffe.proto.caffe_pb2.Datum() 54 | datum.channels = X.shape[1] 55 | datum.height = X.shape[2] 56 | datum.width = X.shape[3] 57 | datum.data = X[i].tobytes() # or .tostring() if numpy < 1.9 58 | datum.label = int(y[i]) 59 | str_id = '{:08}'.format(i) 60 | 61 | # The encode is only essential in Python 3 62 | txn.put(str_id.encode('ascii'), datum.SerializeToString()) 63 | datum.width = X.shape[3] 64 | datum.data = X[i].tobytes() # or .tostring() if numpy < 1.9 65 | datum.label = int(y[i]) 66 | str_id = '{:08}'.format(i) 67 | 68 | # The encode is only essential in Python 3 69 | txn.put(str_id.encode('ascii'), datum.SerializeToString()) -------------------------------------------------------------------------------- /lib/bt_net/config.py: -------------------------------------------------------------------------------- 1 | # /usr/local/bin/python3.5 2 | # Made by Bjoernar Remmen 3 | # Short is better 4 | import os.path as osp 5 | from easydict import EasyDict as edict 6 | __C = edict() 7 | 8 | 9 | cfg = __C 10 | 11 | # Root directory of project 12 | __C.ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..', '..')) 13 | __C.DATA_DIR = osp.join(__C.ROOT_DIR, 'data') 14 | __C.LM_DATA_DIR = osp.join(__C.DATA_DIR, 'lm_data') 15 | __C.IMAGE_DATA_DIR = osp.join(__C.DATA_DIR, 'image_data') 16 | __C.SNAPSHOT_DIR = osp.join(__C.ROOT_DIR, "snapshots") 17 | __C.MODEL_DIR = osp.join(__C.ROOT_DIR, "models") 18 | __C.FASTER_RCNN_DIR = osp.join(__C.ROOT_DIR, '..', 'py-faster-rcnn') # Set to py-faster-rcnn root 19 | __C.FASTER_RCNN_LIB = osp.join(__C.FASTER_RCNN_DIR, 'lib') 20 | 21 | # Set fixed seed for reproducibility 22 | __C.RNG_SEED = 3 23 | 24 | 25 | 26 | def _merge_a_into_b(a, b): 27 | """Merge config dictionary a into config dictionary b, clobbering the 28 | options in b whenever they are also specified in a. 29 | """ 30 | if type(a) is not edict: 31 | return 32 | 33 | for k, v in a.iteritems(): 34 | # a must specify keys that are in b 35 | if not b.has_key(k): 36 | raise KeyError('{} is not a valid config key'.format(k)) 37 | 38 | # the types must match, too 39 | old_type = type(b[k]) 40 | if old_type is not type(v): 41 | if isinstance(b[k], np.ndarray): 42 | v = np.array(v, dtype=b[k].dtype) 43 | else: 44 | raise ValueError(('Type mismatch ({} vs. {}) ' 45 | 'for config key: {}').format(type(b[k]), 46 | type(v), k)) 47 | 48 | # recursively merge dicts 49 | if type(v) is edict: 50 | try: 51 | _merge_a_into_b(a[k], b[k]) 52 | except: 53 | print('Error under config key: {}'.format(k)) 54 | raise 55 | else: 56 | b[k] = v 57 | 58 | def cfg_from_file(filename): 59 | """Load a config file and merge it into the default options.""" 60 | import yaml 61 | with open(filename, 'r') as f: 62 | yaml_cfg = edict(yaml.load(f)) 63 | 64 | _merge_a_into_b(yaml_cfg, __C) 65 | 66 | def cfg_from_list(cfg_list): 67 | """Set config keys via list (e.g., from command line).""" 68 | from ast import literal_eval 69 | assert len(cfg_list) % 2 == 0 70 | for k, v in zip(cfg_list[0::2], cfg_list[1::2]): 71 | key_list = k.split('.') 72 | d = __C 73 | for subkey in key_list[:-1]: 74 | assert d.has_key(subkey) 75 | d = d[subkey] 76 | subkey = key_list[-1] 77 | assert d.has_key(subkey) 78 | try: 79 | value = literal_eval(v) 80 | except: 81 | # handle the case when v is a string literal 82 | value = v 83 | assert type(value) == type(d[subkey]), \ 84 | 'type {} does not match original type {}'.format( 85 | type(value), type(d[subkey])) 86 | d[subkey] = value 87 | -------------------------------------------------------------------------------- /lib/bt_datasets/places365.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Made by Bjoernar Remmen 3 | # Short is better 4 | 5 | import os, sys 6 | import os.path as osp 7 | from copy import deepcopy 8 | from random import shuffle 9 | from bt_datasets.imdb import imdb 10 | from bt_net.config import cfg 11 | import numpy as np 12 | import scipy.io as sio 13 | import cPickle 14 | import xml.etree.ElementTree as ET 15 | import xml.dom.minidom as minidom 16 | from preprocessing.preprocess_images import read_img 17 | 18 | from random import shuffle 19 | 20 | 21 | class places365(imdb): 22 | def __init__(self, image_set): 23 | imdb.__init__(self, 'places365_' + image_set) 24 | self._image_set = image_set 25 | self._devkit_path = osp.join(cfg.IMAGE_DATA_DIR, 'places365') 26 | self._data_path = osp.join(self._devkit_path, 'data_256') 27 | 28 | 29 | self._classes = ('__background__',) 30 | with open(osp.join(self._devkit_path, 'labels.txt')) as lbls: 31 | for lbl_line in lbls: 32 | self._classes = self._classes + (lbl_line.rstrip(),) 33 | print "classes", len(self._classes) 34 | 35 | self._image_ext = ['.JPG'] 36 | self._image_index = self._load_image_set_index() 37 | print "image_index", len(self._image_index) 38 | 39 | 40 | 41 | # Specific config options 42 | self.config = {'cleanup': True, 43 | 'use_salt': True, 44 | 'top_k': 2000} 45 | 46 | assert os.path.exists(self._devkit_path), \ 47 | 'Devkit path does not exist: {}'.format(self._devkit_path) 48 | assert os.path.exists(self._data_path), \ 49 | 'Path does not exist: {}'.format(self._data_path) 50 | 51 | 52 | def _load_image_set_index(self): 53 | """ 54 | Load the indexes listed in this dataset's image set file. 55 | """ 56 | # Example path to image set file: 57 | # self._data_path + /ImageSets/val.txt or self._data_path + /ImageSets/zero_shot.txt 58 | image_set_file = os.path.join(self._devkit_path, self._image_set + '.txt') 59 | print "img_file", image_set_file 60 | assert os.path.exists(image_set_file), \ 61 | 'Path does not exist: {}'.format(image_set_file) 62 | 63 | with open(image_set_file) as f: 64 | image_index = [] 65 | for line in f: 66 | line = line.split(",") 67 | image_index.append(line) 68 | #image_index = [line.split() for line in f] 69 | return image_index 70 | 71 | 72 | 73 | 74 | 75 | 76 | def get_image_data_clean(self, index, lang_db): 77 | data = self._image_index[index] 78 | image_data = data[0] 79 | lbls = [x.strip() for x in data[1:]] 80 | img_path = osp.join(self._data_path, image_data) 81 | word_vectors = [] 82 | for lbl in lbls: 83 | word_vec = lang_db.word_vector(lbl) 84 | word_vectors.append(word_vec) 85 | return img_path , word_vectors 86 | 87 | 88 | def get_classes(self): 89 | return self._classes[1:] 90 | 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /tools/train_ml_brute_force.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | 4 | Train brute force Zero Shot Classification 5 | 1. Train a small dense network that learns to train inception features of 1024 down to 300d, 6 | and inputs images in same semantic space as language model 7 | 8 | """ 9 | import _init_paths 10 | import argparse 11 | import sys 12 | import numpy as np 13 | from bt_net.train_bts import train_multilabel_bts, euclidean_distance 14 | from language_models.language_factory import get_language_model 15 | from bt_datasets.factory import get_imdb 16 | from keras.models import load_model 17 | import os.path as osp 18 | from bt_net.config import cfg 19 | import tensorflow as tf 20 | from keras.backend.tensorflow_backend import set_session 21 | 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser(description='Train the Brute-Force BTnet model') 26 | parser.add_argument('--iters', dest='max_iters', 27 | help='number of iterations to train', 28 | default=40000, type=int) 29 | parser.add_argument('--lm', dest='lang_name', 30 | help='language model to use', 31 | default='glove_wiki', type=str) 32 | parser.add_argument('--imdb', dest='imdb_name', 33 | help='dataset to train on', 34 | default='imagenet_train', type=str) 35 | parser.add_argument('--rand', dest='randomize', 36 | help='randomize (do not use a fixed seed)', 37 | action='store_true') 38 | parser.add_argument('--loss', dest='loss_func', 39 | help='name of loss in pretrained', 40 | default='squared_hinge', type=str) 41 | parser.add_argument('--model', dest='model', 42 | help='pretrained model', type=str) 43 | parser.add_argument('--boxes', dest='boxes', 44 | help='method to generate boxes (random, frcnn, yolo)', 45 | default='random', type=str) 46 | 47 | if len(sys.argv) == 1: 48 | parser.print_help() 49 | sys.exit(1) 50 | 51 | args = parser.parse_args() 52 | return args 53 | 54 | 55 | if __name__ =='__main__': 56 | args = parse_args() 57 | 58 | print('Called with args:') 59 | print(args) 60 | 61 | if not args.randomize: 62 | # fix the random seeds (numpy and caffe) for reproducibility 63 | np.random.seed(42) 64 | 65 | lang_db = get_language_model(args.lang_name) 66 | imdb = get_imdb(args.imdb_name) 67 | ''' 68 | lang_db.build_tree("nus_wide_train_1k",imdb.load_classes()) 69 | #pretrained = load_model(osp.join(cfg.SNAPSHOT_DIR, "model_train_bts_glove_wiki_300.hdf5"), custom_objects={'euclidean_distance': euclidean_distance}) 70 | #pretrained = load_model(osp.join(cfg.SNAPSHOT_DIR, args.model)) 71 | pretrained = load_model(args.model) 72 | 73 | print("test random",(pretrained.predict_on_batch(np.random.rand(30,299,299,3))).shape) 74 | train_multilabel_bts(lang_db, imdb, pretrained, max_iters=args.max_iters, loss_func=args.loss_func) 75 | ''' 76 | pretrained = load_model(args.model) 77 | print("test random",(pretrained.predict_on_batch(np.random.rand(30,299,299,3))).shape) 78 | train_multilabel_bts(lang_db, imdb, pretrained, max_iters=args.max_iters, loss_func=args.loss_func, box_method=args.boxes) 79 | -------------------------------------------------------------------------------- /tools/test_brute_force.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the brute-force zero shot. 3 | 1. The testing part splits an image into many boxes, and runs each box through the trained model. 4 | 2. Each box gets labeled by the model, using euclidean distance to find closest label in the space. 5 | 3. After each box is labeled, returns the top-5 predicted labels (if there is less than 5 predicted in total, returned the labels predicted) 6 | 4. This way the model predicts in a zero-shot manner, as it will predict labels that exist in the word semantic space, but was not in the trained dataset. 7 | 8 | Example run: 9 | python tools/test_brute_force.py --lm glove_wiki_300 --imdb imagenet_zs --ckpt output/train_bts/model_glove_wiki_300.hdf5 --boxes faster_rcnn 10 | 11 | python tools/test_brute_force.py --ckpt output/bts_ckpt/${TRAIN_IMDB}/model_${TRAIN_IMDB}_${LM}_${loss}.hdf5 --imdb ${TEST_IMDB} --lm glove_wiki_300 --singlelabel_predict 12 | """ 13 | import _init_paths 14 | import argparse 15 | import sys 16 | import numpy as np 17 | from bt_net.test_bts import test_bts 18 | from language_models.language_factory import get_language_model 19 | from bt_datasets.factory import get_imdb 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description='Train the Brute-Force BTnet model') 23 | parser.add_argument('--lm', dest='lang_name', 24 | help='language model to use (default: glove_wiki_300)', 25 | default='glove_wiki_300', type=str) 26 | parser.add_argument('--imdb', dest='imdb_name', 27 | help='dataset to train on', 28 | default='imagenet_train', type=str) 29 | parser.add_argument('--ckpt', dest='ckpt', 30 | help='trained model to perform prediction on', type=str) 31 | parser.add_argument('--rand', dest='randomize', 32 | help='randomize (do not use a fixed seed)', 33 | action='store_true') 34 | parser.add_argument('--euc_loss', dest='euc_loss', 35 | help='if euclidean_loss is used', action='store_true') 36 | parser.add_argument('--singlelabel_predict', dest='singlelabel_predict', 37 | help='if you want to predict singlelabel', action='store_true') 38 | parser.add_argument('--space', dest='space', 39 | help='select word vector space to predict on: 0: all of wikipedia, 1: only unseen labels, 2: seen + unseen (default: unseen+seen).', default=2, type=int) 40 | parser.add_argument('--boxes', dest='boxes', 41 | help='predict using generated boxes faster_rcnn or yolo (default: random)', default='random', type=str) 42 | 43 | 44 | 45 | if len(sys.argv) == 1: 46 | parser.print_help() 47 | sys.exit(1) 48 | 49 | args = parser.parse_args() 50 | return args 51 | 52 | if __name__ == '__main__': 53 | args = parse_args() 54 | 55 | print('Called with args:') 56 | print(args) 57 | 58 | if not args.randomize: 59 | # fix the random seeds (numpy) for reproducibility 60 | np.random.seed(42) 61 | 62 | lang_db = get_language_model(args.lang_name) 63 | imdb = get_imdb(args.imdb_name) 64 | 65 | assert 0 <= args.space <= 2 , \ 66 | 'Space has to be either 0, 1 or 2' 67 | 68 | if args.space != 0: 69 | words = imdb.get_labels(args.space) 70 | else: 71 | space = 'all' 72 | words = [] 73 | 74 | lang_db.build_tree(args.space, words, args.imdb_name) 75 | test_bts(lang_db, imdb, args.ckpt, args.euc_loss, args.singlelabel_predict, args.space, args.boxes) 76 | -------------------------------------------------------------------------------- /scripts/train_ml_bruteforce.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -x 4 | set -e 5 | v='w2v_wiki_300D glove_pretrained glove_pretrained' 6 | if [ 1 -eq 0 ]; then 7 | 8 | 9 | for img_mod in 'nus_wide_Train_zs_920_img_lbl' 10 | do 11 | for lang_mod in 'w2v_pretrained' 'fast_eng' 'w2v_wiki_300D' 'glove_wiki_300D' 12 | do 13 | python tools/train_ml_brute_force.py \ 14 | --lm ${lang_mod}\ 15 | --imdb ${img_mod}\ 16 | --iters 100000\ 17 | --loss sq_hinge 18 | done 19 | done 20 | 21 | 22 | for img_mod in 'nus_wide_train_1k' 23 | do 24 | for lang_mod in 'fast_eng' 'w2v_wiki_300D' 'glove_wiki_300D' 25 | do 26 | python tools/train_ml_brute_force.py \ 27 | --lm ${lang_mod}\ 28 | --imdb ${img_mod}\ 29 | --iters 100000 30 | done 31 | done 32 | 33 | fi 34 | if [ 1 -eq 0 ]; then 35 | 36 | for img_mod in 'nus_wide_Train_zs_920_img_lbl' 'nus_wide_train_1k' 37 | do 38 | for lang_mod in 'glove_wiki_300D' 39 | do 40 | python tools/train_ml_brute_force.py\ 41 | --lm ${lang_mod}\ 42 | --imdb ${img_mod}\ 43 | --iters 100000\ 44 | --loss sq_hinge_\ 45 | --model model_imagenet1k_train_bts_glove_wiki_300D_squared_hinge_l2.hdf5 46 | done 47 | 48 | done 49 | 50 | 51 | for img_mod in 'nus_wide_train_1k' 52 | do 53 | for lang_mod in 'w2v_wiki_300D' 54 | do 55 | python tools/train_ml_brute_force.py\ 56 | --lm ${lang_mod}\ 57 | --imdb ${img_mod}\ 58 | --iters 100000\ 59 | --loss sq_hinge_${img_mod}\ 60 | --model /media/bjotta/13f2cffb-0a7d-41b9-946f-36d679d1e9f6/home/fast-rcnn/Zero_Shot_Multi_Label/snapshots/model_imagenet1k_train_bts_w2v_wiki_300D_squared_hinge_l2.hdf5 61 | done 62 | 63 | done 64 | 65 | 66 | fi 67 | for img_mod in 'nus_wide_Train_zs_920_img_lbl' 68 | do 69 | for lang_mod in 'w2v_wiki_300D' 70 | do 71 | python tools/train_ml_brute_force.py\ 72 | --lm ${lang_mod}\ 73 | --imdb ${img_mod}\ 74 | --iters 100000\ 75 | --loss sq_hinge\ 76 | --boxes yolo \ 77 | --model /media/bjotta/13f2cffb-0a7d-41b9-946f-36d679d1e9f6/home/fast-rcnn/Zero_Shot_Multi_Label/snapshots/model_imagenet1k_train_bts_w2v_wiki_300D_squared_hinge_l2.hdf5 78 | done 79 | 80 | done 81 | 82 | 83 | parser = argparse.ArgumentParser(description='Train the Brute-Force BTnet model') 84 | parser.add_argument('--iters', dest='max_iters', 85 | help='number of iterations to train', 86 | default=40000, type=int) 87 | parser.add_argument('--lm', dest='lang_name', 88 | help='language model to use', 89 | default='glove_wiki', type=str) 90 | parser.add_argument('--imdb', dest='imdb_name', 91 | help='dataset to train on', 92 | default='imagenet_train', type=str) 93 | parser.add_argument('--rand', dest='randomize', 94 | help='randomize (do not use a fixed seed)', 95 | action='store_true') 96 | parser.add_argument('--loss', dest='loss_func', 97 | help='name of loss in pretrained', 98 | default='squared_hinge', type=str) 99 | parser.add_argument('--model', dest='model', 100 | help='pretrained model', type=str) 101 | parser.add_argument('--boxes', dest='boxes', 102 | help='method to generate boxes (random, frcnn, yolo)', type=str) 103 | 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /tools/downloader.py: -------------------------------------------------------------------------------- 1 | # Made by Bjoernar Remmen 2 | # Short is better 3 | import urllib2 4 | import os.path as osp 5 | import os 6 | from parallel_sync import wget 7 | 8 | file_path = os.path.dirname(__file__) 9 | print file_path 10 | root = osp.join(file_path,"..") 11 | print root 12 | glove_dir = osp.join(root,"data","lm_data","glove") 13 | if not os.path.exists(glove_dir): 14 | os.makedirs(glove_dir) 15 | word2vec_dir = osp.join(root,"data", "lm_data", "word2vec") 16 | if not os.path.exists(word2vec_dir): 17 | os.makedirs(word2vec_dir) 18 | print glove_dir 19 | # glove [50D,150D,300D] 20 | 21 | glove_urls = ["https://www.dropbox.com/s/cgy4mdstgpmtnw8/vectors_50.txt?dl=1", "https://www.dropbox.com/s/hpjosjxrgbnsyqs/vectors.txt?dl=1", "https://www.dropbox.com/s/8r4524eaus1irdp/vectors.txt?dl=1"] 22 | w2v_50 = ["https://www.dropbox.com/s/bngbeepc3ec3r8h/w2v_wiki.model?dl=1", "https://www.dropbox.com/s/1mx2gcv315cej24/w2v_wiki.model.syn1neg.npy?dl=1","https://www.dropbox.com/s/w5si5ihg73zrbvz/w2v_wiki.model.wv.syn0.npy?dl=1"] 23 | w2v_150 = ["https://www.dropbox.com/s/fh8ng75kehog3jh/w2v_wiki.model?dl=1", "https://www.dropbox.com/s/ddeji4x4t6ds3x1/w2v_wiki.model.syn1neg.npy?dl=1", "https://www.dropbox.com/s/ydmemky269fuoeh/w2v_wiki.model.wv.syn0.npy?dl=1"] 24 | w2v_300 = ["https://www.dropbox.com/s/hhaw2zj81bvuatn/w2v_wiki.model?dl=1", "https://www.dropbox.com/s/2lwafcxqvi2xlbv/w2v_wiki.model.syn1neg.npy?dl=1", "https://www.dropbox.com/s/vpmqrm8q78en38h/w2v_wiki.model.wv.syn0.npy?dl=1"] 25 | 26 | w2v_names_50 = ["w2v_wiki_50.model","w2v_wiki_50.model.syn1neg.npy", "w2v_wiki_50.model.wv.syn0.npy"] 27 | w2v_names_150 = ["w2v_wiki_150.model","w2v_wiki_150.model.syn1neg.npy", "w2v_wiki_150.model.wv.syn0.npy"] 28 | w2v_names_300 = ["w2v_wiki_300.model","w2v_wiki_300.model.syn1neg.npy", "w2v_wiki_300.model.wv.syn0.npy"] 29 | 30 | 31 | 32 | '''def download_file(name, url): 33 | f = urllib2.urlopen(url) 34 | if not(os.path.exists(glove_dir)): 35 | os.makedirs(glove_dir) 36 | 37 | with open(osp.join(glove_dir,name ), "wb") as code: 38 | code.write(f.read()) 39 | ''' 40 | 41 | 42 | if __name__ == "__main__": 43 | vector_size = [50,150,300] 44 | vector_names_glove = ["glove_wiki_50.txt", "glove_wiki_150.txt", "glove_wiki_300.txt"] 45 | vector_names_word2vec_50 = ["w2v_wiki_50.txt"] 46 | print "Starting download" 47 | wget.download(glove_dir, urls=glove_urls, filenames=vector_names_glove, parallelism=3) 48 | 49 | # print "Downloaded", vector_names_glove 50 | #wget.download(word2vec_dir, urls=w2v_50, filenames=w2v_names_50, parallelism=3) 51 | #print "Downloaded", w2v_names_50 52 | #wget.download(word2vec_dir, urls=w2v_150, filenames=w2v_names_150, parallelism=3) 53 | # print "Downloaded", w2v_names_150 54 | #wget.download(word2vec_dir, urls=w2v_300, filenames=w2v_names_300, parallelism=3) 55 | # print "Downloaded", w2v_names_300 56 | # wget.download(glove_dir, urls=[glove_urls[0]],filenames=[vector_names[0]],parallelism=3) 57 | # wget.download(glove_dir, urls=[glove_urls[1]],filenames=[vector_names[1]],parallelism=3) 58 | # wget.download(glove_dir, urls=[glove_urls[2]],filenames=[vector_names[2]],parallelism=3) 59 | 60 | 61 | 62 | ''' 63 | 64 | This takes some time. The files are quite big. 65 | I suggest you try out a new hobby while it is downloading. 66 | 67 | 68 | The files are created at startup and will grow until the program stops. 69 | To monitor size: watch ls -sh file1 file2 file3ve 70 | vectors_50.txt : 629 mb 71 | vectors_150.txt : 1.9 gb 72 | vectors_300.txt : 3.62 gb 73 | 74 | 75 | Do not stop the program, if you do you could end up with a corrupt file 76 | ''' -------------------------------------------------------------------------------- /lib/bt_net/generate_boxes.py: -------------------------------------------------------------------------------- 1 | """ 2 | Library to generate boxes 3 | """ 4 | from multiprocessing import Pool 5 | from skimage.transform import resize 6 | import sys, random, math 7 | import itertools 8 | import numpy as np 9 | 10 | def sliding_window(num_boxes, image, windowSize, workers=3): 11 | image_batch = None 12 | #'''Guarantee that middle of picture is always covered''' 13 | #data = get_middle_box(image, windowSize) 14 | #img = image[data[1]:data[3],data[0]:data[2]] 15 | #image_batch = img[np.newaxis] 16 | 17 | for i in range((num_boxes / workers)): 18 | y = random.randrange(0, image.shape[0] - windowSize[0]) 19 | x = random.randrange(0, image.shape[1] - windowSize[1]) 20 | img = image[y:y + windowSize[1], x:x + windowSize[0]] 21 | img = resize(img, (299, 299)) 22 | 23 | if image_batch is None: 24 | image_batch = img[np.newaxis] 25 | else: 26 | image_batch = np.append(image_batch, img[np.newaxis], axis=0) 27 | return image_batch 28 | 29 | def sliding_window2(num_boxes, image, windowSize, workers=3): 30 | image_batch = [] 31 | for i in range(num_boxes / workers): 32 | y = random.randrange(0, image.shape[0] - windowSize[0]) 33 | x = random.randrange(0, image.shape[1] - windowSize[1]) 34 | img = image[y:y + windowSize[1], x:x + windowSize[0]] 35 | img = resize(img, (299, 299)) 36 | image_batch.append(img[np.newaxis]) 37 | 38 | return image_batch 39 | 40 | 41 | def func_star(a_b): 42 | """Convert `f([1,2])` to `f(1,2)` call.""" 43 | return sliding_window(*a_b) 44 | 45 | 46 | def get_middle_box(image, windowSize): 47 | x = int((image.shape[0] - windowSize[0]) / 2) 48 | y = int((image.shape[1] - windowSize[1]) / 2) 49 | return (x, y, x + windowSize[0], y + windowSize[1]) 50 | 51 | def generate_n_boxes(image, num_boxes=30, temp_size = (500, 700), pool=None): 52 | height = image.shape[0] 53 | width = image.shape[1] 54 | if width > height: 55 | #image = resize(image, temp_size) 56 | windowSize_1 = (int(math.floor(image.shape[0] / 2)), int(math.floor(image.shape[1] / 2))) 57 | windowSize_2 = (int(math.floor(image.shape[0] / 3)), int(math.floor(image.shape[1] / 3))) 58 | windowSize_3 = (int(math.floor(image.shape[0] / 4)), int(math.floor(image.shape[1] / 4))) 59 | else: 60 | #image = resize(image, (temp_size[1], temp_size[0])) 61 | windowSize_1 = (int(math.floor(image.shape[0] / 2)), int(math.floor(image.shape[1] / 2))) 62 | windowSize_2 = (int(math.floor(image.shape[0] / 3)), int(math.floor(image.shape[1] / 3))) 63 | windowSize_3 = (int(math.floor(image.shape[0] / 4)), int(math.floor(image.shape[1] / 4))) 64 | 65 | # windowSize_1 = (int(math.floor(image.shape[0] / 2)), int(math.floor(image.shape[1] / 2))) 66 | # windowSize_2 = (int(math.floor(image.shape[0] / 4)), int(math.floor(image.shape[1] / 4))) 67 | # windowSize_3 = (int(math.floor(image.shape[0] / 6)), int(math.floor(image.shape[1] / 6))) 68 | windowSizes = [windowSize_1, windowSize_2, windowSize_3] 69 | num_boxes, image, workers = num_boxes, image, len(windowSizes) 70 | pool_none = False 71 | if pool is None: 72 | pool = Pool(4) 73 | pool_none = True 74 | im1, im2, im3 = pool.map(func_star, itertools.izip(itertools.repeat(num_boxes), 75 | itertools.repeat(image), 76 | windowSizes , 77 | itertools.repeat(workers))) 78 | if pool_none: 79 | pool.close() 80 | pool.join() 81 | pool.terminate() 82 | result = im1 83 | result = np.append(result, im2, axis=0) 84 | result = np.append(result, im3, axis=0) 85 | #result = im1 + im2 + im3 86 | #print result[0] 87 | #print len(result) 88 | return result 89 | -------------------------------------------------------------------------------- /lib/bt_datasets/open_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | from bt_datasets.imdb import imdb 3 | import numpy as np 4 | import scipy.sparse 5 | import scipy.io as sio 6 | import os.path as osp 7 | from bt_net.config import cfg 8 | 9 | 10 | 11 | class open_images(imdb): 12 | def __init__(self, image_set): 13 | imdb.__init__(self, image_set) 14 | self._image_set = image_set 15 | self._devkit_path = os.path.join(cfg.IMAGE_DATA_DIR, 'openimages') 16 | self._data_path = os.path.join(self._devkit_path) 17 | 18 | if 'train' in self._image_set: 19 | train = osp.join(self._devkit_path,'oi_train_classes.txt') 20 | self._classes = self._generate_class_list(train) 21 | elif 'test' in self._image_set: 22 | test = osp.join(self._devkit_path, 'oi_test_classes.txt') 23 | self._classes = self._generate_class_list(test) 24 | elif 'validation' in self._image_set: 25 | test = osp.join(self._devkit_path, 'oi_test_classes.txt') 26 | self._classes = self._generate_class_list(test) 27 | self._image_index = self._load_image_set_index() 28 | 29 | 30 | 31 | assert os.path.exists(self._data_path), \ 32 | 'Path does not exist: {}'.format(self._data_path) 33 | 34 | def _generate_class_list(self,filename, file_name=None): 35 | self._classes = ('__background__',) 36 | with open(filename) as f: 37 | for line in f: 38 | self._classes = self._classes + (line.strip(),) 39 | return self._classes 40 | 41 | def image_path_at(self, i): 42 | """ 43 | Return the absolute path to image i in the image sequence. 44 | """ 45 | return self.image_path_from_index(self._image_index[i]) 46 | 47 | def get_image_data_clean_train(self, index, lang_db): 48 | ''' 49 | 50 | :param index: Take index 0 - imdb.num_images 51 | :param lang_db: Can be glove_wiki_300D for example 52 | :return: path for image, word_vector and name of word_vectors 53 | ''' 54 | data = self._image_index[index] 55 | image_data = data[0] 56 | lbls = data[1:] 57 | 58 | img_path = osp.join(self._data_path, image_data) 59 | word_vectors = [] 60 | for lbl in lbls: 61 | word_vec = lang_db.word_vector(lbl) 62 | word_vectors.append(word_vec) 63 | return img_path, word_vectors, lbls 64 | 65 | def get_image_data_clean(self, index, lang_db): 66 | ''' 67 | :param index: Take index 0 - imdb.num_images 68 | :param lang_db: Can be glove_wiki_300D for example 69 | :return: path for image and word_vector 70 | ''' 71 | data = self._image_index[index] 72 | image_data = data[0] 73 | lbls = data[1:] 74 | img_path = osp.join(self._data_path, image_data) 75 | word_vectors = [] 76 | for lbl in lbls: 77 | word_vec = lang_db.word_vector(lbl) 78 | word_vectors.append(word_vec) 79 | return img_path, word_vectors 80 | 81 | def image_path_from_index(self, index): 82 | """ 83 | Construct an image path from the image's "index" identifier. 84 | """ 85 | image_path = os.path.join(self._data_path, 86 | index) 87 | assert os.path.exists(image_path), \ 88 | 'Path does not exist: {}'.format(image_path) 89 | return image_path 90 | 91 | def _load_image_set_index(self): 92 | """ 93 | Load the indexes listed in this dataset's image set file 94 | """ 95 | # Example path to image set file: 96 | # self._data_path + /ImageSets/val.txt 97 | image_set_file = os.path.join(self._data_path, self._image_set + '.txt') 98 | assert os.path.exists(image_set_file), 'Path does not exist: {}'.format(image_set_file) 99 | with open(image_set_file) as f: 100 | image_index = [x.strip().split(",") for x in f] 101 | return image_index 102 | 103 | 104 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Zero-Shot MultiLabel 2 | ###### code for [master thesis](https://brage.bibsys.no/xmlui/bitstream/handle/11250/2459946/15964_FULLTEXT.pdf?sequence=1&isAllowed=y) on Zero-Shot Classification with multilabel data. 3 | 4 | ## Abstract 5 | Visual recognition systems are often limited to the object categories previously trained on and thus suffer in their ability to scale. This is in part due to the difficulty of acquiring sufficient labeled images as the number of object categories grows. To solve this, earlier research have presented models that uses other sources, such as text data, to help classify object categories unseen during training. However, most of these models are limited on images with a single label and most images can contain more than one object category, and therefore more than one label. This master's thesis implements a model capable of classifying unseen categories for both single- and multi-labeled images. 6 | 7 | The architecture consist of several modules: A pre-trained neural network that generates image features for each image, a model trained on text that represents words as vectors, and a neural network that projects the image features to the dimension native to the vector representation of words. On this architecture, we compared two approaches to generate word vectors using GloVe and Word2vec, with different vector dimensions and on spaces containing different numbers of word vectors. The model was adapted to multi-label predictions comparing three approaches for image box generation: YOLOv2, Faster R-CNN and randomly generated boxes. Here each box represents a section of the image cut out and this approach was chosen to fit each label to a one of these boxes. 8 | 9 | The results showed that increasing the word vector dimension increased the accuracy, with Word2vec outperforming GloVe, and when adding more words to the word vector space the accuracy dropped. In the single-label scenario the model achieves similar results to existing models with similar architecture. While in the multi-label scenario, the model trained on boxes generated by Faster R-CNN and predicted on random generated boxes had highest accuracy, but was not able to outperform comparative alternatives. The architecture gives promising results, but more investigation is needed to answer if the results can be improved further. 10 | 11 | ## Dependencies 12 | Usage: 13 | - Python 2.7 14 | - [Glove](https://nlp.stanford.edu/projects/glove/) 15 | - [Gensim](https://radimrehurek.com/gensim/) 16 | - [ANNOY](https://github.com/spotify/annoy) 17 | - [h5py](http://www.h5py.org/) 18 | - opencv2 - available from pip 19 | - numpy - available from pip 20 | 21 | ## Object detection frameworks 22 | - [py-faster-rcnn](https://github.com/rbgirshick/py-faster-rcnn) 23 | - [YOLO](https://github.com/philipperemy/yolo-9000) 24 | 25 | 26 | ## Downloadables 27 | - Weightfile for yolo9000 can be downloaded at [http://pjreddie.com/media/files/yolo9000.weights]( http://pjreddie.com/media/files/yolo9000.weights). 28 | - Language model file consisting of word2vec and Glove vectors: [Language models](https://mega.nz/#!GZ0xWSZQ!NEe8ck_KvOHB03t7vlePbJNtXy3vvQAREqwTV59afVk) 29 | - Single- and multi-label weights trained on Glove and Word2vec (50D,150D,300D): [Single- and multi-label weights](https://drive.google.com/open?id=0B3B4Y0zc2AQEbnlTNXZhTDlGak0) 30 | 31 | ## Before training and testing 32 | - Download the pre-trained language model vectors. 33 | - Use py-faster-rcnn or YOLO to compute region-of-interest boxes. 34 | 35 | ## Train Zero-Shot model 36 | ### Single-label data 37 | ``` 38 | python tools/train_brute_force.py --imdb dataset --lm language model (e.g. w2v_wiki_300D) --loss squared_hinge --iters 10000 39 | ``` 40 | ### Multi-label data 41 | ``` 42 | python tools/train_ml_brute_force.py --imdb dataset --lm language model (e.g. w2v_wiki_300D) --loss squared_hinge --model ZSL_model (pre-trained on single-label data) --boxes (random, frcnn or yolo)--iters 10000 43 | ``` 44 | 45 | ## Test Zero-Shot model 46 | ### Single-label data 47 | ``` 48 | python tools/test_brute_force.py --lm glove_wiki_300 --imdb imagenet_zs --ckpt output/train_bts/model_glove_wiki_300.hdf5 --singlelabel_predict 49 | ``` 50 | ### Multi-label data 51 | ``` 52 | python tools/test_brute_force.py --lm glove_wiki_300 --imdb imagenet_zs --ckpt output/train_bts/model_glove_wiki_300.hdf5 --boxes faster_rcnn 53 | ``` 54 | 55 | -------------------------------------------------------------------------------- /lib/preprocessing/imagenet_imagelist_cleaning.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import xml.etree.ElementTree as ET 3 | import xml.dom.minidom as minidom 4 | import numpy as np 5 | 6 | import random 7 | import os.path as osp 8 | 9 | def generate_no_extra(imagelist_path, name = "train_no_extra.txt", origin = 'train.txt', size = 1000): 10 | name = os.path.join(imagelist_path, name) 11 | origin = os.path.join(imagelist_path, origin) 12 | content = [] 13 | with open(origin) as f, open(name, 'w') as wf: 14 | content = [x.strip() for x in f.readlines()] 15 | random.shuffle(content) 16 | for item in content: 17 | if 'extra' in item: continue 18 | wf.write("%s\n" % item) 19 | 20 | def generate_small_train(name, origin, size): 21 | content = [] 22 | with open(origin) as f, open(name, 'w') as wf: 23 | content = [x.strip() for x in f.readlines()] 24 | random.shuffle(content) 25 | for item in content[:size]: 26 | wf.write("%s\n" % item) 27 | print("Successfully generated small dataset ", name) 28 | 29 | def check_image_size(index, filename, num_classes = 201): 30 | tree = ET.parse(filename) 31 | objs = tree.findall('object') 32 | im_size = tree.find("size") 33 | im_width = float(im_size.find('width').text) 34 | im_height = float(im_size.find('height').text) 35 | im_ratio = float(im_width / im_height) 36 | #print im_ratio 37 | if not(0.117 < im_ratio < 15.5): 38 | #print("Image ignored due to im_ratio out of boundaries, ", index) 39 | return False 40 | if (im_width < 127 and im_height < 96) or (im_width>500 and im_height>500): 41 | return False 42 | 43 | return True 44 | 45 | def clean_imagelist(imagenet_path): 46 | added = 0 47 | with open(imagenet_path + '/data/imagenet/data/ImageSets/DET/train_no_extra.txt') as f, open(imagenet_path + 'clean_imagelist.txt', 'w') as wf: 48 | i = 0 49 | for line in f: 50 | index = line.split() 51 | filename = os.path.join(imagenet_path, 'ILSVRC/Annotations/DET/train', index[0] + '.xml') 52 | if check_image_size(index, filename): 53 | added = added + 1 54 | wf.write("%s\n" % line.strip()) 55 | i = i+1 56 | if i%1000==0: 57 | print added,'/',i 58 | 59 | def random_line(filename): 60 | line_num = 0 61 | selected_line = '' 62 | with open(filename) as f: 63 | while 1: 64 | line = f.readline() 65 | if not line: break 66 | line_num += 1 67 | if random.uniform(0, line_num) < 1: 68 | selected_line = line 69 | return selected_line.strip() 70 | 71 | 72 | def combine_synsets(): 73 | line1 = random_line(cfg.ROOT_DIR + "/data/imagenet/data/ImageSets/DET/train_no_extra.txt") 74 | synset1 = line1.split("/")[1] 75 | line2 = random_line(cfg.ROOT_DIR + "/data/imagenet/data/ImageSets/DET/train_no_extra.txt") 76 | synset2 = line2.split("/")[1] 77 | 78 | variable = 0 79 | if synset1 == synset2: 80 | variable =1 81 | 82 | return line1, line2, variable 83 | 84 | 85 | def combine_synsets2(eq=None): 86 | line1 = random_line(cfg.ROOT_DIR + "/data/imagenet/data/ImageSets/DET/train_no_extra.txt") 87 | synset1 = line1.split("/")[1] 88 | get_new = True 89 | while get_new: 90 | line2 = random_line(cfg.ROOT_DIR + "/data/imagenet/data/ImageSets/DET/train_no_extra.txt") 91 | synset2 = line2.split("/")[1] 92 | 93 | if synset1 == synset2: 94 | variable =1 95 | else: 96 | variable = 0 97 | if variable == eq: 98 | get_new = False 99 | return line1, line2, variable 100 | 101 | 102 | if __name__ == "__main__": 103 | sys.path.insert(0, osp.join(osp.dirname(__file__), '..')) 104 | from config import cfg 105 | 106 | with open(cfg.ROOT_DIR + "/data/imagenet/data/ImageSets/DET/train_siamese.txt", "w") as save: 107 | for i in xrange(100): 108 | eq = 0 109 | if i % 2 == 0: 110 | eq = 1 111 | line1, line2, variable =combine_synsets2(eq=eq) 112 | save.write(line1 + " " + line2 + " " + str(variable)) 113 | 114 | #root_path = cfg.ROOT_DIR 115 | #clean_imagelist(root_path) 116 | 117 | 118 | 119 | #imagelist_path = os.path.join(imagenet_path, "ILSVRC/ImageSets/DET") 120 | #generate_no_extra(imagelist_path) 121 | -------------------------------------------------------------------------------- /lib/bt_net/visualize_graph.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import os.path as osp 3 | import numpy as np 4 | from operator import itemgetter 5 | from collections import Counter 6 | from language_models.language_factory import get_language_model 7 | from bt_datasets.factory import get_imdb 8 | from sklearn.metrics.pairwise import cosine_distances 9 | 10 | def vis(x, y, xlabel, ylabel): 11 | plt.xlabel(xlabel) 12 | plt.ylabel(ylabel) 13 | plt.plot(x, y, 'ro') 14 | plt.show() 15 | 16 | def vis_distance(lang_db, imdb, seen_labels): 17 | # Precision results 18 | evaluation_path = osp.join('output', 'evaluate_results') 19 | 20 | with open(osp.join(evaluation_path, 'results_pr_class.txt')) as f: 21 | scores = [] 22 | distances = [] 23 | for line in f: 24 | words = line.strip().split(' ') 25 | words = [x.strip() for x in words] 26 | c = words[0] 27 | c_vector = [lang_db.word_vector(c)] 28 | closest_labels = lang_db.closest_labels(c_vector, k_closest = 5) 29 | 30 | closest_seen = None 31 | for closest in closest_labels: 32 | if closest in seen_labels: 33 | closest_seen = closest 34 | break 35 | if closest_seen is None: 36 | print "closest is none" 37 | 38 | cv = [np.array(lang_db.word_vector(closest_seen))] 39 | distances.append(float(cosine_distances(c_vector, cv))) 40 | scores.append(words[1]) 41 | #print distances 42 | #print scores 43 | vis(distances, scores, 'distance', 'score') 44 | 45 | def vis_top_classes(lang_db, imdb, seen_labels): 46 | evaluation_path = osp.join('output', 'evaluate_results') 47 | 48 | c_vector = [lang_db.word_vector('sunglass')] 49 | closest_labels = lang_db.closest_labels(c_vector, k_closest = 10) 50 | print closest_labels 51 | with open(osp.join(evaluation_path, 'results_pr_class.txt')) as f: 52 | classes = [] 53 | for line in f: 54 | c = {} 55 | words = line.strip().split(' ') 56 | words = [x.strip() for x in words] 57 | c['label'] = words[0] 58 | c['score'] = words[1] 59 | c_vector = [lang_db.word_vector(c['label'])] 60 | closest_labels = lang_db.closest_labels(c_vector, k_closest = 10) 61 | c_labels = closest_labels 62 | ''' 63 | c_labels = [] 64 | i = 0 65 | for closest in closest_labels: 66 | if closest in seen_labels: 67 | c_labels.append(closest) 68 | i += 1 69 | if i == 5: break 70 | ''' 71 | c['nn'] = c_labels 72 | classes.append(c) 73 | sorted_classes = sorted(classes, key=lambda k: k['score'], reverse=True) 74 | for c in sorted_classes[:15]: 75 | print "Class: {}, Score: {}, Closest: {}".format(c['label'], c['score'], c['nn']) 76 | 77 | def find_distribution_accuracy(scores): 78 | dist = [0.0, 1.0, 5.0, 10.0, 25.0, 50.0, 75.0, 85.0, 90.0, 100.0] 79 | current = 0.0 80 | i = 0 81 | total = 0 82 | print scores 83 | for s in scores: 84 | if s > dist[i]: 85 | print "{}: {}".format(dist[i], total) 86 | total = 1 87 | i += 1 88 | continue 89 | total += 1 90 | print "{}: {}".format(dist[i], total) 91 | 92 | 93 | def vis_bar_top(): 94 | evaluation_path = osp.join('output', 'evaluate_results') 95 | with open(osp.join(evaluation_path, 'imagenet_sl_results_pr_class.txt')) as f: 96 | scores = [] 97 | for line in f: 98 | c = {} 99 | words = line.strip().split(' ') 100 | words = [x.strip() for x in words] 101 | scores.append(float(words[1])) 102 | 103 | sort = sorted(scores, key=float) 104 | find_distribution_accuracy(sort) 105 | x = range(len(sort)) 106 | width = 1 107 | 108 | plt.bar(x, sort, width) 109 | plt.show() 110 | 111 | def vis_graph(lang_name, imdb_name, vis): 112 | if 0: 113 | vis_bar_top() 114 | 115 | if 1: 116 | 117 | lang_db = get_language_model(lang_name) 118 | imdb = get_imdb(imdb_name) 119 | unseen_labels = imdb.get_labels(1) 120 | unseen_seen_labels = imdb.get_labels(2) 121 | seen_labels = [] 122 | for s in unseen_seen_labels: 123 | if s in unseen_labels: 124 | continue 125 | seen_labels.append(s) 126 | 127 | lang_db.build_tree(2, unseen_seen_labels, imdb_name) 128 | 129 | #vis_distance(lang_db, imdb, seen_labels) 130 | vis_top_classes(lang_db, imdb, seen_labels) 131 | -------------------------------------------------------------------------------- /lib/bt_net/faster_rcnn_boxes.py: -------------------------------------------------------------------------------- 1 | ### Returns boxes and its matrices generated by faster-rcnn 2 | import os, sys, cv2 3 | import os.path as osp 4 | import pickle 5 | import itertools 6 | import numpy as np 7 | from bt_net.nms import py_cpu_nms, py_area_fix 8 | from bt_net.config import cfg 9 | from preprocessing.preprocess_images import preprocess_array 10 | from operator import itemgetter 11 | from random import randint 12 | import matplotlib as mpl 13 | mpl.use('Agg') 14 | import matplotlib.pyplot as plt 15 | 16 | def vis_proposals(img_path, im, dets, thresh = 0.8): 17 | inds = np.where(dets[:, -1] >= thresh)[0] 18 | if len(inds) == 0: 19 | return 20 | 21 | class_name = 'obj' 22 | im = im[:, :, (2, 1, 0)] 23 | fig, ax = plt.subplots(figsize=(12, 12)) 24 | ax.imshow(im, aspect='equal', interpolation='nearest') 25 | for i in inds: 26 | bbox = dets[i, :4] 27 | score = dets[i, -1] 28 | make_color = lambda: (float(randint(50, 255))/255, float(randint(50, 255))/255, float(randint(50, 255))/255) 29 | color = make_color() 30 | ax.add_patch( 31 | plt.Rectangle((bbox[0], bbox[1]), 32 | bbox[2] - bbox[0], 33 | bbox[3] - bbox[1], fill=False, 34 | edgecolor=color, linewidth=3.5) 35 | ) 36 | ''' 37 | ax.text(bbox[0], bbox[1] - 2, 38 | '{:s} {:.3f}'.format(class_name, score), 39 | bbox=dict(facecolor='blue', alpha=0.5), 40 | fontsize=14, color='white') 41 | 42 | ax.set_title(('{} detections with ' 43 | 'p({} | box) >= {:.1f}').format(class_name, class_name, 44 | thresh), fontsize=14) 45 | ''' 46 | plt.axis('off') 47 | plt.tight_layout() 48 | #plt.draw() 49 | print img_path 50 | plt.savefig(img_path) 51 | print 'Saved: {}'.format(img_path) 52 | 53 | def boxes_to_images(img_path, image, dets, target_size = (299, 299)): 54 | images = [] 55 | ''' 56 | import cv2 57 | img = cv2.imread(img_path) 58 | vis_proposals(img_path[:-4] + 'frcnn.jpg', img, dets, thresh = 0.0) 59 | sys.exit(0) 60 | ''' 61 | for i in range(len(dets)): 62 | bbox = dets[i, :4] 63 | ''' 64 | y1:y2, x1:x2 65 | ''' 66 | box = image[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])] 67 | h,w,color = box.shape 68 | if h == 0 or w ==0: 69 | ''' 70 | edge case for preprocess_array 71 | ''' 72 | continue 73 | img = preprocess_array(box) 74 | images.append(img) 75 | h, w, color = image.shape 76 | if h != 0 or w != 0: 77 | ''' 78 | edge case for preprocess_array 79 | ''' 80 | images.append(preprocess_array(image)) 81 | return np.array(images) 82 | 83 | def load_frcnn_boxes(imdb): 84 | """ 85 | Loads the precomputed boxes and features from py-faster-rcnn 86 | """ 87 | location = osp.join(cfg.FASTER_RCNN_DIR, 'output', 'default', imdb.name, 'VGG16_faster_rcnn_final') 88 | data_file = osp.join(location, 'VGG16_faster_rcnn_final_rpn_proposals.pkl') 89 | scores_file = osp.join(location, 'VGG16_faster_rcnn_final_rpn_proposals_scores.pkl') 90 | 91 | data = np.array(pickle.load(open(data_file, 'rb'))) 92 | scores = np.array(pickle.load(open(scores_file, 'rb'))) 93 | 94 | boxes = [] 95 | for d, s in itertools.izip(data, scores): 96 | dets = np.hstack((d, s)) 97 | keep = py_cpu_nms(dets[:30, :], 1) 98 | boxes.append(np.array([dets[i] for i in keep])) 99 | return boxes 100 | 101 | def load_all_boxes_yolo(imdb): 102 | root_dir = cfg.ROOT_DIR 103 | if "test" in imdb.name or "Test" in imdb.name: 104 | pickle_dir = osp.join(root_dir, "snapshots", "nus_wide_test_81_nms.p") 105 | else: 106 | pickle_dir = osp.join(root_dir, "snapshots","nus_wide_train_920.p") 107 | dictionary = pickle.load(open(pickle_dir, "rb")) 108 | return dictionary 109 | 110 | def load_image_boxes_yolo(image, dictionary, x_path): 111 | 112 | img_path = x_path 113 | 114 | 115 | x_path = x_path.split("/") 116 | x_path = x_path[-1].split(".") 117 | x_path = x_path[0] 118 | #print "path",x_path 119 | if x_path not in dictionary: 120 | return np.array([]) # returning empty array if not existing 121 | #### 122 | #### x1 y1 x2 y2 prob 123 | #### 124 | boxes = dictionary[x_path] 125 | ''' 126 | Sort on prob 127 | ''' 128 | boxes = sorted(boxes, key=itemgetter(4)) 129 | boxes = np.array(boxes[:100]) 130 | #print "len of boxes before nms", len(boxes) 131 | keep = py_cpu_nms(boxes, 0.5) 132 | boxes = np.array([boxes[ind] for ind in keep]) 133 | ''' 134 | import cv2 135 | img = cv2.imread(img_path) 136 | vis_proposals(img_path[:-4] + 'yolo.jpg', img, boxes, thresh = 0.0) 137 | sys.exit(0) 138 | ''' 139 | #keep = py_area_fix(image, boxes, 0.3) 140 | #boxes = np.array([boxes[ind] for ind in keep]) 141 | #print "len of boxes after nms", len(boxes) 142 | boxes = boxes_to_images(image, boxes, target_size=(299, 299)) 143 | return boxes 144 | 145 | 146 | 147 | 148 | -------------------------------------------------------------------------------- /lib/bt_datasets/places205.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Made by Bjoernar Remmen 3 | # Short is better 4 | 5 | import os, sys 6 | import os.path as osp 7 | from copy import deepcopy 8 | from random import shuffle 9 | from bt_datasets.imdb import imdb 10 | from bt_net.config import cfg 11 | import numpy as np 12 | import scipy.io as sio 13 | import cPickle 14 | import xml.etree.ElementTree as ET 15 | import xml.dom.minidom as minidom 16 | from preprocessing.preprocess_images import read_img 17 | 18 | from random import shuffle 19 | 20 | 21 | class places205(imdb): 22 | def __init__(self, image_set): 23 | imdb.__init__(self, 'places205_' + image_set) 24 | self._image_set = image_set 25 | self._devkit_path = osp.join(cfg.IMAGE_DATA_DIR, 'places205') 26 | self._data_path = osp.join(self._devkit_path, 'data','vision', 'torralba', 'deeplearning','images256') 27 | 28 | 29 | self._classes = ('__background__',) 30 | # Change to labels.txt file if generating ZS test for NUS-WIDE 31 | with open(osp.join(self._devkit_path, 'labels_no_overlap.txt')) as lbls: 32 | for lbl_line in lbls: 33 | self._classes = self._classes + (lbl_line.rstrip(),) 34 | print "classes", len(self._classes) 35 | 36 | self._image_ext = ['.JPG'] 37 | self._image_index = self._load_image_set_index() 38 | print "image_index", len(self._image_index) 39 | 40 | 41 | 42 | # Specific config options 43 | self.config = {'cleanup': True, 44 | 'use_salt': True, 45 | 'top_k': 2000} 46 | 47 | assert os.path.exists(self._devkit_path), \ 48 | 'Devkit path does not exist: {}'.format(self._devkit_path) 49 | assert os.path.exists(self._data_path), \ 50 | 'Path does not exist: {}'.format(self._data_path) 51 | 52 | 53 | def _load_image_set_index(self): 54 | """ 55 | Load the indexes listed in this dataset's image set file. 56 | """ 57 | # Example path to image set file: 58 | # self._data_path + /ImageSets/val.txt or self._data_path + /ImageSets/zero_shot.txt 59 | image_set_file = os.path.join(self._devkit_path, self._image_set + '.txt') 60 | print "img_file", image_set_file 61 | assert os.path.exists(image_set_file), \ 62 | 'Path does not exist: {}'.format(image_set_file) 63 | 64 | with open(image_set_file) as f: 65 | image_index = [] 66 | for line in f: 67 | line = line.split() 68 | image_index.append(line) 69 | #image_index = [line.split() for line in f] 70 | return image_index 71 | 72 | def load_classes(self): 73 | return self._classes[1:] # Return classes except background 74 | def get_existing_lbls(self,lang_db, exclude =()): 75 | lbls = self.get_classes() 76 | for ex in exclude: 77 | try: 78 | lbls.remove(ex) 79 | except: 80 | lolz = 1 81 | 82 | existing = [] 83 | for lbl in lbls: 84 | 85 | vec = lang_db.word_vector(lbl) 86 | if vec is not None: existing.append(lbl) 87 | 88 | return existing 89 | 90 | # exclude = (bridge, castle, harbor, mountain, ocean, sky, tower, valley) 91 | def generate_zs_split(self,lang_db, exclude = ("bridge", "castle", "harbor", "mountain", "ocean", "sky", "tower", "valley")): 92 | ''' 93 | 94 | :param lang_db: language model. Such as glove_wiki_300D 95 | :param exclude: Excluding the overlapping classes from NUS-WIDE 96 | :return: None. Just write file 97 | ''' 98 | lbls = self.get_existing_lbls(lang_db, exclude=exclude) 99 | shuffle(lbls) 100 | counter = 0 101 | if len(exclude) > 0: 102 | #filename = "train_"+str(len(exclude))+".txt" 103 | filename = "test_8_val.txt" 104 | else: 105 | #filename = "train.txt" 106 | filename = "test_val.txt" 107 | with open(osp.join(cfg.IMAGE_DATA_DIR, 'places205', 'val_places205.txt')) as file, open(osp.join(cfg.IMAGE_DATA_DIR, 'places205', filename), "w") as train: 108 | for line in file: 109 | if counter % 100 ==0: 110 | print("written", counter, "lines") 111 | 112 | lbl_ind = int(line.rstrip().split(" ")[1]) 113 | lbl = self._classes[lbl_ind] 114 | if lbl in lbls: 115 | counter += 1 116 | line = line.rstrip().split(" ") 117 | path = line[0] 118 | train.write(path + " " + lbl + "\n") 119 | 120 | 121 | 122 | def get_image_data_clean(self, index, lang_db): 123 | data = self._image_index[index] 124 | image_data = data[0] 125 | lbls = [x.strip() for x in data[1:]] 126 | img_path = osp.join(self._data_path, image_data) 127 | word_vectors = [] 128 | for lbl in lbls: 129 | word_vec = lang_db.word_vector(lbl) 130 | word_vectors.append(word_vec) 131 | return img_path , word_vectors 132 | 133 | def image_path_at(self,i): 134 | data = self._image_index[i] 135 | image_data = data[0] 136 | img_path = osp.join(self._data_path, image_data) 137 | return img_path 138 | 139 | def gt_at(self,i): 140 | data = self._image_index[i] 141 | lbls = [x.strip() for x in data[1:]] 142 | return lbls 143 | 144 | 145 | 146 | def get_classes(self): 147 | return self._classes[1:] 148 | 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /lib/preprocessing/load_langmod.py: -------------------------------------------------------------------------------- 1 | # /usr/local/bin/python 2 | # Made by Bjoernar Remmen 3 | # Short is better 4 | import csv 5 | import time 6 | 7 | import h5py 8 | from glove import Glove 9 | 10 | root = "../../" 11 | glove_vec_fold = "data/lm_data/glove/" 12 | 13 | import numpy as np 14 | import math 15 | from sklearn.preprocessing import normalize 16 | #from keras import backend as K 17 | 18 | 19 | def load_glove(name): 20 | model = Glove.load_stanford(root + glove_vec_fold + name) 21 | return model 22 | 23 | def get_features(model): 24 | ''' 25 | :param glove_model: 26 | :return list of tuple(word,feature): 27 | ''' 28 | features = [] 29 | words = [] 30 | for word, index in model.dictionary.iteritems(): 31 | features.append(model.word_vectors[index]) 32 | words.append(word) 33 | 34 | return features, words 35 | 36 | 37 | def save_h5py(arrays, string_arrs, names, filename="glove.h5"): 38 | with h5py.File(filename, "w") as hf: 39 | for i in range(len(arrays)): 40 | hf.create_dataset(names[i], data=arrays[i]) 41 | string_dt = h5py.special_dtype(vlen=str) 42 | hf.create_dataset(names[i] + "_words", data=string_arrs[i], dtype=string_dt) 43 | 44 | return True 45 | 46 | 47 | def load_h5py(name, filename="glove.h5"): 48 | with h5py.File(filename, "r") as hf: 49 | data = hf[name][:] 50 | string_arr = hf[name + "_words"][:] 51 | return data, string_arr 52 | 53 | 54 | def group_list(l, group_size): 55 | """ 56 | :param l: list 57 | :param group_size: size of each group 58 | :return: Yields successive group-sized lists from l. 59 | """ 60 | for i in xrange(0, len(l), group_size): 61 | yield l[i:i+group_size] 62 | 63 | def load_pretrained(): 64 | 65 | #glove_vec = ["glove_wiki_50","glove_wiki_150","glove_wiki_300"] 66 | glove_vec = ["glove_wiki_300"] 67 | #glove_vec = ["glove_wiki_50"] 68 | filename = 'glove_pretrained.h5' 69 | #import tensorflow as tf 70 | #sess = tf.InteractiveSession() 71 | 72 | features, words = load_h5py('glove_wiki_300',filename=root + glove_vec_fold + filename) 73 | filename = 'glove.h5' 74 | features = normalize(np.array(features), axis=1, norm='l2') 75 | with h5py.File(root + glove_vec_fold + filename, "w") as hf: 76 | hf.create_dataset(glove_vec[0], data=features) 77 | string_dt = h5py.special_dtype(vlen=str) 78 | hf.create_dataset(glove_vec[0] + "_words", data=words, dtype=string_dt) 79 | 80 | for vec in glove_vec: 81 | data, words = load_h5py(vec, filename=root + glove_vec_fold + "glove.h5") 82 | print(data.shape, words.shape) 83 | time.sleep(5) 84 | 85 | if __name__ == "__main__": 86 | #glove_vec = ["glove_wiki_50","glove_wiki_150","glove_wiki_300"] 87 | 88 | 89 | glove_vec = ["glove_wiki_300"] 90 | filename = 'glove.h5' 91 | ######### 92 | 93 | load_pretrained() 94 | 95 | 96 | with h5py.File(root + glove_vec_fold + filename, "w") as hf: 97 | for vec in glove_vec: 98 | model = load_glove(vec+".txt") 99 | #ii = model.most_similar('woman',number=5) 100 | feature_list, words = get_features(model) 101 | #feature_list = feature_list.tolist() 102 | counter = 0 103 | #for i in range(len(words) - 1, -1, -1): 104 | #try: 105 | # words[i] = words[i].decode('utf-8') 106 | #except: 107 | # print "exception", words[i] 108 | # feature_list.pop(i) 109 | # words.pop(i) 110 | 111 | 112 | feature_list = np.array(feature_list) 113 | #feature_list = normalize(np.array(feature_list), axis=1, norm='l2') 114 | #feature_list = K.l2_normalize(np.array(feature_list),axis=1).eval() 115 | print("adding to word array") 116 | words = np.array(words, dtype=object) 117 | del model 118 | print("writing to file") 119 | hf.create_dataset(vec, data=feature_list) 120 | string_dt = h5py.special_dtype(vlen=str) 121 | hf.create_dataset(vec + "_words", data=words, dtype=string_dt) 122 | 123 | #del model # Trigger garbage collector to free memory. 124 | #save_h5py(feature_list, words, vec, filename=root + glove_vec_fold + "glove") 125 | 126 | 127 | 128 | for vec in glove_vec: 129 | data, words = load_h5py(vec, filename=root + glove_vec_fold + "glove.h5") 130 | print(data.shape, words.shape) 131 | time.sleep(5) 132 | 133 | 134 | 135 | ''' 136 | 137 | def get_id_dict(filename=None): 138 | id_to_word = {} 139 | with open(filename, 'rb') as f: 140 | reader = csv.reader(f) 141 | 142 | for line in reader: 143 | id_to_word[line[0]] = line[1].lower() 144 | 145 | return id_to_word 146 | def get_oi_labels(id_to_word): 147 | unique = {} 148 | open_images = "/media/bjotta/13f2cffb-0a7d-41b9-946f-36d679d1e9f6/home/GloVe/data/machine_ann_2016_08/train/labels.csv" 149 | with open(open_images, 'rb') as file: 150 | reader = csv.reader(file) 151 | reader.next() 152 | for line in reader: 153 | label_ID = line[2].lower() 154 | if id_to_word[label_ID] not in unique: 155 | unique[id_to_word[label_ID]] = 1 156 | return unique.keys() 157 | 158 | id_to_word = get_id_dict(filename="/media/bjotta/13f2cffb-0a7d-41b9-946f-36d679d1e9f6/home/GloVe/data/dict.csv") 159 | unique_words = get_oi_labels(id_to_word) 160 | glove_vec = ['glove_oi_50'] 161 | counter = 0 162 | not_in ={} 163 | for vec in glove_vec: 164 | model = load_glove(vec+".txt") 165 | feature_list, glove_words = get_features(model) 166 | for word in unique_words: 167 | if word not in glove_words and word not in not_in: 168 | not_in[word] =1 169 | counter +=1 170 | print word, counter 171 | 172 | ''' 173 | 174 | 175 | 176 | 177 | -------------------------------------------------------------------------------- /lib/bt_net/test_bts.py: -------------------------------------------------------------------------------- 1 | import os, cv2, operator 2 | import os.path as osp 3 | from keras.models import load_model 4 | from bt_net.generate_boxes import generate_n_boxes 5 | from bt_net.train_bts import euclidean_distance 6 | from bt_net.faster_rcnn_boxes import load_frcnn_boxes, vis_proposals, boxes_to_images, load_all_boxes_yolo, load_image_boxes_yolo 7 | from preprocessing.preprocess_images import read_img 8 | import numpy as np 9 | from multiprocessing import Pool 10 | 11 | 12 | def predict_image_singlelabel(img_path, model, lang_db, extra = None): 13 | img = read_img(img_path, target_size=(299, 299)) 14 | img = np.expand_dims(img, axis=0) 15 | pred_vector = model.predict(img) 16 | predictions = lang_db.closest_labels(pred_vector, k_closest = 20) 17 | return predictions 18 | 19 | def predict_ml_frcnn(img_path, model, lang_db, extra, k_closest = 1): 20 | img = read_img(img_path) 21 | boxes = boxes_to_images(img_path, img, extra) 22 | pred_vects = model.predict_on_batch(boxes) # Return predictions on the boxes, predictions 300d vector 23 | 24 | predictions = lang_db.closest_labels(pred_vects, k_closest=k_closest) # Return closest label to each vector box 25 | dict_counter = {} 26 | preds = [] 27 | for k in range(k_closest): 28 | for i in range(len(predictions)): 29 | preds.append(predictions[i][k]) 30 | 31 | for pred in preds: 32 | if pred in dict_counter: 33 | dict_counter[pred] += 1 34 | else: 35 | dict_counter[pred] = 1 36 | sorted_dict = sorted(dict_counter.items(), key=operator.itemgetter(1), reverse=True) 37 | return [x[0] for x in sorted_dict[:25]] 38 | 39 | 40 | def predict_ml_random(img_path, model, lang_db, extra): 41 | # Generate boxes for image 42 | img = read_img(img_path, target_size=(299, 299)) 43 | boxes = generate_n_boxes(img, pool = extra) 44 | pred_vects = model.predict_on_batch(boxes) # Return predictions on the boxes, predictions 300d vector 45 | predictions = lang_db.closest_labels(pred_vects, k_closest=1) # Return closest label to each vector box 46 | dict_counter = {} 47 | for pred in predictions: 48 | pred = pred[0] 49 | # Retrieve closest word in semantic space 50 | if pred in dict_counter: 51 | dict_counter[pred] += 1 52 | else: 53 | dict_counter[pred] = 1 54 | 55 | sorted_dict = sorted(dict_counter.items(), key=operator.itemgetter(1), reverse=True) 56 | return [x[0] for x in sorted_dict[:20]] 57 | 58 | def predict_ml_yolo(img_path, model, lang_db, dictionary,k_closest = 1): 59 | img = read_img(img_path) 60 | boxes = load_image_boxes_yolo(img, dictionary, img_path) 61 | if len(boxes) == 0: return [] 62 | pred_vects = model.predict_on_batch(boxes) # Return predictions on the boxes, predictions 300d vector 63 | 64 | predictions = lang_db.closest_labels(pred_vects, k_closest=k_closest) # Return closest label to each vector box 65 | dict_counter = {} 66 | preds = [] 67 | 68 | for k in range(k_closest): 69 | for i in range(len(predictions)): 70 | preds.append(predictions[i][k]) 71 | 72 | for pred in preds: 73 | if pred in dict_counter: 74 | dict_counter[pred] += 1 75 | else: 76 | dict_counter[pred] = 1 77 | sorted_dict = sorted(dict_counter.items(), key=operator.itemgetter(1), reverse=True) 78 | return [x[0] for x in sorted_dict[:25]] 79 | 80 | def test_bts(lang_db, imdb, checkpoint, euc_loss, singlelabel_predict, space, box_method): 81 | # Load trained model 82 | if euc_loss: 83 | print "Detected euclidean loss" 84 | model = load_model(checkpoint, custom_objects={'euclidean_distance': euclidean_distance}) 85 | model_loss = "euclidean" 86 | else: 87 | model = load_model(checkpoint) 88 | model_loss = str(model.loss) 89 | 90 | index = [ix for ix in range(imdb.num_images)] 91 | output_path = osp.join('output', 'predict_results') 92 | if not osp.exists(output_path): 93 | os.makedirs(output_path) 94 | 95 | if singlelabel_predict: 96 | predict_image = predict_image_singlelabel 97 | pred = '_sl_' 98 | elif box_method == 'yolo': 99 | predict_image = predict_ml_yolo 100 | ''' 101 | dictionary 102 | ''' 103 | extra = load_all_boxes_yolo(imdb) 104 | pred = '_ml_yolo_' 105 | elif box_method == 'frcnn': 106 | predict_image = predict_ml_frcnn 107 | boxes = load_frcnn_boxes(imdb) 108 | pred = '_ml_frcnn_' 109 | else: 110 | predict_image = predict_ml_random 111 | pred = '_ml_random' 112 | extra = Pool(4) 113 | 114 | output_file = osp.join(output_path, 'results_' + imdb.name + '_' + str(lang_db.name) + pred + str(model_loss) + '_' + str(space) + '_l2_nw.txt') 115 | with open(output_file, 'w+') as wf: 116 | for i in index: 117 | if i%100==0: print "Generating proposals: {}/{}".format(i, imdb.num_images) 118 | if box_method == 'frcnn': extra = boxes[i] 119 | img_path = imdb.image_path_at(i) 120 | predictions = predict_image(img_path, model, lang_db, extra) 121 | 122 | if box_method == 'yolo': 123 | if len(predictions) == 0: continue 124 | # write predictions to file 125 | # id, num_actual, a1, a2, ..., an, num_predicted, p1, p2, ..., pn 126 | actual = imdb.gt_at(i) 127 | wf.write('%s' % imdb.image_index_path(i)) 128 | if isinstance(actual, basestring): # Check if gt is a simple string and not array 129 | wf.write(', %s' % 1) # Length of ground_truth is 1 130 | wf.write(', %s' % actual) 131 | else: 132 | wf.write(', %s' % len(actual)) 133 | for a in actual: 134 | wf.write(', %s' % a) 135 | 136 | wf.write(', %s' % len(predictions)) 137 | if len(predictions) > 1: 138 | for p in predictions: 139 | wf.write(', %s' % p) 140 | else: 141 | wf.write(', %s' % predictions) 142 | 143 | wf.write('\n') 144 | 145 | -------------------------------------------------------------------------------- /lib/bt_datasets/nus_wide.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import os.path as osp 3 | from bt_datasets.imdb import imdb 4 | from bt_net.config import cfg 5 | import xml.dom.minidom as minidom 6 | import numpy as np 7 | import scipy.sparse 8 | import scipy.io as sio 9 | import cPickle 10 | from random import shuffle 11 | import subprocess 12 | import xml.etree.ElementTree as ET 13 | 14 | import linecache 15 | 16 | class nus_wide(imdb): 17 | def __init__(self, image_set): 18 | imdb.__init__(self, 'nus_wide_' + image_set) 19 | self._image_set = image_set 20 | 21 | self._devkit_path = osp.join(cfg.IMAGE_DATA_DIR, 'nus_wide') 22 | self.image_list = osp.join(self._devkit_path, 'ImageList') 23 | self._data_path = osp.join(self._devkit_path, 'Images') 24 | 25 | if 'Train' in self._image_set: 26 | label_file_1k = osp.join(self._devkit_path, 'Concepts', 'Train_Tags920_clean') 27 | self._classes = self._generate_class_list(label_file_1k) 28 | elif "_us_" in self._image_set: 29 | label_file_1k = osp.join(self._devkit_path, 'Concepts', 'Train_Tags920_clean') 30 | label_file_81 = osp.join(self._devkit_path, 'Concepts', 'Concepts81.txt') 31 | self._classes = self._generate_class_list(label_file_1k, label_file_81) 32 | elif "_u_" in self._image_set: 33 | label_file_81 = osp.join(self._devkit_path, 'Concepts', 'Concepts81.txt') 34 | print "Generating class list from label_file concepts, only unseen" 35 | self._classes = self._generate_class_list(label_file_81) 36 | elif '1k' in self._image_set: 37 | label_file_1k = osp.join(self._devkit_path, 'Concepts', 'TagList1k.txt') 38 | self._classes = self._generate_class_list(label_file_1k) 39 | print "classes", len(self._classes) 40 | 41 | if "train_clean" in self._image_set: 42 | self.label_file = osp.join(self._devkit_path,"Train_Tags81_clean") 43 | else: 44 | self.label_file = osp.join(self._devkit_path,"Test_Tags81_clean.txt") 45 | 46 | print("Num classes: {}").format(len(self._classes)) 47 | 48 | self._image_index = self._load_image_set_index() 49 | self._image_ext = ['.JPG'] 50 | 51 | 52 | 53 | # Specific config options 54 | self.config = {'cleanup' : True, 55 | 'use_salt' : True, 56 | 'top_k' : 2000} 57 | 58 | assert osp.exists(self._devkit_path), \ 59 | 'Devkit path does not exist: {}'.format(self._devkit_path) 60 | assert osp.exists(self._data_path), \ 61 | 'Path does not exist: {}'.format(self._data_path) 62 | 63 | def image_index_path(self, i): 64 | return self._image_index[i][0] 65 | 66 | def _generate_class_list(self, filename, file_name=None): 67 | self._classes = () 68 | print filename 69 | with open(filename) as f: 70 | for line in f: 71 | self._classes = self._classes + (line.strip(),) 72 | if file_name is not None: 73 | with open(file_name) as file: 74 | for line in file: 75 | lbl = line.strip() 76 | if lbl not in self._classes: 77 | self._classes = self._classes + (lbl, ) 78 | return self._classes 79 | 80 | def image_path_at(self, i): 81 | """ 82 | Return the absolute path to image i in the image sequence. 83 | """ 84 | return self.image_path_from_index(self._image_index[i]) 85 | 86 | def image_path_from_index(self, index): 87 | """ 88 | Construct an image path from the image's "index" identifier. 89 | """ 90 | image_path = osp.join(self._data_path, index[0]) 91 | assert osp.exists(image_path), \ 92 | 'Path does not exist: {}'.format(image_path) 93 | return image_path 94 | 95 | def load_classes(self): 96 | return self.classes 97 | 98 | def gt_at(self,i): 99 | data = self._image_index[i] 100 | image_data = data[0] 101 | lbls = [x.strip() for x in data[1:]] 102 | return lbls 103 | 104 | def get_image_data_clean(self, index, lang_db): 105 | data = self._image_index[index] 106 | image_data = data[0] 107 | lbls = [x.strip() for x in data[1:]] 108 | img_path = osp.join(self._data_path, image_data) 109 | word_vectors = [] 110 | for lbl in lbls: 111 | word_vec = lang_db.word_vector(lbl) 112 | word_vectors.append(word_vec) 113 | return img_path , word_vectors, lbls 114 | ''' 115 | def load_val_data(self, lang_db, size= 10000): 116 | data = shuffle(self._image_index) 117 | 118 | shuffle(images_data) 119 | while 1: 120 | for img_data in images_data: 121 | img_path = osp.join(val_folder_path, img_data[0] + self._image_ext[0]) 122 | x = read_img(img_path) 123 | word_vec = lang_db.word_vector(img_data[1]) 124 | yield np.expand_dims(x, axis=0), np.expand_dims(word_vec, axis=0) 125 | ''' 126 | def get_labels(self, space): 127 | 128 | if space == 2: # If space is 2, include seen labels in space 129 | label_file_1k = osp.join(self._devkit_path, "Concepts", "Train_Tags920_clean") 130 | label_file_81 = osp.join(self._devkit_path, "Concepts", "Concepts81.txt") 131 | labels = self._generate_class_list(label_file_1k, label_file_81) 132 | print("the labels", len(labels)) 133 | if space == 1: 134 | label_file_81 = osp.join(self._devkit_path, "Concepts", "Concepts81.txt") 135 | labels = self._generate_class_list(label_file_81)[1:] 136 | print("the labels", len(labels)) 137 | print(labels) 138 | return labels 139 | 140 | def load_val_data(self, lang_db): 141 | val_file_set = osp.join(self._devkit_path, 'ImageSets/CLS-LOC/val_bts.txt') 142 | val_folder_path = osp.join(self._devkit_path, 'Data', 'CLS-LOC', 'val') 143 | with open(val_file_set) as f: 144 | images_data = [line.split() for line in f] 145 | shuffle(images_data) 146 | while 1: 147 | for img_data in images_data: 148 | img_path = osp.join(val_folder_path, img_data[0] + self._image_ext[0]) 149 | x = read_img(img_path) 150 | word_vec = lang_db.word_vector(img_data[1]) 151 | yield np.expand_dims(x, axis=0), np.expand_dims(word_vec, axis=0) 152 | def _load_image_set_index(self): 153 | """ 154 | Load the indexes listed in this dataset's image set file. 155 | """ 156 | # Example path to image set file: 157 | # self._data_path + /ImageList/TrainImageList.txt 158 | image_set_file = osp.join(self._devkit_path, 'ImageList', self._image_set + '.txt') 159 | assert osp.exists(image_set_file), \ 160 | 'Path does not exist: {}'.format(image_set_file) 161 | with open(image_set_file) as f: 162 | image_index = [line.split(",") for line in f] 163 | return image_index 164 | 165 | def _load_gt_classes(self, index): 166 | """ 167 | Load image and groundtruth classes from txt files of nus-wide. 168 | """ 169 | index = index[0].split('/') 170 | gt_classes = index[0] 171 | return {'gt_classes': gt_classes} 172 | -------------------------------------------------------------------------------- /lib/bt_net/train_bts.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from keras.applications.inception_v3 import InceptionV3 3 | from keras.models import Model, load_model 4 | from keras.layers import Dense, GlobalAveragePooling2D, Dropout, Lambda 5 | from keras.optimizers import Adam 6 | from keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard 7 | from keras.losses import cosine_proximity, hinge 8 | from keras import backend as K 9 | from preprocessing.preprocess_images import read_img 10 | from bt_net.generate_boxes import generate_n_boxes 11 | from bt_net.faster_rcnn_boxes import load_frcnn_boxes, boxes_to_images, load_all_boxes_yolo,load_image_boxes_yolo 12 | 13 | from multiprocessing import Pool 14 | 15 | import os 16 | import os.path as osp 17 | import numpy as np 18 | from random import shuffle 19 | import cv2 20 | 21 | 22 | def load_data(imdb, lang_db, batch_size = 32): 23 | total_batch = int(imdb.num_images / batch_size) 24 | index = [ix for ix in range(imdb.num_images)] 25 | 26 | while 1: 27 | shuffle(index) 28 | i = 0 29 | for _ in range(total_batch): 30 | X_train = [] 31 | y_train = [] 32 | for j in range(batch_size): 33 | x_path, y = imdb.get_image_data_clean(index[i], lang_db) 34 | x = read_img(x_path) 35 | X_train.append(x) 36 | y_train.append(y) 37 | i = i + 1 38 | yield np.array(X_train, dtype=np.float32), np.array(y_train, dtype=np.float32) 39 | 40 | def create_best(boxes, lbl_features, indices): 41 | new_box =[] 42 | new_lbls = [] 43 | for i,index in enumerate(indices): 44 | new_box.append(boxes[index]) 45 | new_lbls.append(lbl_features[i]) 46 | return new_box, new_lbls 47 | 48 | def load_multilabel_data(imdb, lang_db, pretrained, box_method, batch_size = 8): 49 | total_batch = int(imdb.num_images / batch_size) 50 | index = [ix for ix in range(imdb.num_images)] 51 | 52 | if box_method == 'random': 53 | pool = Pool(4) 54 | elif box_method == 'frcnn': 55 | print "Loading Faster R-CNN boxes..." 56 | all_boxes = load_frcnn_boxes(imdb) 57 | elif box_method == 'yolo': 58 | print "Loading YOLO boxes..." 59 | all_boxes = load_all_boxes_yolo(imdb) 60 | 61 | while 1: 62 | shuffle(index) 63 | i = 0 64 | for _ in range(total_batch): 65 | X_train = [] 66 | y_train = [] 67 | j = 0 68 | while j < batch_size: 69 | if i >= len(index): i = 0 70 | x_path, y_s, g_t = imdb.get_image_data_clean(index[i], lang_db) 71 | img = cv2.imread(x_path) 72 | if box_method == 'random': 73 | boxes = generate_n_boxes(img, pool=pool) 74 | k = 1 75 | elif box_method == 'frcnn': 76 | boxes = boxes_to_images(img, all_boxes[index[i]]) 77 | k = 1 78 | elif box_method == "yolo": 79 | boxes = load_image_boxes_yolo(img, all_boxes, x_path) 80 | k = 1 81 | if len(boxes) == 0: 82 | i = i + 1 83 | ''' 84 | If the path does not exist, len is 0. 85 | Jump over this picture. 86 | 87 | To jump over the picture next time, increment i. 88 | 89 | ''' 90 | continue 91 | j += 1 # Only increment j, when actually appending 92 | pred_vects = pretrained.predict_on_batch(boxes) 93 | best_indices, lbl_features = lang_db.get_best_match(pred_vects, y_s, g_t,k) 94 | best_box, best_lbl_feat = create_best(boxes, lbl_features, best_indices) 95 | X_train += best_box 96 | y_train += best_lbl_feat 97 | i = i + 1 98 | yield np.array(X_train, dtype=np.float32), np.array(y_train, dtype=np.float32) 99 | 100 | def euclidean_distance(y_true, y_pred): 101 | return K.sqrt(K.sum(K.square(y_true - y_pred), axis=1, keepdims=True)) 102 | 103 | def define_network(vector_size, loss): 104 | base_model = InceptionV3(weights='imagenet', include_top=True) 105 | 106 | for layer in base_model.layers: # Freeze layers in pretrained model 107 | layer.trainable = False 108 | 109 | # fully-connected layer to predict 110 | x = Dense(4096, activation='relu', name='fc1')(base_model.layers[-2].output) 111 | x = Dense(8096, activation='relu', name='fc2')(x) 112 | x = Dropout(0.5)(x) 113 | x = Dense(2048,activation='relu', name='fc3')(x) 114 | predictions = Dense(vector_size, activation='relu')(x) 115 | l2 = Lambda(lambda x: K.l2_normalize(x, axis=1))(predictions) 116 | model = Model(inputs=base_model.inputs, outputs=l2) 117 | 118 | optimizer = 'adam' 119 | if loss == 'euclidean': 120 | model.compile(optimizer = optimizer, loss = euclidean_distance) 121 | else: 122 | model.compile(optimizer = optimizer, loss = loss) 123 | 124 | return model 125 | 126 | def train_multilabel_bts(lang_db, imdb, pretrained, max_iters = 1000, loss_func = 'squared_hinge', box_method = 'random'): 127 | # Create callback_list. 128 | dir_path = osp.join('output', 'bts_ckpt', imdb.name) 129 | tensor_path = osp.join(dir_path, 'log_dir') 130 | if not osp.exists(dir_path): 131 | os.makedirs(dir_path) 132 | if not osp.exists(tensor_path): 133 | os.makedirs(tensor_path) 134 | 135 | ckpt_save = osp.join(dir_path, lang_db.name + '_multi_label_fixed_' + 'weights-{epoch:02d}.hdf5') 136 | checkpoint = ModelCheckpoint(ckpt_save, monitor='loss', verbose=1, save_best_only=True) 137 | early_stop = EarlyStopping(monitor='loss', min_delta=0, patience=3, verbose=0, mode='auto') 138 | tensorboard = TensorBoard(log_dir=dir_path, histogram_freq=2000, write_graph=True, write_images=False) 139 | callback_list = [checkpoint, early_stop, tensorboard] 140 | pretrained.fit_generator(load_multilabel_data(imdb, lang_db, pretrained, box_method), 141 | steps_per_epoch = 5000, 142 | epochs = max_iters, 143 | verbose = 1, 144 | callbacks = callback_list, 145 | workers = 1) 146 | 147 | pretrained.save(osp.join(dir_path, 'model_fixed' + imdb.name + '_' + lang_db.name + '_ML_' + box_method + '_' + loss_func + '.hdf5')) 148 | 149 | def train_bts(lang_db, imdb, max_iters = 1000, loss = 'squared_hinge'): 150 | # Define network 151 | model = define_network(lang_db.vector_size, loss) 152 | 153 | #model = load_model(osp.join('output', 'bts_ckpt', 'imagenet1k_train_bts', 'glove_wiki_300_hinge_weights-03.hdf5')) 154 | 155 | # Create callback_list. 156 | dir_path = osp.join('output', 'bts_ckpt', imdb.name) 157 | if not osp.exists(dir_path): 158 | os.makedirs(dir_path) 159 | 160 | log_dir = osp.join('output', 'bts_logs', imdb.name) 161 | if not osp.exists(log_dir): 162 | os.makedirs(log_dir) 163 | 164 | ckpt_save = osp.join(dir_path, lang_db.name + "_" + loss + "_weights-{epoch:02d}.hdf5") 165 | checkpoint = ModelCheckpoint(ckpt_save, monitor='val_loss', verbose=1, save_best_only = True) 166 | early_stop = EarlyStopping(monitor='val_loss', min_delta=0, patience=3, verbose=0, mode='auto') 167 | 168 | tensorboard = TensorBoard(log_dir=log_dir, histogram_freq=0, write_graph=True, write_images=False) 169 | callback_list = [checkpoint, early_stop, tensorboard] 170 | model.fit_generator(load_data(imdb, lang_db), 171 | steps_per_epoch = 5000, 172 | epochs = max_iters, 173 | verbose = 1, 174 | validation_data = imdb.load_val_data(lang_db), 175 | validation_steps = 20000, # number of images to validate on 176 | callbacks = callback_list, 177 | workers = 1) 178 | 179 | model.save(osp.join(dir_path, 'model_' + imdb.name + '_' + lang_db.name + '_' + loss + '_l2.hdf5')) 180 | -------------------------------------------------------------------------------- /lib/preprocessing/clean_openimages.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Made by Bjoernar Remmen 3 | # Short is better 4 | 5 | bad = ["train/6572/0fb2ed69f4bae4e2.jpg", 6 | "train/6574/a3fea9921acdfbc3.jpg", 7 | "train/7551/173d25f00beb04ce.jpg", 8 | "train/7545/c5d7cbb12b608740.jpg", 9 | "train/7532/b8425ddfe824f141.jpg", 10 | "train/7551/050c5d8e42bf84eb.jpg", 11 | "train/7545/2685fc115f369749.jpg", 12 | "train/7532/f38b800f8eb54351.jpg", 13 | "train/6572/7889ccdf56de1679.jpg", 14 | "train/7551/a50b7a8977033cfd.jpg", 15 | "train/6572/b781facd601293c5.jpg", 16 | "train/7505/a3c859eb75d6e914.jpg", 17 | "train/7492/ff907236c99bfdfd.jpg", 18 | "train/7505/0765cfce36357d56.jpg", 19 | "train/7503/fd12ea0e07da1e3c.jpg", 20 | "train/7545/6f4e52b809150f40.jpg", 21 | "train/7532/c7a574dea7fb0623.jpg", 22 | "train/7551/4da4fc97b5b81c22.jpg", 23 | "train/7503/20c7f7f50f101cfc.jpg", 24 | "train/6574/cb7ed2d6d557d69a.jpg", 25 | "train/7505/8a6202a7a2fbdbfd.jpg", 26 | "train/7545/ff3deb268dc9f26d.jpg", 27 | "train/6572/81f1c4f662e69a0a.jpg", 28 | "train/7545/3c871a10cc406563.jpg" 29 | ] 30 | 31 | def label_dictionary(name): 32 | id_lbl_dict = {} 33 | counter = 0 34 | with open(name) as file: 35 | for line in file: 36 | line = line.strip().split(",") 37 | image_id = line[0] 38 | lbls = line[1:] 39 | counter += 1 40 | print image_id, counter 41 | id_lbl_dict[image_id] = lbls 42 | 43 | return id_lbl_dict 44 | 45 | def make_all_file(id_lbl_dict, file_dirs_file, newfile,labelsfile): 46 | ''' 47 | :param id_lbl_dict: image_id to lbls 48 | :param file_dirs_file: file with all images in 49 | :param newfile: name of new file 50 | :param labelsfile: label_id to labels 51 | :return: 52 | ''' 53 | lbl_id_to_lbl = {} 54 | with open(labelsfile) as reader: 55 | for line in reader: 56 | lbl = line.split(",")[1].strip().replace(" ", "_") 57 | lbl_id = line.split(",")[0].strip() 58 | if lbl.startswith('"') and lbl.endswith('"'): 59 | lbl = lbl[1:-1] 60 | if lbl_id.startswith('"') and lbl_id.endswith('"'): 61 | lbl_id = lbl_id[1:-1] 62 | lbl_id_to_lbl[lbl_id] = lbl 63 | with open(file_dirs_file) as file, open(newfile, "w") as new: 64 | counter = 0 65 | for path in file: 66 | path = path.rstrip() 67 | array = path.split("/") 68 | picture_id = array[2].split(".")[0] 69 | print "picture_id", picture_id 70 | try: 71 | 72 | lbl_ids = id_lbl_dict[picture_id] 73 | 74 | new.write(path + "," + ",".join([lbl_id_to_lbl[x] for x in lbl_ids]) + "\n") 75 | 76 | except: 77 | counter += 1 78 | print counter 79 | 80 | def get_classes(filename): 81 | ''' 82 | :param filename: filename for train or test 83 | :return: class_dict with classes 84 | ''' 85 | class_dict = {} 86 | with open(filename) as file: 87 | for line in file: 88 | the_class = line.strip() 89 | class_dict[the_class] = 1 90 | return class_dict 91 | 92 | def make_zs_split(all_file, train_classes, test_classes, test_file, train_file): 93 | ''' 94 | :param all_file: file with all images as well as lbls for each image 95 | :param train_classes: train classes for zs split 96 | :param test_classes: test classes for zs split 97 | :param test_file: placement and name of test file 98 | :param train_file: placement and name of train file 99 | :return: None. Just make test and train file 100 | ''' 101 | with open(all_file) as file, open(test_file, "w") as test, open(train_file, "w") as train: 102 | counter = 0 103 | for line in file: 104 | array = line.strip().split(",") 105 | path = array[0] 106 | lbls = array[1:] 107 | legal_lbls = [] 108 | # make array containing lbls in the split 109 | # this split can be made from i.e glove_wiki_300D 110 | train_legal = [] 111 | test_legal = [] 112 | for lbl in lbls: 113 | if lbl in train_classes or lbl in test_classes: 114 | legal_lbls.append(lbl) 115 | if lbl in train_classes: 116 | train_legal.append(lbl) 117 | if lbl in test_classes: 118 | test_legal.append(lbl) 119 | 120 | if len(legal_lbls) == 0: 121 | continue 122 | for lbl in legal_lbls: 123 | if lbl in test_classes: 124 | test.write(path + "," + ",".join(test_legal) + "\n") 125 | break 126 | if lbl in train_classes: 127 | train.write(path + "," + ",".join(train_legal) + "\n") 128 | break 129 | print counter, "/", "5280911" 130 | counter += 1 131 | from glob import glob 132 | import os.path as osp 133 | 134 | def val_imgs_list(placement): 135 | ''' 136 | :param placement: placement of image txt file 137 | :return: None. Just make file with all images in. 138 | ''' 139 | with open(osp.join(placement, "validation_imgs.txt"), "w") as val: 140 | counter = 0 141 | folder_list_total = glob(osp.join(placement,"val_imgs","*")) 142 | for folder in folder_list_total: 143 | total = osp.join(folder, "*") 144 | file_placement = glob(total) 145 | for file in file_placement: 146 | file = file.strip().split("openimages/")[1] 147 | val.write(file + "\n") 148 | counter += 1 149 | print counter 150 | 151 | def val_id_to_dir(file): 152 | ''' 153 | :param file: validation file with all files 154 | :return: image_id to dir of file. Dictionary 155 | ''' 156 | id_to_dir ={} 157 | counter = 0 158 | with open(file) as file: 159 | for line in file: 160 | 161 | line = line.strip() 162 | image_id = line.split("/")[2].split(".")[0] 163 | counter +=1 164 | id_to_dir[image_id] = line 165 | return id_to_dir 166 | 167 | 168 | 169 | def make_val_openimages(val_classes,labelsfile, validation_file, newfile, id_to_dir): 170 | ''' 171 | :param val_classes: Same as test classes 172 | :param labelsfile: id_to_label file 173 | :param validation_file: file that says labels for each image_id 174 | :param newfile: name and placement of validation file 175 | :param id_to_dir: image_id to dir of files dictionary 176 | :return: Nothing. Just makes file 177 | ''' 178 | lbl_id_to_lbl = {} 179 | with open(labelsfile) as reader: 180 | for line in reader: 181 | lbl = line.split(",")[1].strip().replace(" ", "_") 182 | lbl_id = line.split(",")[0].strip() 183 | if lbl.startswith('"') and lbl.endswith('"'): 184 | lbl = lbl[1:-1] 185 | if lbl_id.startswith('"') and lbl_id.endswith('"'): 186 | lbl_id = lbl_id[1:-1] 187 | lbl_id_to_lbl[lbl_id] = lbl 188 | 189 | image_id_labels = {} 190 | with open(validation_file) as file: 191 | file.next() 192 | for line in file: 193 | line = line.strip().split(",") 194 | image_id = line[0] 195 | label_id = line[2] 196 | label = lbl_id_to_lbl[label_id] 197 | if image_id not in image_id_labels: 198 | image_id_labels[image_id] = [] 199 | if label in val_classes: 200 | image_id_labels[image_id].append(label) 201 | 202 | counter = 0 203 | with open(newfile, "w") as file: 204 | for image_id, lbls in image_id_labels.iteritems(): 205 | try: 206 | dir = id_to_dir[image_id] 207 | if len(lbls) >=1: 208 | file.write(dir + "," + ",".join(lbls) + "\n") 209 | except: 210 | counter +=1 211 | print counter 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | if __name__ == "__main__": 220 | id_lbl_file ="../../data/image_data/openimages/label_pr_picture_100.txt" 221 | file_dirs_file ="../../data/image_data/openimages/filedirs.txt" 222 | newfile = "../../data/image_data/openimages/path_lbls.txt" 223 | labelsfile = "../../data/image_data/openimages/labels.csv" 224 | 225 | 226 | 227 | #id_lbl_dict = label_dictionary(id_lbl_file) 228 | #make_all_file(id_lbl_dict, file_dirs_file, newfile, labelsfile) 229 | train_file = "../../data/image_data/openimages/train_zs.txt" 230 | test_file = "../../data/image_data/openimages/test_zs.txt" 231 | 232 | train_classes = "../../data/image_data/openimages/oi_train_classes.txt" 233 | test_classes = "../../data/image_data/openimages/oi_test_classes.txt" 234 | train_classes = get_classes(train_classes) 235 | test_classes = get_classes(test_classes) 236 | 237 | validation_file = "../../data/image_data/openimages/labels_val.csv" 238 | new_val_file = "../../data/image_data/openimages/validation.txt" 239 | #make_zs_split(newfile, train_classes, test_classes, test_file, train_file) 240 | #val_imgs_list(osp.join("..","..","data","image_data","openimages")) 241 | id_to_dir = val_id_to_dir(osp.join("..","..","data","image_data","openimages","validation_imgs.txt")) 242 | make_val_openimages(test_classes,labelsfile, validation_file, new_val_file, id_to_dir) 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | -------------------------------------------------------------------------------- /lib/bt_datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import os.path as osp 3 | from random import shuffle 4 | 5 | from bt_datasets.imdb import imdb 6 | from bt_net.config import cfg 7 | import numpy as np 8 | import scipy.io as sio 9 | import cPickle 10 | import xml.etree.ElementTree as ET 11 | import xml.dom.minidom as minidom 12 | from preprocessing.preprocess_images import read_img 13 | 14 | from random import shuffle 15 | 16 | 17 | class imagenet(imdb): 18 | def __init__(self, image_set): 19 | imdb.__init__(self, 'imagenet_' + image_set) 20 | self._image_set = image_set 21 | self._devkit_path = osp.join(cfg.IMAGE_DATA_DIR, 'imagenet') 22 | self._data_path = osp.join(self._devkit_path, 'Data/DET', self._image_folder()) 23 | synsets = sio.loadmat(osp.join(self._devkit_path, 'meta_det.mat')) 24 | self._classes = ('__background__',) 25 | self._wnid = (0,) 26 | for i in xrange(200): 27 | self._classes = self._classes + (synsets['synsets'][0][i][2][0],) 28 | self._wnid = self._wnid + (synsets['synsets'][0][i][1][0],) 29 | self._wnid_to_ind = dict(zip(self._wnid, xrange(self.num_classes))) 30 | self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes))) 31 | self._image_ext = ['.JPEG'] 32 | self._image_index = self._load_image_set_index() 33 | #print "image_index", self._image_index 34 | 35 | 36 | # Specific config options 37 | self.config = {'cleanup' : True, 38 | 'use_salt' : True, 39 | 'top_k' : 2000} 40 | 41 | assert os.path.exists(self._devkit_path), \ 42 | 'Devkit path does not exist: {}'.format(self._devkit_path) 43 | assert os.path.exists(self._data_path), \ 44 | 'Path does not exist: {}'.format(self._data_path) 45 | 46 | def image_index_path(self, i): 47 | return self._image_index[i][0] 48 | 49 | 50 | def _image_folder(self): 51 | image_folder = "" 52 | if 'train' in self._image_set: return 'train' 53 | elif 'test_zs' in self._image_set: return 'test_zs' 54 | elif 'test' in self._image_set: return 'test' 55 | return 'val' 56 | 57 | def image_path_at(self, i): 58 | """ 59 | Return the absolute path to image i in the image sequence. 60 | """ 61 | return self.image_path_from_index(self._image_index[i]) 62 | 63 | def image_path_from_index(self, index): 64 | """ 65 | Construct an image path from the image's "index" identifier. 66 | """ 67 | image_path = os.path.join(self._data_path, 68 | index[0] + self._image_ext[0]) 69 | assert os.path.exists(image_path), \ 70 | 'Path does not exist: {}'.format(image_path) 71 | return image_path 72 | 73 | def _load_image_set_index(self): 74 | """ 75 | Load the indexes listed in this dataset's image set file. 76 | """ 77 | # Example path to image set file: 78 | # self._data_path + /ImageSets/val.txt or self._data_path + /ImageSets/zero_shot.txt 79 | image_set_file = os.path.join(self._devkit_path, 'ImageSets/DET/', self._image_set + '.txt') 80 | assert os.path.exists(image_set_file), \ 81 | 'Path does not exist: {}'.format(image_set_file) 82 | 83 | with open(image_set_file) as f: 84 | image_index = [line.split() for line in f] 85 | return image_index 86 | 87 | def load_classes(self): 88 | return self._classes[1:] # Return classes except background 89 | 90 | def load_image_lbl(self): 91 | i = 0 92 | index = [ix for ix in range(self.num_images)] 93 | shuffle(index) 94 | while i < self.num_images: 95 | 96 | image_dir = self._image_index[index[i]] 97 | image_path = self.image_path_at(index[i]) 98 | filename = os.path.join(self._devkit_path, 'Annotations', self._image_folder(), image_dir[0] + '.xml') 99 | tree = ET.parse(filename) 100 | 101 | def get_data_from_tag(node, tag): 102 | return node.getElementsByTagName(tag)[0].childNodes[0].data 103 | 104 | with open(filename) as f: 105 | data = minidom.parseString(f.read()) 106 | 107 | objs = data.getElementsByTagName('object') 108 | unique = {} 109 | for ix, obj in enumerate(objs): 110 | cls = self._classes[self._wnid_to_ind[ 111 | str(get_data_from_tag(obj, "name")).lower().strip()]] 112 | if cls not in unique: 113 | unique[cls] = cls 114 | 115 | i += 1 116 | if i >=(self.num_images): 117 | i = 0 118 | shuffle(index) 119 | yield image_path, [str(x) for x in unique.keys()] 120 | 121 | def get_image_data_clean(self, index, lang_db): 122 | image_data = self._image_index[index] 123 | img_path = osp.join(self._data_path, image_data[0] + self._image_ext[0]) 124 | word_vec = lang_db.word_vector(image_data[1]) 125 | return img_path, word_vec 126 | 127 | def get_image_label(self, index, lang_db): 128 | image_dir = self._image_index[index] 129 | filename = os.path.join(self._devkit_path, 'Annotations/DET/', self._image_folder(), image_dir[0] + '.xml') 130 | tree = ET.parse(filename) 131 | def get_data_from_tag(node, tag): 132 | return node.getElementsByTagName(tag)[0].childNodes[0].data 133 | 134 | with open(filename) as f: 135 | data = minidom.parseString(f.read()) 136 | 137 | objs = data.getElementsByTagName('object') 138 | num_objs = len(objs) 139 | if num_objs != 1: return False 140 | 141 | cls = self._classes[self._wnid_to_ind[str(get_data_from_tag(objs[0], "name")).lower().strip()]] 142 | y = lang_db.word_vector(cls) 143 | if y is None: return False 144 | return [image_dir[0], cls] 145 | 146 | 147 | def load_val_data(self, lang_db): 148 | val_file_set = osp.join(self._devkit_path, 'ImageSets/DET/val_bts.txt') 149 | val_folder_path = osp.join(self._devkit_path, 'Data', 'DET', 'val') 150 | with open(val_file_set) as f: 151 | images_data = [line.split() for line in f] 152 | while 1: 153 | for img_data in images_data: 154 | img_path = osp.join(val_folder_path, img_data[0] + self._image_ext[0]) 155 | x = read_img(img_path) 156 | word_vec = lang_db.word_vector(img_data[1]) 157 | yield np.expand_dims(x, axis=0), np.expand_dims(word_vec, axis=0) 158 | 159 | def get_image_data_clean(self, index, lang_db): 160 | image_data = self._image_index[index] 161 | img_path = osp.join(self._data_path, image_data[0] + self._image_ext[0]) 162 | word_vec = lang_db.word_vector(image_data[1]) 163 | return img_path, word_vec 164 | 165 | def get_image_label(self, index, lang_db): 166 | image_dir = self._image_index[index] 167 | filename = os.path.join(self._devkit_path, 'Annotations/DET/', self._image_folder(), image_dir[0] + '.xml') 168 | tree = ET.parse(filename) 169 | def get_data_from_tag(node, tag): 170 | return node.getElementsByTagName(tag)[0].childNodes[0].data 171 | 172 | with open(filename) as f: 173 | data = minidom.parseString(f.read()) 174 | 175 | objs = data.getElementsByTagName('object') 176 | num_objs = len(objs) 177 | if num_objs != 1: return False 178 | 179 | cls = self._classes[self._wnid_to_ind[str(get_data_from_tag(objs[0], "name")).lower().strip()]] 180 | y = lang_db.word_vector(cls) 181 | if y is None: return False 182 | return [image_dir[0], cls] 183 | 184 | 185 | def load_val_data(self, lang_db): 186 | val_file_set = osp.join(self._devkit_path, 'ImageSets/DET/val_bts.txt') 187 | val_folder_path = osp.join(self._devkit_path, 'Data', 'DET', 'val') 188 | with open(val_file_set) as f: 189 | images_data = [line.split() for line in f] 190 | shuffle(images_data) 191 | while 1: 192 | for img_data in images_data: 193 | img_path = osp.join(val_folder_path, img_data[0] + self._image_ext[0]) 194 | x = read_img(img_path) 195 | word_vec = lang_db.word_vector(img_data[1]) 196 | yield np.expand_dims(x, axis=0), np.expand_dims(word_vec, axis=0) 197 | 198 | def load_generated_box_features(self, proposals_file): 199 | """ Load the boxes generated by py-faster-rcnn """ 200 | num_images = self.num_images 201 | 202 | # Load pickle with proposals 203 | proposals_path = os.path.join(self._devkit_path, 'Generated_proposals', proposals_file) 204 | print(proposals_path) 205 | with open(proposals_path, 'rb') as input_file: 206 | proposals = cPickle.load(input_file) 207 | 208 | # proposals has shape (num_images, 2000, 4) and a few images (600, 4) 209 | 210 | num_proposals = len(proposals) 211 | assert num_images != num_proposals, \ 212 | 'Mismatch between number of images in imagelist and the proposal list, {}, {}'.format(num_images, num_proposals) 213 | i = 0 214 | 215 | # Clean boxes smaller than 0.3 of image, have top 3 or top 5 of bounding boxes? Or base it on scores? 216 | for ix in proposals: 217 | if i%100 == 0: 218 | ix = np.array(ix) 219 | print ix[0] 220 | print("images {}/{}".format(i, num_images)) 221 | i = i + 1 222 | 223 | 224 | -------------------------------------------------------------------------------- /lib/language_models/glove_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | from bt_net.config import cfg 5 | from language_models.lmdb import lmdb 6 | from preprocessing.load_langmod import load_h5py 7 | from annoy import AnnoyIndex 8 | from scipy.spatial import ckdtree 9 | import pickle 10 | import operator 11 | from scipy.spatial import distance 12 | from scipy import spatial 13 | 14 | 15 | class glove_factory(lmdb): 16 | def __init__(self, corpus, dimension, glove_file='language_models.h5'): 17 | lmdb.__init__(self, str(corpus)) 18 | self._corpus = corpus 19 | self._devkit_path = osp.join(cfg.LM_DATA_DIR, 'glove') 20 | 21 | print "feature_name", self.name 22 | print "devkit_path", self._devkit_path 23 | self.features, self.words = load_h5py(self.name, 24 | osp.join(self._devkit_path, glove_file)) 25 | 26 | assert len(self.features) == len(self.words), \ 27 | 'Feature and word array not equal length. feature:{} words: {}'.format(str(len(self.features)), str(len(self.words))) 28 | 29 | assert os.path.exists(self._devkit_path), \ 30 | 'Devkit path does not exist: {}'.format(self._devkit_path) 31 | 32 | self._dimension = dimension 33 | self._vectors = self.vector_dict() 34 | 35 | @property 36 | def vector_size(self): 37 | return self._dimension 38 | 39 | def vector_dict(self): 40 | """ 41 | Return a dictionary with the vector values of the classes 42 | The i-th class and the i-th feature is the two that correlate 43 | """ 44 | vector_dict = {} 45 | for i in range(len(self.words)): 46 | vector_dict[self.words[i].lower()] = self.features[i] 47 | return vector_dict 48 | 49 | def build_tree(self, space, words, imdb_name, num_trees = 1000, vector_list = None): 50 | """ 51 | Build annoy tree to calculate distance, if vector_list = None builds full tree with all words in language model. If not None, builds using the words in the vector list. 52 | """ 53 | # If pckl exist, load, else build 54 | tree_path = osp.join(self._devkit_path, self.name + '_' + imdb_name + str(space) + '.ann') 55 | pckl_path = osp.join(self._devkit_path, self.name + '_' + imdb_name + str(space) + 'array'+".pkl") 56 | t = AnnoyIndex(self._dimension, metric="euclidean") 57 | if osp.exists(tree_path): 58 | print "Tree exist, loading from file..." 59 | 60 | t.load(tree_path) 61 | self._tree = t 62 | with open(pckl_path, 'rb') as file: 63 | self._labels = pickle.load(file) 64 | else: 65 | print "Building tree..." 66 | 67 | counter = 0 68 | word_list = [] 69 | if space == 0: 70 | for word, feature in self._vectors.iteritems(): 71 | word_list.append(word) 72 | t.add_item(counter,feature) 73 | counter += 1 74 | else: 75 | for w in words: 76 | word_list.append(w) 77 | t.add_item(counter, self.word_vector(w)) 78 | counter += 1 79 | 80 | t.build(num_trees) 81 | self._tree = t 82 | self._labels = word_list 83 | 84 | # Save tree 85 | t.save(tree_path) 86 | with open(pckl_path, 'wb') as handle: 87 | pickle.dump(word_list,handle) 88 | 89 | def closest_labels(self, labels_vector, k_closest = 1): 90 | ''' 91 | :param labels_vector: 92 | :return: Returns the words of the closest labels 93 | ''' 94 | nearest = [] 95 | for v in labels_vector: 96 | nearest.append(self._tree.get_nns_by_vector(v, k_closest, include_distances=False)) 97 | return [[self._labels[x] for x in nearest[y]] for y in range(len(nearest))] 98 | 99 | def get_label_index(self,label): 100 | for i in range(len(self._labels)): 101 | if label == self._labels[i]: 102 | return i 103 | 104 | def get_feature(self,label): 105 | index = self.get_label_index(label) 106 | if index is None: 107 | return index 108 | else: 109 | return self._tree.get_item_vector(self.get_label_index(label)) 110 | 111 | def is_existing(self,word): 112 | for i in range(len(self._labels)): 113 | if word == self._labels[i].lower(): 114 | return i 115 | return False 116 | def get_lbl_features(self, pic_labels, k_closest=1): 117 | label_indices = [] 118 | 119 | for x in range(len(pic_labels)): 120 | temp = self.is_existing(pic_labels[x]) 121 | if temp == False: 122 | pic_labels.pop(x) 123 | else: 124 | label_indices.append(temp) 125 | 126 | nearest = [] 127 | for lbl_ind in label_indices: 128 | nearest.append(self._tree.get_item_vector(lbl_ind)) 129 | return nearest, pic_labels 130 | 131 | def get_k_smallest(self,distance, indices, k): 132 | distances = [] 133 | #x = [list(x) for x in zip(*sorted(zip(distance, indices), key=operator.itemgetter(0)))] 134 | x = [list(x) for x in zip(*sorted(zip(distance, indices)))] 135 | distances = x[0] 136 | indices = x[1] 137 | return indices[-k:] 138 | 139 | def get_closest(self,img_feature, lbl_features): 140 | smallest_distance = 999999999 141 | smallest_ind = None 142 | for i in range(len(lbl_features)): 143 | value = distance.euclidean(img_feature,lbl_features[i]) 144 | if value < smallest_distance: 145 | smallest_distance = value 146 | smallest_ind = i 147 | 148 | return smallest_distance, smallest_ind 149 | ''' 150 | def get_best_match4(self,pred_vects, lbl_features, labels=None, k=2, num_trees=100): 151 | 152 | #Goal: get k closest for every label 153 | 154 | 155 | t = AnnoyIndex(len(lbl_features[0])) 156 | for i in range(len(pred_vects)): 157 | t.add_item(i, pred_vects[i]) 158 | t.build(num_trees) 159 | feature_inds = [] 160 | label_inds = [] 161 | for x in range(len(lbl_features)): 162 | index = t.get_nns_by_vector(lbl_features[x], k, include_distances=False) 163 | for ind in index: 164 | label_inds.append(x) 165 | feature_inds.append(ind) 166 | return feature_inds, label_inds 167 | ''' 168 | 169 | 170 | def get_best_match(self, pred_vects, lbl_features, labels, k, expand=True, num_trees=100): 171 | '''' 172 | Goal: get 2 closest for every label 173 | if expand = True - Avoid using closest index in image vectors for another label 174 | else: Can get same index for another label 175 | 176 | Method to get match. A little faster than original code 177 | ''' 178 | t = AnnoyIndex(len(lbl_features[0])) 179 | for i in range(len(pred_vects)): 180 | t.add_item(i, pred_vects[i]) 181 | t.build(num_trees) 182 | image_feature_inds = [] 183 | final_lbl_featues = [] 184 | temp = 0 185 | used_inds = {} 186 | for x in range(len(lbl_features)): 187 | if expand == False: 188 | temp = 2 189 | else: 190 | temp += k 191 | indices = t.get_nns_by_vector(lbl_features[x], temp, include_distances=False) 192 | add = 0 193 | for ind in indices: 194 | if (add > k): break 195 | if temp == len(pred_vects): # Edge case. If every index is used. 196 | final_lbl_featues.append(lbl_features[x]) 197 | image_feature_inds.append(indices[0]) 198 | return image_feature_inds, final_lbl_featues 199 | if ind in used_inds and expand: 200 | continue 201 | else: 202 | if ind not in used_inds: used_inds[ind] = 1 203 | #label_inds.append(x) 204 | final_lbl_featues.append(lbl_features[x]) 205 | image_feature_inds.append(ind) 206 | add += 1 207 | return image_feature_inds, final_lbl_featues 208 | 209 | 210 | 211 | 212 | def get_best_match_OLD(self, pred_vects, lbl_features, labels, k=2, num_trees=1): 213 | #print("lblfeature",len(lbl_features),"lbls", len(labels)) 214 | #print("pred_vects", pred_vects) 215 | 216 | # Make one sublist for each label 217 | all_closest = [[] for x in range(len(lbl_features))] 218 | 219 | # Iterate through all lbl_features 220 | # For each lbl_feature find distance to every image_vector 221 | # Add with (distance, image index, lbl index) 222 | 223 | for i in range(len(lbl_features)): 224 | closest = [] 225 | for ind, element in enumerate(pred_vects): 226 | value = distance.euclidean(pred_vects[ind], lbl_features[i]) 227 | if len(closest) > 0: 228 | added = False 229 | for cls_ind in range(len(closest)): 230 | if closest[cls_ind][0] > value: 231 | # THe value for the distance. Smallest is best 232 | # Index for best vector 233 | # index for best lbl 234 | closest.insert(cls_ind,(value,ind,i)) 235 | added = True 236 | break 237 | if added == False: 238 | closest.append((value, ind, i)) 239 | else: 240 | closest.append((value, ind,i)) 241 | #pred_vects.pop(ind) 242 | all_closest[i] = closest 243 | 244 | # Find the lbl that is closest to some ind. 245 | # This is one of the first in every sub array in the all_closest. 246 | # This is because the first is always the smallest. 247 | closest_ind, closest_val = zip(*sorted(enumerate([x[0][0] for x in all_closest]), key=operator.itemgetter(1))) 248 | closest_ind, closest_val = list(closest_ind),list(closest_val) 249 | best_indices = [] 250 | temp_best_inds = [] 251 | used_boxes = [] 252 | for ind in closest_ind: 253 | for clos in all_closest[ind]: 254 | if len(temp_best_inds) < k and clos[1] not in used_boxes: 255 | used_boxes.append(clos[1]) 256 | temp_best_inds.append(clos) 257 | else: 258 | continue 259 | best_indices += temp_best_inds 260 | final_best_ind = [] 261 | final_lbl_features = [] 262 | for best in best_indices: 263 | final_best_ind.append(best[1]) 264 | final_lbl_features.append(lbl_features[best[2]]) 265 | return final_best_ind, final_lbl_features 266 | 267 | def get_word_features(self): 268 | return self.features, self.words 269 | 270 | def word_vector(self, label): 271 | """ Load in the word vector for the given label """ 272 | try: 273 | return self._vectors[label] 274 | except: 275 | print "Missing label: ", label 276 | return None 277 | 278 | -------------------------------------------------------------------------------- /lib/bt_datasets/imagenet1k.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import os.path as osp 3 | from copy import deepcopy 4 | from random import shuffle 5 | 6 | from bt_datasets.imdb import imdb 7 | from bt_net.config import cfg 8 | import numpy as np 9 | import scipy.io as sio 10 | import cPickle 11 | import xml.etree.ElementTree as ET 12 | import xml.dom.minidom as minidom 13 | from preprocessing.preprocess_images import read_img 14 | 15 | from random import shuffle 16 | 17 | 18 | class imagenet1k(imdb): 19 | def __init__(self, image_set): 20 | imdb.__init__(self, 'imagenet1k_' + image_set) 21 | self._image_set = image_set 22 | self._devkit_path = osp.join(cfg.IMAGE_DATA_DIR, 'imagenet1k') 23 | self._data_path = osp.join(self._devkit_path, 'Data/CLS-LOC', self._image_folder()) 24 | self._synsets = open((osp.join(self._devkit_path, 'map_clsloc.txt'))).readlines() 25 | 26 | self._classes = ('__background__',) 27 | self._wnid = (0,) 28 | for i in xrange(1000): 29 | line = self._synsets[i].rstrip().split(" ") 30 | self._classes = self._classes + (line[2],) 31 | self._wnid = self._wnid + (line[0],) 32 | self._wnid_to_ind = dict(zip(self._wnid, xrange(self.num_classes))) 33 | self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes))) 34 | self._image_ext = ['.JPEG'] 35 | self._image_index = self._load_image_set_index() 36 | # print "image_index", self._image_index 37 | 38 | 39 | # Specific config options 40 | self.config = {'cleanup': True, 41 | 'use_salt': True, 42 | 'top_k': 2000} 43 | 44 | assert os.path.exists(self._devkit_path), \ 45 | 'Devkit path does not exist: {}'.format(self._devkit_path) 46 | assert os.path.exists(self._data_path), \ 47 | 'Path does not exist: {}'.format(self._data_path) 48 | 49 | def image_index_path(self, i): 50 | return self._image_index[i][0] 51 | 52 | def _image_folder(self): 53 | image_folder = "" 54 | if 'val' in self._image_set: 55 | return 'val' 56 | #elif 'test_zs' in self._image_set: 57 | # return 'test_zs' 58 | #elif 'test' in self._image_set: 59 | # return 'test' 60 | return 'train' 61 | 62 | def image_path_at(self, i): 63 | """ 64 | Return the absolute path to image i in the image sequence. 65 | """ 66 | return self.image_path_from_index(self._image_index[i]) 67 | 68 | def image_path_from_index(self, index): 69 | """ 70 | Construct an image path from the image's "index" identifier. 71 | """ 72 | image_path = os.path.join(self._data_path, 73 | index[0] + self._image_ext[0]) 74 | assert os.path.exists(image_path), \ 75 | 'Path does not exist: {}'.format(image_path) 76 | return image_path 77 | 78 | def _load_image_set_index(self): 79 | """ 80 | Load the indexes listed in this dataset's image set file. 81 | """ 82 | # Example path to image set file: 83 | # self._data_path + /ImageSets/val.txt or self._data_path + /ImageSets/zero_shot.txt 84 | image_set_file = osp.join(self._devkit_path, 'ImageSets','CLS-LOC', self._image_set + '.txt') 85 | assert os.path.exists(image_set_file), \ 86 | 'Path does not exist: {}'.format(image_set_file) 87 | 88 | with open(image_set_file) as f: 89 | image_index = [line.split() for line in f] 90 | return image_index 91 | 92 | def get_labels(self, space): 93 | labels = [] 94 | imageset_path = osp.join(self._devkit_path,'ImageSets', 'CLS-LOC',) 95 | unseen_f = open(osp.join(imageset_path, 'test_200_classes.txt')) 96 | labels = [x.strip() for x in unseen_f] 97 | 98 | if space == 2: # If space is 2, include seen labels in space 99 | seen_f = open(osp.join(imageset_path, 'train_800_classes.txt')) 100 | for line in seen_f: 101 | labels.append(line.strip()) 102 | 103 | return labels 104 | 105 | def generate_zs_split(self): 106 | new_list = deepcopy(self._synsets) 107 | shuffle(new_list) 108 | train_split = os.path.join(self._devkit_path, "train_800_classes.txt") 109 | test_split = os.path.join(self._devkit_path, "test_200_classes.txt") 110 | 111 | with open(train_split,"w") as train, open(test_split,"w") as test: 112 | for i in range(800): 113 | train.write(new_list[i]) 114 | for x in range(800,1000): 115 | test.write(new_list[x]) 116 | print "Wrote train and test split to:" 117 | print os.path.join(self._devkit_path, 'ImageSets/CLS-LOC/') 118 | 119 | def clean_zs_split(self, lang_db): 120 | train_800 = os.path.join(self._devkit_path, "ImageSets", "CLS-LOC", "train_800.txt") 121 | test_200 = os.path.join(self._devkit_path, "ImageSets", "CLS-LOC", "test_200.txt") 122 | dict_200 = {} 123 | dict_800 = {} 124 | with open(train_800) as train, open(osp.join(self._devkit_path, "train_800_fixed_" + lang_db.name + ".txt"),"w") as train_fix, open("train_800_classes_"+ lang_db.name + ".txt", "w") as train_lbl: 125 | for line in train: 126 | wnid = line.rstrip().split("/")[0] 127 | cls = self._classes[self._wnid_to_ind[wnid]].lower() 128 | if lang_db.word_vector(cls.lower()) is not None: 129 | train_fix.write(line) 130 | if cls not in dict_800: 131 | dict_800[cls] = 1 132 | train_lbl.write(cls + "\n") 133 | with open(test_200) as train, open(osp.join(self._devkit_path, "test_200_fixed_" + lang_db.name + ".txt"), "w") as test_fix, open("test_200_classes_" + lang_db.name + ".txt", "w") as test_lbl: 134 | for line in train: 135 | wnid = line.rstrip().split("/")[0] 136 | cls = self._classes[self._wnid_to_ind[wnid]].lower() 137 | if lang_db.word_vector(cls.lower()) is not None: 138 | test_fix.write(line) 139 | if cls not in dict_200: 140 | dict_200[cls] = 1 141 | test_lbl.write(cls + "\n") 142 | 143 | print "dict_200", len(dict_200.keys()) 144 | print "dict_800", len(dict_800.keys()) 145 | 146 | 147 | 148 | 149 | def zs_text_files(self): 150 | train_zs = open(os.path.join(self._devkit_path, "train_800_classes.txt")).readlines() 151 | train_zs_wnid = [x.split(" ")[0] for x in train_zs] 152 | test_zs = open(os.path.join(self._devkit_path, "test_200_classes.txt")).readlines() 153 | test_zs_wnid = [x.split(" ")[0] for x in test_zs] 154 | 155 | #train_split = os.path.join(self._devkit_path, 'ImageSets/CLS-LOC/','train_800.txt') 156 | #test_split = os.path.join(self._devkit_path, 'ImageSets/CLS-LOC/','test_200.txt') 157 | train_split = os.path.join(self._devkit_path, 'train_800.txt') 158 | test_split = os.path.join(self._devkit_path, 'test_200.txt') 159 | 160 | with open(train_split,"w") as train, open(test_split, "w") as test: 161 | index = [ix for ix in range(self.num_images)] 162 | for i in range(self.num_images): 163 | image_dir = self._image_index[index[i]] 164 | image_dir0 = image_dir[0] 165 | print image_dir 166 | wn_id = image_dir0.split("/")[0] 167 | if wn_id in train_zs_wnid: 168 | train.write(" ".join(image_dir) +"\n") 169 | elif wn_id in test_zs_wnid: 170 | test.write(" ".join(image_dir) + "\n") 171 | 172 | def load_classes(self): 173 | return self._classes[1:] # Return classes except background 174 | 175 | def load_image_lbl(self): 176 | i = 0 177 | index = [ix for ix in range(self.num_images)] 178 | shuffle(index) 179 | while i < self.num_images: 180 | 181 | image_dir = self._image_index[index[i]][0] 182 | wn_id = image_dir.split("/")[0] 183 | image_path = self.image_path_at(index[i]) 184 | i += 1 185 | if i >= self.num_images: 186 | i = 0 187 | shuffle(index) 188 | yield image_path, [self._classes[self._wnid_to_ind[wn_id]].lower()] 189 | 190 | def load_image_lbl_imdir(self): 191 | i = 0 192 | index = [ix for ix in range(self.num_images)] 193 | shuffle(index) 194 | while i < self.num_images: 195 | 196 | image_dir = self._image_index[index[i]][0] 197 | full_image_dir = self._image_index[index[i]] 198 | wn_id = image_dir.split("/")[0] 199 | image_path = self.image_path_at(index[i]) 200 | i += 1 201 | yield image_path, [self._classes[self._wnid_to_ind[wn_id]].lower()], full_image_dir 202 | 203 | def get_image_data_clean(self, index, lang_db): 204 | image_data = self._image_index[index] 205 | img_path = osp.join(self._data_path, image_data[0] + self._image_ext[0]) 206 | word_vec = lang_db.word_vector(image_data[1]) 207 | return img_path, word_vec 208 | 209 | def gt_at(self, index): 210 | image_dir = self._image_index[index][0] 211 | wn_id = image_dir.split("/")[0] 212 | return self._classes[self._wnid_to_ind[wn_id]].lower() 213 | 214 | def get_image_label(self, index, lang_db): 215 | image_dir = self._image_index[index][0] 216 | wn_id = image_dir.split("/")[0] 217 | image_path = self.image_path_at(index) 218 | 219 | cls = self._classes[self._wnid_to_ind[wn_id]].lower() 220 | y = lang_db.word_vector(cls) 221 | if y is None: return False 222 | return [image_dir, cls] 223 | 224 | def get_val_label(self, index, lang_db): 225 | image_dir = self._image_index[index] 226 | filename = os.path.join(self._devkit_path, 'Annotations/CLS-LOC/', self._image_folder(), image_dir[0] + '.xml') 227 | tree = ET.parse(filename) 228 | 229 | def get_data_from_tag(node, tag): 230 | return node.getElementsByTagName(tag)[0].childNodes[0].data 231 | 232 | with open(filename) as f: 233 | data = minidom.parseString(f.read()) 234 | 235 | objs = data.getElementsByTagName('object') 236 | num_objs = len(objs) 237 | if num_objs != 1: return False 238 | 239 | cls = self._classes[self._wnid_to_ind[str(get_data_from_tag(objs[0], "name")).lower().strip()]] 240 | y = lang_db.word_vector(cls) 241 | if y is None: return False 242 | return [image_dir[0], cls] 243 | 244 | 245 | def load_val_data(self, lang_db): 246 | val_file_set = osp.join(self._devkit_path, 'ImageSets/CLS-LOC/val_bts.txt') 247 | val_folder_path = osp.join(self._devkit_path, 'Data', 'CLS-LOC', 'val') 248 | with open(val_file_set) as f: 249 | images_data = [line.split() for line in f] 250 | shuffle(images_data) 251 | while 1: 252 | for img_data in images_data: 253 | img_path = osp.join(val_folder_path, img_data[0] + self._image_ext[0]) 254 | x = read_img(img_path) 255 | word_vec = lang_db.word_vector(img_data[1]) 256 | yield np.expand_dims(x, axis=0), np.expand_dims(word_vec, axis=0) 257 | 258 | def load_generated_box_features(self, proposals_file): 259 | """ Load the boxes generated by py-faster-rcnn """ 260 | num_images = self.num_images 261 | 262 | # Load pickle with proposals 263 | proposals_path = os.path.join(self._devkit_path, 'Generated_proposals', proposals_file) 264 | print(proposals_path) 265 | with open(proposals_path, 'rb') as input_file: 266 | proposals = cPickle.load(input_file) 267 | 268 | # proposals has shape (num_images, 2000, 4) and a few images (600, 4) 269 | 270 | num_proposals = len(proposals) 271 | assert num_images != num_proposals, \ 272 | 'Mismatch between number of images in imagelist and the proposal list, {}, {}'.format(num_images, 273 | num_proposals) 274 | i = 0 275 | 276 | # Clean boxes smaller than 0.3 of image, have top 3 or top 5 of bounding boxes? Or base it on scores? 277 | for ix in proposals: 278 | if i % 100 == 0: 279 | ix = np.array(ix) 280 | print ix[0] 281 | print("images {}/{}".format(i, num_images)) 282 | i = i + 1 283 | 284 | 285 | -------------------------------------------------------------------------------- /lib/preprocessing/clean_nus_wide.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os, sys 4 | import os.path as osp 5 | from itertools import izip 6 | 7 | 8 | def clean_backslashes(imagelist, clean_list): 9 | # Change backslashes to forwardslashes 10 | with open(imagelist) as f, open(clean_list, 'w') as wf: 11 | for line in f: 12 | words = line.split('\\') 13 | wf.write('%s' % words[0]) 14 | wf.write('/%s\n' % words[1].strip()) 15 | 16 | 17 | def remove_no_tags(imagelist, clean_list, tag_list, clean_tag_list): 18 | # Read line for line in taglist, if vector doesnt contain a 1 (sum of vector is 0), do not add to list, if sum > 0, add image to imagelist, and vector to clean taglist 19 | 20 | with open(imagelist) as ilf, \ 21 | open(clean_list, 'w') as clf, \ 22 | open(tag_list) as tlf, \ 23 | open(clean_tag_list, 'w') as ctlf: 24 | 25 | tot_img = 0 26 | clean_img = 0 27 | for tag_line, image_line in izip(tlf, ilf): 28 | tot_img = tot_img + 1 29 | numbers = tag_line.strip().split() 30 | 31 | numbers = [int(x) for x in numbers] 32 | if sum(numbers) > 0: 33 | ctlf.write(tag_line) 34 | clf.write(image_line) 35 | clean_img = clean_img + 1 36 | 37 | print "Finished with total images: {} of total {}".format(clean_img, tot_img) 38 | 39 | 40 | def make_925_train(nwh): 41 | labels_81 = open(osp.join(nwh, "Concepts", "Concepts81.txt")).readlines() 42 | labels_1k = open(osp.join(nwh, "Concepts", "TagList1k.txt")).readlines() 43 | 44 | label_index = [] 45 | for i in range(len(labels_81)): 46 | for j in range(len(labels_1k)): 47 | if labels_81[i].rstrip().lower() == labels_1k[j].rstrip().lower(): 48 | label_index.append(j) 49 | 50 | return label_index 51 | 52 | 53 | def check_overlap(indices, numbers): 54 | for ind in indices: 55 | if numbers[ind] == 1: 56 | return True 57 | 58 | 59 | def remove_overlap(nwh, imagelist, clean_list, tag_list, clean_tag_list, test_zs_u_im, test_us_im): 60 | # ImageList_1k 61 | # ImageList 81k 62 | # Write labels from original im_list at that same place. Labels 81 and 1k 63 | 64 | overlapping = make_925_train(nwh) 65 | with open(imagelist) as im_list, \ 66 | open(clean_list, 'w') as cl_im_list, \ 67 | open(tag_list) as tag_list, \ 68 | open(clean_tag_list, 'w') as cl_tg_list: 69 | tot_img = 0 70 | clean_img = 0 71 | for tag_line, image_line in izip(tag_list, im_list): 72 | tot_img += 1 73 | numbers = tag_line.strip().split() 74 | numbers = [int(x) for x in numbers] 75 | if check_overlap(overlapping, numbers): 76 | continue 77 | else: 78 | cl_im_list.write(image_line) 79 | cl_tg_list.write(tag_line) 80 | clean_img += 1 81 | print "Finished with total images: {} of total {}".format(clean_img, tot_img) 82 | orginal_1k_tag = open(osp.join(nwh, "Train_Tags1k.txt")).readlines() 83 | original_81_tag = open(osp.join(nwh, "Train_Tags81.txt")).readlines() 84 | labels_81 = open(osp.join(nwh, "Concepts", "Concepts81.txt")).readlines() 85 | labels_1k = open(osp.join(nwh, "Concepts", "TagList1k.txt")).readlines() 86 | original_imagelist = open(osp.join(nwh, "ImageList", "TrainImagelist.txt")).readlines() 87 | seen_unseen_counter = 0 88 | unseen_counter = 0 89 | with open(test_zs_u_im, "w") as test_zs_u_img, \ 90 | open(test_us_im, "w") as test_zs_us_img: 91 | imagelist1k_zs = open(clean_list).readlines() 92 | for i in range(len(original_imagelist)): 93 | img = original_imagelist[i] 94 | if img not in imagelist1k_zs: 95 | 96 | line_1k = orginal_1k_tag[i].rstrip().split() 97 | line_81 = original_81_tag[i].rstrip().split() 98 | vector_1k_ind = [i for i, z in enumerate([int(x) for x in line_1k]) if z != 0] 99 | vector_81 = [vec_81_ind for vec_81_ind, val in enumerate([int(y) for y in line_81]) if val != 0] 100 | lbl81 = [] 101 | lbl1k = [] 102 | for n in vector_1k_ind: 103 | lbl1k.append(labels_1k[n].strip()) 104 | for m in vector_81: 105 | lbl81.append(labels_81[m].strip()) 106 | seen_unseen_list = list(set(lbl81 + lbl1k)) 107 | unseen_list = list(set(lbl81)) 108 | if len(seen_unseen_list) != 0: 109 | test_zs_us_img.write(img.strip() + "," + ",".join(seen_unseen_list) + "\n") 110 | seen_unseen_counter += 1 111 | if len(unseen_list) != 0: 112 | test_zs_u_img.write(img.strip() + "," + ",".join(unseen_list) + "\n") 113 | unseen_counter += 1 114 | 115 | print "Finished with total u/s: {} , u: {}".format(seen_unseen_counter, unseen_counter) 116 | 117 | 118 | def non_overlap_file(classes1, classes2, nwh, filename): 119 | non_overlap = {} 120 | 121 | for cls in classes2: 122 | if cls.rstrip().lower() not in non_overlap and cls not in classes1: 123 | non_overlap[cls.rstrip().lower()] = 1 124 | 125 | with open(osp.join(nwh, filename), "w") as file_wr: 126 | for the_class in non_overlap.keys(): 127 | file_wr.write(the_class.rstrip().lower() + "\n") 128 | 129 | 130 | def resulting_classes(nwh): 131 | class_ind = {} 132 | tag_file = osp.join(nwh, 'Train_Tags925.txt') 133 | with open(tag_file) as tags: 134 | for line in tags: 135 | numbers = line.strip().split() 136 | numbers = [int(x) for x in numbers] 137 | indices = [i for i, x in enumerate(numbers) if x != 0] 138 | for ind in indices: 139 | if ind not in class_ind: 140 | class_ind[ind] = ind 141 | 142 | labels_1k = open(osp.join(nwh, "Concepts", "TagList1k.txt")).readlines() 143 | with open(osp.join(nwh, "Train_Tags" + str(len(class_ind.keys())) + "_clean"), "w") as file: 144 | for index in class_ind.keys(): 145 | file.write(labels_1k[index].rstrip() + "\n") 146 | 147 | 148 | def train_img_lbl_clean(imagelist, image_tags, labels_file, newfile): 149 | imagelist = open(imagelist).readlines() 150 | image_tags = open(image_tags).readlines() 151 | labels = open(labels_file).readlines() 152 | with open(newfile, "w") as file: 153 | for img, tag in izip(imagelist, image_tags): 154 | tag = tag.strip().split() 155 | indices = [i for i, x in enumerate([int(x) for x in tag]) if x != 0] 156 | lbls = [labels[ind].strip() for ind in indices] 157 | file.write(img.rstrip() + "," + ",".join(lbls) + "\n") 158 | 159 | 160 | def clean_train_test(lang_name, train_u_s, train_u, test): 161 | from language_models.language_factory import get_language_model 162 | lang_db = get_language_model(lang_name) 163 | print lang_db 164 | train_u_s_read = open(train_u_s).readlines() 165 | print "len(train_us)", len(train_u_s_read) 166 | unique_lbl_t = {} 167 | unique_lbl_u = {} 168 | unique_lbl_us = {} 169 | 170 | train_u_read = open(train_u).readlines() 171 | print "len(train_u)", len(train_u_s_read) 172 | 173 | test_read = open(test).readlines() 174 | print "len(train)", len(train_u_s_read) 175 | 176 | 177 | with open(train_u_s, "w") as train_u_s_file, \ 178 | open(train_u, "w") as train_u_file, \ 179 | open(test, "w") as test_file: 180 | for img_lbls in train_u_s_read: 181 | line = img_lbls.split(",") 182 | 183 | img = line[0].strip() 184 | lbls = [x.strip() for x in line[1:]] 185 | new_lbls = [] 186 | for lbl in lbls: 187 | lbl = lbl.strip() 188 | vec = lang_db.word_vector(lbl) 189 | if vec is not None: 190 | new_lbls.append(lbl) 191 | if lbl is not None: 192 | unique_lbl_us[lbl] = 1 193 | img_path = img.split("\\") 194 | path = img_path[0] + "/" + img_path[1] 195 | train_u_s_file.write(path + "," + ",".join(set(new_lbls)) + "\n") 196 | print "s/u",path, new_lbls 197 | 198 | 199 | for img_lbls in train_u_read: 200 | line = img_lbls.split(",") 201 | img = line[0].strip() 202 | lbls = [x.strip() for x in line[1:]] 203 | new_lbls = [] 204 | for lbl in lbls: 205 | lbl = lbl.strip() 206 | vec = lang_db.word_vector(lbl.strip()) 207 | if vec is not None: 208 | new_lbls.append(lbl.strip()) 209 | if lbl is not unique_lbl_us: 210 | unique_lbl_u[lbl] = 1 211 | img_path = img.split("\\") 212 | path = img_path[0] + "/" + img_path[1] 213 | print "unseen",path, new_lbls 214 | train_u_file.write(path + "," + ",".join(set(new_lbls)) + "\n") 215 | 216 | for img_lbls in test_read: 217 | line = img_lbls.split(",") 218 | img = line[0].strip() 219 | lbls = [x.strip() for x in line[1:]] 220 | new_lbls = [] 221 | 222 | for lbl in lbls: 223 | lbl = lbl.strip() 224 | vec = lang_db.word_vector(lbl) 225 | if vec is not None: 226 | new_lbls.append(lbl) 227 | if lbl is not unique_lbl_us: 228 | unique_lbl_t[lbl] = 1 229 | 230 | img_path = img.split("\\") 231 | path = img_path[0] + "/" + img_path[1] 232 | print "test",path, new_lbls 233 | 234 | test_file.write(path + "," + ",".join(set(new_lbls)) + "\n") 235 | 236 | def train_1k(tags, imagelist, newfilename): 237 | orginal_1k_labels = open(osp.join(nwh, "Concepts", "TagList1k.txt")).readlines() 238 | with open(tags) as tag, open(imagelist) as img_l, open(newfilename, "w") as newfile: 239 | for line in tag: 240 | line = line.rstrip().split() 241 | numbers = [int(x) for x in line] 242 | indices = [i for i,x in enumerate(numbers) if x != 0] 243 | print indices 244 | labels = [] 245 | for ind in indices: 246 | labels.append(orginal_1k_labels[ind].rstrip()) 247 | img_dir = img_l.next() 248 | newfile.write(img_dir.rstrip() + "," + ",".join(labels)+ "\n") 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | if __name__ == '__main__': 257 | nwh = '../../data/image_data/nus_wide' 258 | ''' 259 | # clean_backslashes('train.txt', 'train_clean.txt') 260 | remove_no_tags(osp.join(nwh, 'ImageList', 'TrainImagelist.txt'), 261 | osp.join(nwh, 'ImageList', 'TrainImagelist_clean_81.txt'), 262 | osp.join(nwh, 'Train_Tags81.txt'), 263 | osp.join(nwh, 'Train_Tags81_clean.txt') 264 | ) 265 | 266 | remove_no_tags(osp.join(nwh, 'ImageList', 'TestImagelist.txt'), 267 | osp.join(nwh, 'ImageList', 'TestImagelist_clean.txt'), 268 | osp.join(nwh, 'Test_Tags81.txt'), 269 | osp.join(nwh, 'Test_Tags81_clean.txt') 270 | ) 271 | 272 | remove_no_tags(osp.join(nwh, 'ImageList', 'TestImagelist.txt'), 273 | osp.join(nwh, 'ImageList', 'TestImagelist_clean.txt'), 274 | osp.join(nwh, 'Test_Tags1k.txt'), 275 | osp.join(nwh, 'Test_Tags1k_clean.txt') 276 | ) 277 | remove_no_tags(osp.join(nwh, 'ImageList', 'TestImagelist.txt'), 278 | osp.join(nwh, 'ImageList', 'TestImagelist_clean.txt'), 279 | osp.join(nwh, 'Train_Tags1k.txt'), 280 | osp.join(nwh, 'Train_Tags1k_clean.txt') 281 | ) 282 | 283 | clean_backslashes(osp.join(nwh, 'ImageList', 'TrainImagelist_clean.txt'), 284 | osp.join(nwh, 'train_clean.txt') 285 | ) 286 | clean_backslashes(osp.join(nwh, 'ImageList', 'TestImagelist_clean.txt'), 287 | osp.join(nwh, 'test_clean.txt') 288 | ) 289 | 290 | ' ''' 291 | ''' 292 | remove_no_tags(osp.join(nwh, 'ImageList', 'TrainImagelist.txt'), 293 | osp.join(nwh, 'ImageList', 'TrainImagelist_clean_81.txt'), 294 | osp.join(nwh, 'Train_Tags81.txt'), 295 | osp.join(nwh, 'Train_Tags81_clean.txt') 296 | ) 297 | remove_no_tags(osp.join(nwh, 'ImageList', 'TrainImagelist.txt'), 298 | osp.join(nwh, 'ImageList', 'TrainImagelist_clean_1k.txt'), 299 | osp.join(nwh, 'Train_Tags1k.txt'), 300 | osp.join(nwh, 'Train_Tags1k_clean.txt') 301 | ) 302 | 303 | 304 | remove_overlap(nwh, \ 305 | osp.join(nwh, 'ImageList', 'TrainImagelist_clean_1k.txt') ,\ 306 | osp.join(nwh, 'ImageList', 'TrainImagelist_925.txt'), \ 307 | osp.join(nwh, 'Train_Tags1k_clean.txt'), \ 308 | osp.join(nwh, 'Train_Tags925.txt'),\ 309 | osp.join(nwh, 'ImageList', 'Test_zs_u_img_lbl.txt'), \ 310 | osp.join(nwh, "ImageList", "Test_zs_us_img_lbl.txt")) 311 | 312 | train_img_lbl_clean(osp.join(nwh, 'ImageList', 'TrainImagelist_clean_1k.txt'), \ 313 | osp.join(nwh, 'Train_Tags1k_clean.txt'),\ 314 | osp.join(nwh, "Concepts", "TagList1k.txt"), \ 315 | osp.join(nwh,"ImageList", "Train_zs_920_img_lbl.txt")) 316 | ''' 317 | ''' 318 | this_dir = osp.dirname(__file__) 319 | lib_path = osp.join(this_dir, '..', '..', 'lib') 320 | sys.path.insert(0, lib_path) 321 | nwh = osp.join(this_dir, "..", "..", "data", "image_data", "nus_wide") 322 | import os 323 | 324 | print("abspath", os.path.abspath(nwh)) 325 | print os.path.exists(osp.join(nwh, 'ImageList', 'Test_zs_u_img_lbl.txt')) 326 | print os.path.exists(osp.join(nwh, "ImageList", "Test_zs_us_img_lbl.txt")) 327 | print os.path.exists(osp.join(nwh, 'ImageList', 'Test_zs_u_img_lbl.txt')) 328 | 329 | clean_train_test('glove_wiki_300', osp.join(nwh, 'ImageList', 'Test_zs_us_img_lbl.txt'), 330 | osp.join(nwh, "ImageList", "Test_zs_u_img_lbl.txt"), \ 331 | osp.join(nwh, "ImageList", "Train_zs_920_img_lbl.txt")) 332 | ''' 333 | # labels_81 = open(osp.join(nwh, "Concepts", "Concepts81.txt")).readlines() 334 | # labels_1k= open(osp.join(nwh, "Concepts", "TagList1k.txt")).readlines() 335 | # non_overlap_file(labels_81, labels_1k, nwh, "Train_925_classes.txt") 336 | # resulting_classes(nwh) 337 | # create_test(nwh, osp.join(nwh, 'ImageList', 'TrainImagelist_clean.txt'),\ 338 | # osp.join(nwh, 'ImageList', 'TrainImagelist_925.txt'), \ 339 | # osp.join(nwh, 'Train_Tags81_clean.txt'), \ 340 | # osp.join(nwh, 'ImageList', 'Test_imagelist_zs.txt'), osp.join(nwh, 'Test_tags_all_zs.txt'), 341 | # osp.join(nwh, 'Test_tags_unseen.txt')) 342 | 343 | clean_backslashes(osp.join(nwh,"ImageList", "TrainImagelist_clean_1k.txt"), osp.join(nwh,"TrainImagelist_bs_clean_1k.txt")) 344 | train_1k(osp.join(nwh,"Train_Tags1k_clean.txt"), osp.join(nwh, "TrainImagelist_bs_clean_1k.txt"), osp.join(nwh, "train_1k.txt")) -------------------------------------------------------------------------------- /lib/bt_net/evaluate_results.py: -------------------------------------------------------------------------------- 1 | # 2 | # File to evaluate results 3 | # 4 | # Per-Class recall, per-class precision, overall-recall, 5 | # overall-precision, percentage of recalled labels in all labels (N+) 6 | # mean-average-precision (map) 7 | # 8 | 9 | import numpy as np 10 | import os.path as osp 11 | import operator 12 | from sklearn.metrics.pairwise import cosine_similarity 13 | from scipy.spatial import distance 14 | from annoy import AnnoyIndex 15 | import os 16 | 17 | 18 | def overall_precision(correctly_annotated, n_preds): 19 | """ Overall Precision """ 20 | n_c = 0 21 | n_p = 0 22 | for key, value in correctly_annotated.iteritems(): 23 | n_c += value 24 | n_p += n_preds[key] 25 | 26 | return round(float(n_c) / n_p, 5) 27 | 28 | 29 | def per_class_precision(correctly_annotated, n_preds): 30 | """ Per-Class prediction """ 31 | precision = 0.0 32 | for key, value in correctly_annotated.iteritems(): 33 | if value != 0 and n_preds[key] != 0: 34 | precision += float(value) / n_preds[key] 35 | 36 | return round(precision / float(len(correctly_annotated)), 5) 37 | 38 | def overall_recall(correctly_annotated, gt): 39 | """ Overall Recall """ 40 | n_c = 0 41 | n_g = 0 42 | for key, value in correctly_annotated.iteritems(): 43 | n_c += value 44 | n_g += gt[key] 45 | 46 | return round(float(n_c) / n_g, 5) 47 | 48 | def per_class_recall(correctly_annotated, gt): 49 | """ Per-class Recall """ 50 | recall = 0 51 | for key, value in correctly_annotated.iteritems(): 52 | if value != 0 and gt[key] != 0: 53 | recall += float(value)/gt[key] 54 | 55 | return round(recall / float(len(correctly_annotated)), 5) 56 | 57 | def apk(actual, predicted, k): 58 | """ Average precision at k """ 59 | 60 | if len(predicted)>k: 61 | predicted = predicted[:k] 62 | 63 | score = 0.0 64 | num_hits = 0.0 65 | 66 | for i,p in enumerate(predicted): 67 | if p in actual and p not in predicted[:i]: 68 | num_hits += 1.0 69 | score += num_hits / (i+1.0) 70 | 71 | if not actual: 72 | return 0.0 73 | 74 | return round(score / min(len(actual), k), 5) 75 | 76 | def mapk(actual, predicted, k): 77 | """ 78 | Computes the mean average precision at k 79 | """ 80 | return round(np.mean([apk(a,p,k) for a,p in zip(actual, predicted)]), 4) * 100 81 | 82 | def flat_hit(actual, predicted, k): 83 | corr = 0 84 | for a, p in zip(actual, predicted): 85 | if a[0] in p[:k]: 86 | corr += 1 87 | 88 | return round(corr / float(len(actual)), 4) * 100 # return percentage 89 | 90 | def total_cosine(total_actual, predicted, singlelabel,lang_mod, weighted=False): 91 | total = 0 92 | counter = 0 93 | if weighted: 94 | count_dict, largest = total_dict(total_actual) 95 | if singlelabel: 96 | for gt, pred in zip(total_actual, predicted): 97 | if weighted: 98 | count = count_dict[gt[0].lower()] 99 | total += largest/count * cosine_similarity([lang_mod.word_vector(gt[0].lower())], 100 | [lang_mod.word_vector(pred[0].lower())]) 101 | 102 | else: 103 | count = 1 104 | #if gt[0].lower()!=pred[0].lower(): 105 | counter += 1 106 | total += (1 / float(count)) *cosine_similarity([lang_mod.word_vector(gt[0].lower())], [lang_mod.word_vector(pred[0].lower())]) 107 | print "Average cosine similarity - Single label" 108 | if weighted: 109 | print round(total / float(len(total_actual)),5) 110 | else: 111 | print round(total / float(counter),5) 112 | 113 | else: 114 | number_wrong_iter = 0 115 | total_cosine_var = 0 116 | number_words = 0 117 | for gt, pred in zip(total_actual, predicted): 118 | t = AnnoyIndex(300, metric="euclidean") 119 | for i in range(len(gt)): 120 | t.add_item(i, lang_mod.word_vector(gt[i].lower())) 121 | t.build(10) # 10 trees 122 | temp_cosine = 0 123 | num_words_used = 0 124 | for pred_word in pred: 125 | predicted_word_vector = lang_mod.word_vector(pred_word.lower()) 126 | indices = t.get_nns_by_vector(predicted_word_vector, 1, search_k=-1, include_distances=False) 127 | if gt[indices[0]] == pred_word: continue 128 | #else: 129 | # print gt[indices[0]], pred_word 130 | num_words_used +=1 131 | temp_cosine += cosine_similarity([lang_mod.word_vector(gt[indices[0]])], [lang_mod.word_vector(pred_word)]) 132 | 133 | if num_words_used>0: 134 | number_wrong_iter += 1 135 | temp_cosine /= float(num_words_used) 136 | total += temp_cosine 137 | for pred_word in pred: 138 | predicted_word_vector = lang_mod.word_vector(pred_word.lower()) 139 | indices = t.get_nns_by_vector(predicted_word_vector, 1, search_k=-1, include_distances=False) 140 | total_cosine_var += cosine_similarity([lang_mod.word_vector(gt[indices[0]])], [lang_mod.word_vector(pred_word)]) 141 | number_words += 1 142 | 143 | print "Average cosine similarity - Multi label" 144 | print "missed avg cosine", round(total / float(len(total_actual)), 5) 145 | print "Regular cosine", round(total_cosine_var / float(number_words), 5) 146 | #print round(total / float(number_wrong_iter), 5) 147 | 148 | def total_dict(total_actual): 149 | # Count occurences of words, to be able to weight importance of wrong 150 | count_dict = {} 151 | largest = 0 152 | for list in total_actual: 153 | for word in list: 154 | if word not in count_dict: count_dict[word] = 0 155 | count_dict[word] += 1 156 | for key, value in count_dict.iteritems(): 157 | if value > largest: 158 | largest = value 159 | return count_dict, largest 160 | 161 | def average_cosine(fn, lang_mod): 162 | single_label = True 163 | with open(fn) as f: 164 | total_actual = [] 165 | predicted = [] 166 | line_count = 0 167 | print "Reading from file..." 168 | for line in f: 169 | if line_count % 10000 == 0: 170 | print("Number of lines evaluated: {}").format(line_count) 171 | # id, num_actual, a1, a2, ..., an, num_predicted, p1, p2, ..., pn 172 | words = line.strip().split(',') 173 | words = [x.strip() for x in words] 174 | num_actual = int(words[1]) 175 | if single_label and num_actual != 1: 176 | print num_actual 177 | single_label = False # Check if trying to predict multilabel 178 | 179 | actual = words[2:(2 + num_actual)] 180 | num_pred = int(words[(2 + num_actual)]) 181 | preds = words[(3 + num_actual):] 182 | total_actual.append(actual) 183 | predicted.append(preds) 184 | line_count += 1 185 | 186 | total_cosine(total_actual, predicted, single_label, lang_mod) 187 | 188 | def evaluate_flat_map(fn, k_values = [1, 2, 5, 10]): 189 | single_label = True 190 | with open(fn) as f: 191 | total_actual = [] 192 | predicted = [] 193 | line_count = 0 194 | print "Reading from file..." 195 | for line in f: 196 | if line_count % 10000 == 0: 197 | print("Number of lines evaluated: {}").format(line_count) 198 | # id, num_actual, a1, a2, ..., an, num_predicted, p1, p2, ..., pn 199 | words = line.strip().split(',') 200 | words = [x.strip() for x in words] 201 | num_actual = int(words[1]) 202 | if single_label and num_actual != 1: 203 | print num_actual 204 | single_label = False # Check if trying to predict multilabel 205 | 206 | actual = words[2:(2+num_actual)] 207 | num_pred = int(words[(2+num_actual)]) 208 | preds = words[(3+num_actual):] 209 | total_actual.append(actual) 210 | predicted.append(preds) 211 | line_count += 1 212 | 213 | # Measure results 214 | print "Measuring results..." 215 | results = '_' 216 | for k in k_values: 217 | results += ' & ' + str(mapk(total_actual, predicted, k)) 218 | if single_label: 219 | for k in k_values: 220 | results += ' & ' + str(flat_hit(total_actual, predicted, k)) 221 | 222 | 223 | # Print table 224 | if single_label: headers = 'Map 1 2 5 10 | Flat 1 2 5 10' 225 | else: headers = 'Map 1 2 5 10' 226 | print headers 227 | print results 228 | 229 | def evaluate_results(fn, k_values = [1, 2, 5, 10]): 230 | # Load file with predictions 231 | # Calculate correctly annotated pr class 232 | # Get number of GT labeling for each label 233 | # Get number of predictions for each label 234 | single_label = True 235 | with open(fn) as f: 236 | image_ids = [] 237 | total_actual = [] 238 | predicted = [] 239 | correct_ann = {} 240 | n_gt = {} 241 | total_preds = {} 242 | line_count = 0 243 | 244 | print "Reading from file" 245 | for line in f: 246 | if line_count % 1000 == 0: 247 | print("Number of lines evaluated: {}").format(line_count) 248 | #print("Length of corrected annotated: {}").format(len(correct_ann)) 249 | #print("Length of gt: {}").format(len(actual)) 250 | #print("Length of pred: {}").format(len(predicted)) 251 | # id, num_actual, a1, a2, ..., an, num_predicted, p1, p2, ..., pn 252 | words = line.strip().split(',') 253 | words = [x.strip() for x in words] 254 | num_actual = int(words[1]) 255 | if single_label and num_actual != 1: 256 | print num_actual 257 | single_label = False # Check if trying to predict multilabel 258 | 259 | actual = words[2:(2+num_actual)] 260 | num_pred = int(words[(2+num_actual)]) 261 | preds = words[(3+num_actual):] 262 | preds = preds[:k] # only select the k values 263 | 264 | image_ids.append(words[0]) 265 | total_actual.append(actual) 266 | predicted.append(preds) 267 | # Check if correct label exists in dicts 268 | for a in actual: 269 | if a in correct_ann: 270 | # If the correct label exist in predicted labels 271 | if a in preds: 272 | correct_ann[a] += 1 273 | n_gt[a] += 1 274 | else: 275 | correct_ann[a] = 0 276 | n_gt[a] = 0 277 | # Add corrected to number of pred dict 278 | if a not in total_preds: 279 | total_preds[a] = 0 280 | 281 | # Add predicted words to number of predictions 282 | for p in preds: 283 | if p in total_preds: 284 | total_preds[p] += 1 285 | else: 286 | total_preds[p] = 0 287 | line_count += 1 288 | 289 | 290 | 291 | # Measure results 292 | results = '_' 293 | results += ' & ' + str(mapk(total_actual, predicted)) 294 | if single_label: results += ' & ' + str(flat_hit(total_actual, predicted)) 295 | results += ' & ' + str(per_class_recall(correct_ann, n_gt)) 296 | results += ' & ' + str(per_class_precision(correct_ann, total_preds)) 297 | results += ' & ' + str(overall_recall(correct_ann, n_gt)) 298 | results += ' & ' + str(overall_precision(correct_ann, total_preds)) 299 | 300 | # Print table 301 | if single_label: headers = '\thead{Loss} & \thead{MAP@k} & \thead{Flat hit@k}& \thead{Per-class\\ recall} & \thead{Per-class \\ precision} & \thead{Overall\\ Recall} & \thead{Overall\\ precision}\\' 302 | else: headers = '\thead{Loss} & \thead{MAP@k} & \thead{Per-class\\ recall} & \thead{Per-class \\ precision} & \thead{Overall\\ Recall} & \thead{Overall\\ precision}\\' 303 | print headers 304 | 305 | 306 | def evaluate_pr_class(fn, k = 5): 307 | single_label = True 308 | with open(fn) as f: 309 | line_count = 0 310 | correct_ann = {} 311 | n_gt = {} 312 | print "Reading from file..." 313 | for line in f: 314 | if line_count % 10000 == 0: 315 | print("Number of lines evaluated: {}").format(line_count) 316 | # id, num_actual, a1, a2, ..., an, num_predicted, p1, p2, ..., pn 317 | words = line.strip().split(',') 318 | words = [x.strip() for x in words] 319 | num_actual = int(words[1]) 320 | if single_label and num_actual != 1: 321 | print num_actual 322 | single_label = False # Check if trying to predict multilabel 323 | 324 | actual = words[2:(2+num_actual)] 325 | num_pred = int(words[(2+num_actual)]) 326 | preds = words[(3+num_actual):] 327 | line_count += 1 328 | for a in actual: 329 | if a in correct_ann: 330 | # If the correct label exist in predicted labels 331 | if a in preds[:k]: 332 | correct_ann[a] += 1 333 | n_gt[a] += 1 334 | else: 335 | correct_ann[a] = 0 336 | n_gt[a] = 0 337 | 338 | output_path = osp.join('output', 'evaluate_results') 339 | if not osp.exists(output_path): 340 | os.makedirs(output_path) 341 | 342 | with open(osp.join(output_path, 'results_pr_class.txt'), 'w+') as wf: 343 | for c, gt in n_gt.iteritems(): 344 | f_r = round(correct_ann[c] / float(gt), 4) * 100 345 | wf.write('%s' % c) 346 | wf.write(' %s\n' % f_r) 347 | 348 | #show_actual_predicted() 349 | 350 | def show_actual_predicted(fn): 351 | single_label = True 352 | with open(fn) as f: 353 | line_count = 0 354 | correct_ann = {} 355 | n_gt = {} 356 | predicted = {} 357 | print "Reading from file..." 358 | for line in f: 359 | if line_count % 10000 == 0: 360 | print("Number of lines evaluated: {}").format(line_count) 361 | # id, num_actual, a1, a2, ..., an, num_predicted, p1, p2, ..., pn 362 | words = line.strip().split(',') 363 | words = [x.strip() for x in words] 364 | actual = words[2] 365 | preds = words[4:] 366 | line_count += 1 367 | if actual in predicted: 368 | if preds[0] in predicted[actual]: 369 | predicted[actual][preds[0]] += 1 370 | else: 371 | predicted[actual][preds[0]] = 0 372 | # If the correct label exist in predicted labels 373 | else: 374 | predicted[actual] = {preds[0]: 1} 375 | 376 | output_path = osp.join('output', 'evaluate_results') 377 | if not osp.exists(output_path): 378 | os.makedirs(output_path) 379 | 380 | with open(osp.join(output_path, 'img_sl_top_predicted_pr_class.txt'), 'w+') as wf: 381 | for key, value in predicted.items(): 382 | sorted_a = sorted(value.items(), key=operator.itemgetter(1), reverse=True) 383 | wf.write('%s' % key) 384 | wf.write(' %s\n' % sorted_a[:5]) 385 | 386 | def average_gt_classes(fn): 387 | with open(fn) as f: 388 | line_count = 0 389 | num_actual_list = [] 390 | num_pred_list = [] 391 | scores = [] 392 | for line in f: 393 | if line_count % 10000 == 0: 394 | print("Number of lines evaluated: {}").format(line_count) 395 | 396 | words = line.strip().split(',') 397 | words = [x.strip() for x in words] 398 | num_actual = int(words[1]) 399 | 400 | if num_actual > 50: print num_actual 401 | num_actual_list.append(num_actual) 402 | num_pred = int(words[(2+num_actual)]) 403 | num_pred_list.append(num_pred) 404 | 405 | actual = words[2:(2+num_actual)] 406 | preds = words[(3+num_actual):] 407 | 408 | scores.append(apk(actual, preds, 1001)) 409 | 410 | line_count += 1 411 | 412 | sum_actual = sum(num_actual_list) 413 | sum_preds = sum(num_pred_list) 414 | avg_pred = float(sum_preds)/line_count 415 | avg_actual = float(sum_actual)/line_count 416 | print "Average actual: ", avg_actual 417 | print "Average pred: ", avg_pred 418 | print "Mean average precision: ", round(np.mean(scores), 4) * 100 419 | # Calculate average score for plots: 420 | x, y = zip(*sorted((xVal, np.mean([yVal for a, yVal in zip(num_pred_list, scores) if xVal==a])) for xVal in set(num_pred_list))) 421 | import matplotlib as mpl 422 | mpl.use('Agg') 423 | import matplotlib.pyplot as plt 424 | print num_actual 425 | plt.ylabel('Average Precision') 426 | plt.xlabel('Number of actual') 427 | plt.plot(x, y, 'r-', linewidth=2) 428 | plt.figtext(.8, .8, "Average actual: " + str(avg_actual)) 429 | plt.figtext(.8, .75, "Average predicted: " + str(avg_pred)) 430 | 431 | plt.axis([0, 25, 0, 0.5]) 432 | #plt.show() 433 | 434 | from matplotlib2tikz import save as tikz_save 435 | tikz_save('test.tex') 436 | 437 | --------------------------------------------------------------------------------