├── LICENSE ├── Report and PPT ├── MAIN.pdf ├── REVIEW 2.pptx └── test ├── __pycache__ └── __init__.cpython-36.pyc ├── app.py ├── banner.jpeg ├── bg.jpg ├── cap.py ├── captionbot ├── Readme.md ├── captionbot │ ├── __init__.py │ └── captionbot.py └── setup.py ├── data ├── mscoco_label_map.pbtxt ├── pascal_label_map.pbtxt └── pet_label_map.pbtxt ├── image caption ├── __pycache__ │ ├── load_data.cpython-36.pyc │ ├── model.cpython-36.pyc │ └── preprocessing.cpython-36.pyc ├── demo ├── load_data.py ├── model.py ├── model.pyc ├── model_data │ ├── description.txt │ ├── descriptions.txt │ ├── model_19.h5 │ └── tokenizer.pkl ├── preprocessing.py ├── readme.txt ├── test.py ├── test_data │ ├── ai3.jpg │ ├── deep.jpg │ ├── instruction.txt │ └── modern class.jpg ├── train_val.py ├── train_val_data │ └── demo └── utils │ ├── load_data.py │ ├── model.py │ └── preprocessing.py ├── images ├── image3.png ├── image4.png ├── image5.png ├── image6.png └── logo.png ├── opencv.jpg ├── protos ├── BUILD ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── string_int_label_map_pb2.cpython-36.pyc ├── anchor_generator.proto ├── anchor_generator_pb2.py ├── argmax_matcher.proto ├── argmax_matcher_pb2.py ├── bipartite_matcher.proto ├── bipartite_matcher_pb2.py ├── box_coder.proto ├── box_coder_pb2.py ├── box_predictor.proto ├── box_predictor_pb2.py ├── eval.proto ├── eval_pb2.py ├── faster_rcnn.proto ├── faster_rcnn_box_coder.proto ├── faster_rcnn_box_coder_pb2.py ├── faster_rcnn_pb2.py ├── grid_anchor_generator.proto ├── grid_anchor_generator_pb2.py ├── hyperparams.proto ├── hyperparams_pb2.py ├── image_resizer.proto ├── image_resizer_pb2.py ├── input_reader.proto ├── input_reader_pb2.py ├── losses.proto ├── losses_pb2.py ├── matcher.proto ├── matcher_pb2.py ├── mean_stddev_box_coder.proto ├── mean_stddev_box_coder_pb2.py ├── model.proto ├── model_pb2.py ├── optimizer.proto ├── optimizer_pb2.py ├── pipeline.proto ├── pipeline_pb2.py ├── post_processing.proto ├── post_processing_pb2.py ├── preprocessor.proto ├── preprocessor_pb2.py ├── region_similarity_calculator.proto ├── region_similarity_calculator_pb2.py ├── square_box_coder.proto ├── square_box_coder_pb2.py ├── ssd.proto ├── ssd_anchor_generator.proto ├── ssd_anchor_generator_pb2.py ├── ssd_pb2.py ├── string_int_label_map.proto ├── string_int_label_map_pb2.py ├── train.proto └── train_pb2.py ├── readme.md ├── requirements.txt ├── speech.py ├── static.zip ├── templates.zip ├── tests ├── comedor.jpg ├── familia.jpg ├── fiesta.jpg └── mascotas.jpg ├── uploads ├── 1.jpg ├── 1002674143_1b742ab4b8.jpg ├── 1015584366_dfcec3c85a.jpg ├── 1019077836_6fc9b15408.jpg ├── 10815824_2997e03d76.jpg ├── 1110208841_5bb6806afe.jpg ├── 1260816604_570fc35836.jpg ├── 12830823_87d2654e31.jpg ├── 17273391_55cfc7d3d4.jpg ├── 2018-06-sample-gallery.png ├── 2043427251_83b746da8e.jpg ├── 23445819_3a458716c1.jpg ├── 24.jpg ├── 2658009523_b49d611db8.jpg ├── 2661567396_cbe4c2e5be.jpg ├── 27782020_4dab210360.jpg ├── 29871656_938329389659068_5796396155926156498_o.jpg ├── 3047751696_78c2efe5e6.jpg ├── 3134092148_151154139a.jpg ├── 3262793378_773b21ec19.jpg ├── 3301754574_465af5bf6d.jpg ├── 33108590_d685bfe51c.jpg ├── 3354883962_170d19bfe4.jpg ├── 3422458549_f3f3878dbf.jpg ├── 3518687038_964c523958.jpg ├── 35506150_cbdb630f4f.jpg ├── 3595216998_0a19efebd0.jpg ├── 3597210806_95b07bb968.jpg ├── 3601508034_5a3bfc905e.jpg ├── 3601843201_4809e66909.jpg ├── 3607489370_92683861f7.jpg ├── 3613375729_d0b3c41556.jpg ├── 3637013_c675de7705.jpg ├── 3640743904_d14eea0a0b.jpg ├── 36422830_55c844bc2d.jpg ├── 3710971182_cb01c97d15.jpg ├── 37375513_2262615307097058_8585448275321028608_o.jpg ├── 403523132_73b9a1a4b3.jpg ├── 41999070_838089137e.jpg ├── 42637986_135a9786a6.jpg ├── 42637987_866635edf6.jpg ├── 44129946_9eeb385d77.jpg ├── 460195978_fc522a4979.jpg ├── 47871819_db55ac4699.jpg ├── 54501196_a9ac9d66f2.jpg ├── 6.jpg ├── 667626_18933d713e.jpg ├── 72218201_e0e9c7d65b.jpg ├── 818340833_7b963c0ee3.jpg ├── 93922153_8d831f7f01.jpg ├── 99679241_adc853a5c0.jpg ├── CR3.jpg ├── DSC00221.JPG ├── IMG_20180406_162945.jpg ├── IMG_20181003_193709_mh1538575655201.jpg ├── IMG_20190409_112950_mh1554795435105.jpg ├── Penguins.jpg ├── WhatsApp Image 2018-07-16 at 7.32.07 PM.jpeg ├── WhatsApp_Image_2018-07-16_at_7.32.07_PM_1.jpeg ├── ai3.jpg ├── canon-eos-sample-photo.jpg ├── college-grad-temping-feature.jpg ├── comedor.jpg ├── crop635w_accomlishedstudent0Small.jpg ├── deep.jpg ├── depositphotos_14060460-stock-illustration-3d-heart-protection-vector-icon.jpg ├── download (1).jpg ├── download.jpg ├── e020fc9a4def5c66ba435e27109a0890.jpg ├── familia.jpg ├── fiesta.jpg ├── geek.png ├── images (1).jpg ├── images (2).jpg ├── images.jpg ├── latest_mobile.png ├── mascotas.jpg ├── modern_class.jpg ├── molecular-model-colorful_f57072f6-2ab1-11e9-b115-35431bcc9744.jpg ├── opencv.jpg ├── sai.jpg ├── sample-5.jpg ├── trail.jpg ├── trail2.jpg └── yash4.jpg └── utils ├── BUILD ├── __pycache__ ├── __init__.cpython-36.pyc ├── label_map_util.cpython-36.pyc └── visualization_utils.cpython-36.pyc ├── category_util.py ├── category_util_test.py ├── dataset_util.py ├── dataset_util_test.py ├── label_map_util.py ├── label_map_util_test.py ├── learning_schedules.py ├── learning_schedules_test.py ├── metrics.py ├── metrics_test.py ├── np_box_list.py ├── np_box_list_ops.py ├── np_box_list_ops_test.py ├── np_box_list_test.py ├── np_box_ops.py ├── np_box_ops_test.py ├── object_detection_evaluation.py ├── object_detection_evaluation_test.py ├── ops.py ├── ops_test.py ├── per_image_evaluation.py ├── per_image_evaluation_test.py ├── shape_utils.py ├── shape_utils_test.py ├── static_shape.py ├── static_shape_test.py ├── test_utils.py ├── test_utils_test.py ├── variables_helper.py ├── variables_helper_test.py ├── visualization_utils.py └── visualization_utils_test.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Yaswanth Sai Palaghat 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Report and PPT/MAIN.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/Report and PPT/MAIN.pdf -------------------------------------------------------------------------------- /Report and PPT/REVIEW 2.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/Report and PPT/REVIEW 2.pptx -------------------------------------------------------------------------------- /Report and PPT/test: -------------------------------------------------------------------------------- 1 | ... 2 | -------------------------------------------------------------------------------- /__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | from flask import Flask, render_template, request, redirect, url_for, send_from_directory 3 | from flask_bootstrap import Bootstrap 4 | from werkzeug import secure_filename 5 | import numpy as np 6 | import os 7 | import six.moves.urllib as urllib 8 | import sys 9 | import tensorflow as tf 10 | from collections import defaultdict 11 | from io import StringIO 12 | from PIL import Image 13 | sys.path.append("..") 14 | from utils import label_map_util 15 | from utils import visualization_utils as vis_util 16 | MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017' 17 | PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb' 18 | PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt') 19 | NUM_CLASSES = 90 20 | 21 | detection_graph = tf.Graph() 22 | with detection_graph.as_default(): 23 | od_graph_def = tf.GraphDef() 24 | with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: 25 | serialized_graph = fid.read() 26 | od_graph_def.ParseFromString(serialized_graph) 27 | tf.import_graph_def(od_graph_def, name='') 28 | label_map = label_map_util.load_labelmap(PATH_TO_LABELS) 29 | categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) 30 | category_index = label_map_util.create_category_index(categories) 31 | 32 | 33 | def load_image_into_numpy_array(image): 34 | (im_width, im_height) = image.size 35 | return np.array(image.getdata()).reshape( 36 | (im_height, im_width, 3)).astype(np.uint8) 37 | 38 | 39 | app = Flask(__name__) 40 | bootstrap = Bootstrap(app) 41 | 42 | app.config['UPLOAD_FOLDER'] = 'uploads/' 43 | app.config['ALLOWED_EXTENSIONS'] = set(['png', 'jpg', 'jpeg']) 44 | 45 | def allowed_file(filename): 46 | return '.' in filename and \ 47 | filename.rsplit('.', 1)[1] in app.config['ALLOWED_EXTENSIONS'] 48 | 49 | 50 | @app.route('/') 51 | def index(): 52 | return render_template('index.html') 53 | 54 | 55 | @app.route('/upload', methods=['POST']) 56 | def upload(): 57 | file = request.files['file'] 58 | if file and allowed_file(file.filename): 59 | filename = secure_filename(file.filename) 60 | file.save(os.path.join(app.config['UPLOAD_FOLDER'], filename)) 61 | return redirect(url_for('uploaded_file', 62 | filename=filename)) 63 | 64 | 65 | @app.route('/uploads/') 66 | def uploaded_file(filename): 67 | PATH_TO_TEST_IMAGES_DIR = app.config['UPLOAD_FOLDER'] 68 | TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR,filename.format(i)) for i in range(1, 2) ] 69 | IMAGE_SIZE = (12, 8) 70 | 71 | with detection_graph.as_default(): 72 | with tf.Session(graph=detection_graph) as sess: 73 | for image_path in TEST_IMAGE_PATHS: 74 | image = Image.open(image_path) 75 | image_np = load_image_into_numpy_array(image) 76 | image_np_expanded = np.expand_dims(image_np, axis=0) 77 | image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') 78 | boxes = detection_graph.get_tensor_by_name('detection_boxes:0') 79 | scores = detection_graph.get_tensor_by_name('detection_scores:0') 80 | classes = detection_graph.get_tensor_by_name('detection_classes:0') 81 | num_detections = detection_graph.get_tensor_by_name('num_detections:0') 82 | (boxes, scores, classes, num_detections) = sess.run( 83 | [boxes, scores, classes, num_detections], 84 | feed_dict={image_tensor: image_np_expanded}) 85 | vis_util.visualize_boxes_and_labels_on_image_array( 86 | image_np, 87 | np.squeeze(boxes), 88 | np.squeeze(classes).astype(np.int32), 89 | np.squeeze(scores), 90 | category_index, 91 | use_normalized_coordinates=True, 92 | line_thickness=8) 93 | im = Image.fromarray(image_np) 94 | im.save('uploads/'+filename) 95 | 96 | return send_from_directory(app.config['UPLOAD_FOLDER'], 97 | filename) 98 | if __name__ == '__main__': 99 | app.run(debug=True,host='0.0.0.0',port=5000) 100 | 101 | -------------------------------------------------------------------------------- /banner.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/banner.jpeg -------------------------------------------------------------------------------- /bg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/bg.jpg -------------------------------------------------------------------------------- /cap.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import os 3 | from flask import Flask 4 | from flask import render_template 5 | from flask import request 6 | from captionbot import CaptionBot 7 | from flask_bootstrap import Bootstrap 8 | from werkzeug import secure_filename 9 | import numpy as np 10 | import os 11 | import six.moves.urllib as urllib 12 | import sys 13 | import tensorflow as tf 14 | from collections import defaultdict 15 | from flask_gtts import gtts 16 | from gtts import gTTS 17 | from playsound import playsound 18 | from io import StringIO 19 | from PIL import Image 20 | sys.path.append("..") 21 | from utils import label_map_util 22 | from utils import visualization_utils as vis_util 23 | 24 | 25 | language = 'en' 26 | c=CaptionBot() 27 | 28 | app = Flask(__name__) 29 | bootstrap = Bootstrap(app) 30 | 31 | UPLOAD_FOLDER = os.path.basename('/uploads') 32 | app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER 33 | 34 | @app.route('/') 35 | def index(): 36 | return render_template('index.html') 37 | 38 | @app.route('/upload', methods=['POST']) 39 | def upload_file(): 40 | file = request.files['file'] 41 | 42 | f = os.path.join(app.config['UPLOAD_FOLDER'], file.filename) 43 | # add your custom code to check that the uploaded file is a valid image and not a malicious file (out-of-scope for this post) 44 | file.save(f) 45 | print(f) 46 | image_caption = c.file_caption(f) 47 | return render_template('index.html', image_caption=image_caption) 48 | 49 | @app.route("/success", methods = ['POST','GET']) 50 | def success(): 51 | text = c.file_caption(f) 52 | tts = gTTS(text=text, lang='en') 53 | return render_template("success.html",value = text) 54 | tts.save("text.mp3") 55 | 56 | if __name__ == '__main__': 57 | app.run(debug=True, port=3000) 58 | -------------------------------------------------------------------------------- /captionbot/Readme.md: -------------------------------------------------------------------------------- 1 | # captionbot 2 | 3 | [![PyPi Package Version](https://img.shields.io/pypi/v/captionbot.svg)](https://pypi.python.org/pypi/captionbot) 4 | 5 | Captionbot is a simple API wrapper for https://www.captionbot.ai/ 6 | 7 | ## Installation 8 | 9 | You can install captionbot using pip: 10 | ```bash 11 | $ pip install captionbot 12 | ``` 13 | ## Usage 14 | 15 | To use, simply do: 16 | 17 | ```python 18 | >>> from captionbot import CaptionBot 19 | >>> c = CaptionBot() 20 | >>> c.url_caption('your image url here') 21 | >>> c.file_caption('your local image filename here') 22 | ``` 23 | -------------------------------------------------------------------------------- /captionbot/captionbot/__init__.py: -------------------------------------------------------------------------------- 1 | from .captionbot import CaptionBot -------------------------------------------------------------------------------- /captionbot/captionbot/captionbot.py: -------------------------------------------------------------------------------- 1 | import json 2 | import mimetypes 3 | import os 4 | import requests 5 | try: 6 | from urllib.parse import urlencode 7 | except ImportError: 8 | from urllib import urlencode 9 | import logging 10 | logger = logging.getLogger("captionbot") 11 | 12 | 13 | class CaptionBotException(Exception): 14 | pass 15 | 16 | 17 | class CaptionBot: 18 | UPLOAD_URL = "https://www.captionbot.ai/api/upload" 19 | MESSAGES_URL = "https://captionbot.azurewebsites.net/api/messages" 20 | 21 | @staticmethod 22 | def _resp_error(resp): 23 | if not resp.ok: 24 | data = resp.json() 25 | msg = "HTTP error: {}".format(resp.status_code) 26 | if type(data) == dict and "Message" in data: 27 | msg += ", " + data.get("Message") 28 | raise CaptionBotException(msg) 29 | 30 | def __init__(self): 31 | self.session = requests.Session() 32 | 33 | def _upload(self, filename): 34 | url = self.UPLOAD_URL 35 | mime = mimetypes.guess_type(filename)[0] 36 | name = os.path.basename(filename) 37 | files = {'file': (name, open(filename, 'rb'), mime)} 38 | resp = self.session.post(url, files=files) 39 | logger.debug("upload: {}".format(resp)) 40 | self._resp_error(resp) 41 | res = resp.text 42 | if res: 43 | return res[1:-1] 44 | 45 | def url_caption(self, image_url): 46 | data = { 47 | "Content": image_url, 48 | "Type": "CaptionRequest", 49 | } 50 | headers = { 51 | "Content-Type": "application/json; charset=utf-8" 52 | } 53 | url = self.MESSAGES_URL 54 | resp = self.session.post(url, data=json.dumps(data), headers=headers) 55 | logger.info("get_caption: {}".format(resp)) 56 | if not resp.ok: 57 | return None 58 | res = resp.text[1:-1].replace('\\"', '"').replace('\\n', '\n') 59 | logger.info(res) 60 | return res 61 | 62 | def file_caption(self, filename): 63 | upload_filename = self._upload(filename) 64 | return self.url_caption(upload_filename) 65 | -------------------------------------------------------------------------------- /captionbot/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name='captionbot', 4 | version='0.1.4', 5 | description='Simple API wrapper for https://www.captionbot.ai/', 6 | url='http://github.com/krikunts/captionbot', 7 | author='Tatiana Krikun', 8 | author_email='krikunts@gmail.com', 9 | license='MIT', 10 | packages=['captionbot'], 11 | install_requires=[ 12 | 'requests', 13 | ], 14 | zip_safe=False) -------------------------------------------------------------------------------- /data/pascal_label_map.pbtxt: -------------------------------------------------------------------------------- 1 | item { 2 | id: 1 3 | name: 'aeroplane' 4 | } 5 | 6 | item { 7 | id: 2 8 | name: 'bicycle' 9 | } 10 | 11 | item { 12 | id: 3 13 | name: 'bird' 14 | } 15 | 16 | item { 17 | id: 4 18 | name: 'boat' 19 | } 20 | 21 | item { 22 | id: 5 23 | name: 'bottle' 24 | } 25 | 26 | item { 27 | id: 6 28 | name: 'bus' 29 | } 30 | 31 | item { 32 | id: 7 33 | name: 'car' 34 | } 35 | 36 | item { 37 | id: 8 38 | name: 'cat' 39 | } 40 | 41 | item { 42 | id: 9 43 | name: 'chair' 44 | } 45 | 46 | item { 47 | id: 10 48 | name: 'cow' 49 | } 50 | 51 | item { 52 | id: 11 53 | name: 'diningtable' 54 | } 55 | 56 | item { 57 | id: 12 58 | name: 'dog' 59 | } 60 | 61 | item { 62 | id: 13 63 | name: 'horse' 64 | } 65 | 66 | item { 67 | id: 14 68 | name: 'motorbike' 69 | } 70 | 71 | item { 72 | id: 15 73 | name: 'person' 74 | } 75 | 76 | item { 77 | id: 16 78 | name: 'pottedplant' 79 | } 80 | 81 | item { 82 | id: 17 83 | name: 'sheep' 84 | } 85 | 86 | item { 87 | id: 18 88 | name: 'sofa' 89 | } 90 | 91 | item { 92 | id: 19 93 | name: 'train' 94 | } 95 | 96 | item { 97 | id: 20 98 | name: 'tvmonitor' 99 | } 100 | -------------------------------------------------------------------------------- /data/pet_label_map.pbtxt: -------------------------------------------------------------------------------- 1 | item { 2 | id: 1 3 | name: 'Abyssinian' 4 | } 5 | 6 | item { 7 | id: 2 8 | name: 'american_bulldog' 9 | } 10 | 11 | item { 12 | id: 3 13 | name: 'american_pit_bull_terrier' 14 | } 15 | 16 | item { 17 | id: 4 18 | name: 'basset_hound' 19 | } 20 | 21 | item { 22 | id: 5 23 | name: 'beagle' 24 | } 25 | 26 | item { 27 | id: 6 28 | name: 'Bengal' 29 | } 30 | 31 | item { 32 | id: 7 33 | name: 'Birman' 34 | } 35 | 36 | item { 37 | id: 8 38 | name: 'Bombay' 39 | } 40 | 41 | item { 42 | id: 9 43 | name: 'boxer' 44 | } 45 | 46 | item { 47 | id: 10 48 | name: 'British_Shorthair' 49 | } 50 | 51 | item { 52 | id: 11 53 | name: 'chihuahua' 54 | } 55 | 56 | item { 57 | id: 12 58 | name: 'Egyptian_Mau' 59 | } 60 | 61 | item { 62 | id: 13 63 | name: 'english_cocker_spaniel' 64 | } 65 | 66 | item { 67 | id: 14 68 | name: 'english_setter' 69 | } 70 | 71 | item { 72 | id: 15 73 | name: 'german_shorthaired' 74 | } 75 | 76 | item { 77 | id: 16 78 | name: 'great_pyrenees' 79 | } 80 | 81 | item { 82 | id: 17 83 | name: 'havanese' 84 | } 85 | 86 | item { 87 | id: 18 88 | name: 'japanese_chin' 89 | } 90 | 91 | item { 92 | id: 19 93 | name: 'keeshond' 94 | } 95 | 96 | item { 97 | id: 20 98 | name: 'leonberger' 99 | } 100 | 101 | item { 102 | id: 21 103 | name: 'Maine_Coon' 104 | } 105 | 106 | item { 107 | id: 22 108 | name: 'miniature_pinscher' 109 | } 110 | 111 | item { 112 | id: 23 113 | name: 'newfoundland' 114 | } 115 | 116 | item { 117 | id: 24 118 | name: 'Persian' 119 | } 120 | 121 | item { 122 | id: 25 123 | name: 'pomeranian' 124 | } 125 | 126 | item { 127 | id: 26 128 | name: 'pug' 129 | } 130 | 131 | item { 132 | id: 27 133 | name: 'Ragdoll' 134 | } 135 | 136 | item { 137 | id: 28 138 | name: 'Russian_Blue' 139 | } 140 | 141 | item { 142 | id: 29 143 | name: 'saint_bernard' 144 | } 145 | 146 | item { 147 | id: 30 148 | name: 'samoyed' 149 | } 150 | 151 | item { 152 | id: 31 153 | name: 'scottish_terrier' 154 | } 155 | 156 | item { 157 | id: 32 158 | name: 'shiba_inu' 159 | } 160 | 161 | item { 162 | id: 33 163 | name: 'Siamese' 164 | } 165 | 166 | item { 167 | id: 34 168 | name: 'Sphynx' 169 | } 170 | 171 | item { 172 | id: 35 173 | name: 'staffordshire_bull_terrier' 174 | } 175 | 176 | item { 177 | id: 36 178 | name: 'wheaten_terrier' 179 | } 180 | 181 | item { 182 | id: 37 183 | name: 'yorkshire_terrier' 184 | } 185 | -------------------------------------------------------------------------------- /image caption/__pycache__/load_data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/image caption/__pycache__/load_data.cpython-36.pyc -------------------------------------------------------------------------------- /image caption/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/image caption/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /image caption/__pycache__/preprocessing.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/image caption/__pycache__/preprocessing.cpython-36.pyc -------------------------------------------------------------------------------- /image caption/demo: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /image caption/model.py: -------------------------------------------------------------------------------- 1 | from numpy import argmax 2 | from keras.applications.inception_v3 import InceptionV3 3 | from keras.models import Model 4 | from keras.layers import Input 5 | from keras.layers import Dense 6 | from keras.layers import LSTM 7 | from keras.layers import Embedding 8 | from keras.layers import Dropout 9 | from keras.layers.merge import add 10 | from keras.preprocessing.sequence import pad_sequences 11 | from keras.preprocessing.image import load_img, img_to_array 12 | from nltk.translate.bleu_score import corpus_bleu 13 | 14 | from keras.models import Sequential 15 | from keras.layers import LSTM, Embedding, TimeDistributed, RepeatVector, Activation 16 | from keras.layers.core import Layer, Dense 17 | from keras.optimizers import Adam, RMSprop 18 | from keras.layers.wrappers import Bidirectional 19 | 20 | 21 | # define the CNN model 22 | def defineCNNmodel(): 23 | model = InceptionV3() 24 | model.layers.pop() 25 | model = Model(inputs=model.inputs, outputs=model.layers[-1].output) 26 | #print(model.summary()) 27 | return model 28 | 29 | # define the RNN model 30 | def defineRNNmodel(vocab_size, max_len): 31 | embedding_size = 300 32 | # Input dimension is 2048 since we will feed it the encoded version of the image. 33 | image_model = Sequential([ 34 | Dense(embedding_size, input_shape=(2048,), activation='relu'), 35 | RepeatVector(max_len) 36 | ]) 37 | # Since we are going to predict the next word using the previous words(length of previous words changes with every iteration over the caption), we have to set return_sequences = True. 38 | caption_model = Sequential([ 39 | Embedding(vocab_size, embedding_size, input_length=max_len), 40 | LSTM(256, return_sequences=True), 41 | TimeDistributed(Dense(300)) 42 | ]) 43 | # Merging the models and creating a softmax classifier 44 | final_model = Sequential([ 45 | keras.layers.Concatenate([image_model, caption_model], mode='concat', concat_axis=1), 46 | Bidirectional(LSTM(256, return_sequences=False)), 47 | Dense(vocab_size), 48 | Activation('softmax') 49 | ]) 50 | final_model.compile(loss='categorical_crossentropy', optimizer=RMSprop(), metrics=['accuracy']) 51 | final_model.summary() 52 | return final_model 53 | 54 | 55 | # map an integer to a word 56 | def word_for_id(integer, tokenizer): 57 | for word, index in tokenizer.word_index.items(): 58 | if index == integer: 59 | return word 60 | return None 61 | 62 | # generate a description for an image, given a pre-trained model and a tokenizer to map integer back to word 63 | def generate_desc(model, tokenizer, photo, max_length): 64 | # seed the generation process 65 | in_text = 'startseq' 66 | # iterate over the whole length of the sequence 67 | for i in range(max_length): 68 | # integer encode input sequence 69 | sequence = tokenizer.texts_to_sequences([in_text])[0] 70 | # pad input 71 | sequence = pad_sequences([sequence], maxlen=max_length) 72 | # predict next word 73 | yhat = model.predict([photo,sequence], verbose=0) 74 | # convert probability to integer 75 | yhat = argmax(yhat) 76 | # map integer to word 77 | word = word_for_id(yhat, tokenizer) 78 | # stop if we cannot map the word 79 | if word is None: 80 | break 81 | # append as input for generating the next word 82 | in_text += ' ' + word 83 | # stop if we predict the end of the sequence 84 | if word == 'endseq': 85 | break 86 | return in_text 87 | 88 | 89 | def evaluate_model(model, photos, descriptions, tokenizer, max_length): 90 | actual, predicted = list(), list() 91 | 92 | for key, desc_list in descriptions.items(): 93 | yhat = generate_desc(model, tokenizer, photos[key], max_length) 94 | references = [d.split() for d in desc_list] 95 | actual.append(references) 96 | predicted.append(yhat.split()) 97 | 98 | print('BLEU-1: %f' % corpus_bleu(actual, predicted, weights=(1.0, 0, 0, 0))) 99 | print('BLEU-2: %f' % corpus_bleu(actual, predicted, weights=(0.5, 0.5, 0, 0))) 100 | print('BLEU-3: %f' % corpus_bleu(actual, predicted, weights=(0.3, 0.3, 0.3, 0))) 101 | print('BLEU-4: %f' % corpus_bleu(actual, predicted, weights=(0.25, 0.25, 0.25, 0.25))) 102 | -------------------------------------------------------------------------------- /image caption/model.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/image caption/model.pyc -------------------------------------------------------------------------------- /image caption/model_data/description.txt: -------------------------------------------------------------------------------- 1 | When you run the project, some files will be generated which'll be stored here 2 | 3 | descriptions.txt : contains the saved text features 4 | features.pkl : contains the saved image features 5 | tokenizer.pkl : contains the saved tokenizer 6 | model : the trained model 7 | -------------------------------------------------------------------------------- /image caption/model_data/model_19.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/image caption/model_data/model_19.h5 -------------------------------------------------------------------------------- /image caption/model_data/tokenizer.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/image caption/model_data/tokenizer.pkl -------------------------------------------------------------------------------- /image caption/preprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import string 3 | from os import listdir 4 | from pickle import dump 5 | from model import * 6 | from keras.applications.inception_v3 import preprocess_input 7 | from keras.preprocessing.image import load_img, img_to_array 8 | 9 | #The function returns a dictionary of image identifier to image features. 10 | def extract_features(path): 11 | model = defineCNNmodel() 12 | # extract features from each photo 13 | features = dict() 14 | for name in listdir(path): 15 | # load an image from file 16 | filename = path + '/' + name 17 | image = load_img(filename, target_size=(299, 299)) 18 | # convert the image pixels to a numpy array 19 | image = img_to_array(image) 20 | # reshape data for the model 21 | image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2])) 22 | # prepare the image for the VGG model 23 | image = preprocess_input(image) 24 | # get features 25 | feature = model.predict(image, verbose=0) 26 | # get image id 27 | image_id = name.split('.')[0] 28 | # store feature 29 | features[image_id] = feature 30 | return features 31 | 32 | # extract descriptions for images 33 | def load_descriptions(filename): 34 | file = open(filename, 'r') 35 | doc = file.read() 36 | file.close() 37 | mapping = dict() 38 | # process lines by line 39 | for line in doc.split('\n'): 40 | # split line by white space 41 | tokens = line.split() 42 | if len(line) < 2: 43 | continue 44 | # take the first token as the image id, the rest as the description 45 | image_id, image_desc = tokens[0], tokens[1:] 46 | # remove filename from image id 47 | image_id = image_id.split('.')[0] 48 | # convert description tokens back to string 49 | image_desc = ' '.join(image_desc) 50 | # create the list if needed 51 | if image_id not in mapping: 52 | mapping[image_id] = list() 53 | # store description 54 | mapping[image_id].append(image_desc) 55 | return mapping 56 | 57 | def clean_descriptions(descriptions): 58 | # prepare translation table for removing punctuation 59 | table = str.maketrans('', '', string.punctuation) 60 | for key, desc_list in descriptions.items(): 61 | for i in range(len(desc_list)): 62 | desc = desc_list[i] 63 | # tokenize 64 | desc = desc.split() 65 | # convert to lower case 66 | desc = [word.lower() for word in desc] 67 | # remove punctuation from each token 68 | desc = [w.translate(table) for w in desc] 69 | # remove hanging 's' and 'a' 70 | desc = [word for word in desc if len(word)>1] 71 | # remove tokens with numbers in them 72 | desc = [word for word in desc if word.isalpha()] 73 | # store as string 74 | desc_list[i] = ' '.join(desc) 75 | 76 | # save descriptions to file, one per line 77 | def save_descriptions(descriptions, filename): 78 | lines = list() 79 | for key, desc_list in descriptions.items(): 80 | for desc in desc_list: 81 | lines.append(key + ' ' + desc) 82 | data = '\n'.join(lines) 83 | file = open(filename, 'w') 84 | file.write(data) 85 | file.close() 86 | 87 | def preprocessData(): 88 | # extract features from all images 89 | path = 'train_val_data/Flicker8k_Dataset' 90 | print('Generating image features...') 91 | features = extract_features(path) 92 | print('Completed. Saving now...') 93 | # save to file 94 | dump(features, open('model_data/features.pkl', 'wb')) 95 | print("Save Complete.") 96 | 97 | # load descriptions containing file and parse descriptions 98 | descriptions_path = 'train_val_data/Flickr8k.token.txt' 99 | 100 | descriptions = load_descriptions(descriptions_path) 101 | print('Loaded Descriptions: %d ' % len(descriptions)) 102 | 103 | # clean descriptions 104 | clean_descriptions(descriptions) 105 | 106 | # save descriptions 107 | save_descriptions(descriptions, 'model_data/descriptions.txt') 108 | 109 | 110 | # Now descriptions.txt is of form : 111 | # Example : 2252123185_487f21e336 stadium full of people watch game 112 | -------------------------------------------------------------------------------- /image caption/readme.txt: -------------------------------------------------------------------------------- 1 | Required Libraries for python 2 | Keras 3 | Pillow 4 | nltk 5 | Matplotlib 6 | Important: After downloading the dataset, put the reqired files in train_val_data folder 7 | Procedure to Train Model 8 | 1)Put the required files in train_val_data Folder 9 | 2)Run train_val.py 10 | Procedure to Test on new images 11 | 1)Put the test image in test_data folder 12 | 2)Run test.py -------------------------------------------------------------------------------- /image caption/test.py: -------------------------------------------------------------------------------- 1 | from pickle import load 2 | from model import * 3 | from keras.models import load_model 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from PIL import Image 7 | 8 | 9 | # extract features from each photo in the directory 10 | def extract_features(filename): 11 | model = defineCNNmodel() 12 | # load the photo 13 | image = load_img(filename, target_size=(224, 224)) 14 | # convert the image pixels to a numpy array 15 | image = img_to_array(image) 16 | # reshape data for the model 17 | image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2])) 18 | # prepare the image for the VGG model 19 | image = preprocess_input(image) 20 | # get features 21 | feature = model.predict(image, verbose=0) 22 | return feature 23 | 24 | # load the tokenizer 25 | tokenizer_path = 'model_data/tokenizer.pkl' 26 | tokenizer = load(open(tokenizer_path, 'rb')) 27 | 28 | # pre-define the max sequence length (from training) 29 | max_length = 34 30 | 31 | # load the model 32 | model_path = 'model_data/model_19.h5' 33 | model = load_model(model_path) 34 | 35 | # load and prepare the photograph 36 | test_path = 'test_data' 37 | for image_file in os.listdir(test_path): 38 | try: 39 | image_type = imghdr.what(os.path.join(test_path, image_file)) 40 | if not image_type: 41 | continue 42 | except IsADirectoryError: 43 | continue 44 | image = extract_features(image_file) 45 | 46 | # generate description 47 | description = generate_desc(model, tokenizer, image, max_length) 48 | 49 | # remove startseq and endseq 50 | caption = 'Caption: ' + description.split()[1].capitalize() 51 | for x in description.split()[2:len(description.split())-1]: 52 | caption = caption + ' ' + x 53 | caption += '.' 54 | 55 | # Show image and it's caption 56 | pil_im = Image.open(image_file, 'r') 57 | fig, ax = plt.subplots(figsize=(8, 8)) 58 | ax.get_xaxis().set_visible(False) 59 | ax.get_yaxis().set_visible(False) 60 | _ = ax.imshow(np.asarray(pil_im), interpolation='nearest') 61 | _ = ax.set_title(caption,fontdict={'fontsize': '20','fontweight' : '40'}) 62 | -------------------------------------------------------------------------------- /image caption/test_data/ai3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/image caption/test_data/ai3.jpg -------------------------------------------------------------------------------- /image caption/test_data/deep.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/image caption/test_data/deep.jpg -------------------------------------------------------------------------------- /image caption/test_data/instruction.txt: -------------------------------------------------------------------------------- 1 | Put here the images you want to test the model on. -------------------------------------------------------------------------------- /image caption/test_data/modern class.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/image caption/test_data/modern class.jpg -------------------------------------------------------------------------------- /image caption/train_val.py: -------------------------------------------------------------------------------- 1 | from pickle import load 2 | from model import * 3 | from load_data import * 4 | 5 | # Load Data 6 | # X1 : image features 7 | # X2 : text features 8 | X1train, X2train, max_length = loadTrainData(path = 'train_val_data/Flickr_8k.trainImages.txt',preprocessDataReady=False) 9 | 10 | X1val, X2val = loadValData(path = 'train_val_data/Flickr_8k.devImages.txt') 11 | 12 | # load the tokenizer 13 | tokenizer_path = 'model_data/tokenizer.pkl' 14 | tokenizer = load(open(tokenizer_path, 'rb')) 15 | vocab_size = len(tokenizer.word_index) + 1 16 | 17 | # prints 34 18 | print('Max Length : ',max_length) 19 | 20 | # We already have the image features from CNN model so we only need to define the RNN model now. 21 | # define the RNN model 22 | model = defineRNNmodel(vocab_size, max_length) 23 | 24 | # train the model, run epochs manually and save after each epoch 25 | epochs = 20 26 | steps_train = len(X2train) 27 | steps_val = len(X2val) 28 | for i in range(epochs): 29 | # create the train data generator 30 | generator_train = data_generator(X1train, X2train, tokenizer, max_length) 31 | # create the val data generator 32 | generator_val = data_generator(X1val, X2val, tokenizer, max_length) 33 | # fit for one epoch 34 | model.fit_generator(generator_train, epochs=1, steps_per_epoch=steps_train, 35 | verbose=1, validation_data=generator_val, validation_steps=steps_val) 36 | # save model 37 | model.save('model_data/model_' + str(i) + '.h5') 38 | 39 | # Evaluate the model on validation data and ouput BLEU score 40 | # evaluate_model(model, X1val, X2val, tokenizer, max_length) 41 | -------------------------------------------------------------------------------- /image caption/train_val_data/demo: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /image caption/utils/model.py: -------------------------------------------------------------------------------- 1 | from numpy import argmax 2 | from keras.applications.inception_v3 import InceptionV3 3 | from keras.models import Model 4 | from keras.layers import Input 5 | from keras.layers import Dense 6 | from keras.layers import LSTM 7 | from keras.layers import Embedding 8 | from keras.layers import Dropout 9 | from keras.layers.merge import add 10 | from keras.preprocessing.sequence import pad_sequences 11 | from keras.preprocessing.image import load_img, img_to_array 12 | from nltk.translate.bleu_score import corpus_bleu 13 | 14 | from keras.models import Sequential 15 | from keras.layers import LSTM, Embedding, TimeDistributed, Dense, RepeatVector, Merge, Activation, Flatten 16 | from keras.optimizers import Adam, RMSprop 17 | from keras.layers.wrappers import Bidirectional 18 | 19 | 20 | # define the CNN model 21 | def defineCNNmodel(): 22 | model = InceptionV3() 23 | model.layers.pop() 24 | model = Model(inputs=model.inputs, outputs=model.layers[-1].output) 25 | #print(model.summary()) 26 | return model 27 | 28 | # define the RNN model 29 | def defineRNNmodel(vocab_size, max_len): 30 | embedding_size = 300 31 | # Input dimension is 2048 since we will feed it the encoded version of the image. 32 | image_model = Sequential([ 33 | Dense(embedding_size, input_shape=(2048,), activation='relu'), 34 | RepeatVector(max_len) 35 | ]) 36 | # Since we are going to predict the next word using the previous words(length of previous words changes with every iteration over the caption), we have to set return_sequences = True. 37 | caption_model = Sequential([ 38 | Embedding(vocab_size, embedding_size, input_length=max_len), 39 | LSTM(256, return_sequences=True), 40 | TimeDistributed(Dense(300)) 41 | ]) 42 | # Merging the models and creating a softmax classifier 43 | final_model = Sequential([ 44 | Merge([image_model, caption_model], mode='concat', concat_axis=1), 45 | Bidirectional(LSTM(256, return_sequences=False)), 46 | Dense(vocab_size), 47 | Activation('softmax') 48 | ]) 49 | final_model.compile(loss='categorical_crossentropy', optimizer=RMSprop(), metrics=['accuracy']) 50 | final_model.summary() 51 | return final_model 52 | 53 | 54 | # map an integer to a word 55 | def word_for_id(integer, tokenizer): 56 | for word, index in tokenizer.word_index.items(): 57 | if index == integer: 58 | return word 59 | return None 60 | 61 | # generate a description for an image, given a pre-trained model and a tokenizer to map integer back to word 62 | def generate_desc(model, tokenizer, photo, max_length): 63 | # seed the generation process 64 | in_text = 'startseq' 65 | # iterate over the whole length of the sequence 66 | for i in range(max_length): 67 | # integer encode input sequence 68 | sequence = tokenizer.texts_to_sequences([in_text])[0] 69 | # pad input 70 | sequence = pad_sequences([sequence], maxlen=max_length) 71 | # predict next word 72 | yhat = model.predict([photo,sequence], verbose=0) 73 | # convert probability to integer 74 | yhat = argmax(yhat) 75 | # map integer to word 76 | word = word_for_id(yhat, tokenizer) 77 | # stop if we cannot map the word 78 | if word is None: 79 | break 80 | # append as input for generating the next word 81 | in_text += ' ' + word 82 | # stop if we predict the end of the sequence 83 | if word == 'endseq': 84 | break 85 | return in_text 86 | 87 | 88 | def evaluate_model(model, photos, descriptions, tokenizer, max_length): 89 | actual, predicted = list(), list() 90 | 91 | for key, desc_list in descriptions.items(): 92 | yhat = generate_desc(model, tokenizer, photos[key], max_length) 93 | references = [d.split() for d in desc_list] 94 | actual.append(references) 95 | predicted.append(yhat.split()) 96 | 97 | print('BLEU-1: %f' % corpus_bleu(actual, predicted, weights=(1.0, 0, 0, 0))) 98 | print('BLEU-2: %f' % corpus_bleu(actual, predicted, weights=(0.5, 0.5, 0, 0))) 99 | print('BLEU-3: %f' % corpus_bleu(actual, predicted, weights=(0.3, 0.3, 0.3, 0))) 100 | print('BLEU-4: %f' % corpus_bleu(actual, predicted, weights=(0.25, 0.25, 0.25, 0.25))) -------------------------------------------------------------------------------- /image caption/utils/preprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import string 3 | from os import listdir 4 | from pickle import dump 5 | from keras.applications.inception_v3 import preprocess_input 6 | from keras.preprocessing.image import load_img, img_to_array 7 | 8 | #The function returns a dictionary of image identifier to image features. 9 | def extract_features(path): 10 | model = defineCNNmodel() 11 | # extract features from each photo 12 | features = dict() 13 | for name in listdir(path): 14 | # load an image from file 15 | filename = path + '/' + name 16 | image = load_img(filename, target_size=(299, 299)) 17 | # convert the image pixels to a numpy array 18 | image = img_to_array(image) 19 | # reshape data for the model 20 | image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2])) 21 | # prepare the image for the VGG model 22 | image = preprocess_input(image) 23 | # get features 24 | feature = model.predict(image, verbose=0) 25 | # get image id 26 | image_id = name.split('.')[0] 27 | # store feature 28 | features[image_id] = feature 29 | return features 30 | 31 | # extract descriptions for images 32 | def load_descriptions(filename): 33 | file = open(filename, 'r') 34 | doc = file.read() 35 | file.close() 36 | mapping = dict() 37 | # process lines by line 38 | for line in doc.split('\n'): 39 | # split line by white space 40 | tokens = line.split() 41 | if len(line) < 2: 42 | continue 43 | # take the first token as the image id, the rest as the description 44 | image_id, image_desc = tokens[0], tokens[1:] 45 | # remove filename from image id 46 | image_id = image_id.split('.')[0] 47 | # convert description tokens back to string 48 | image_desc = ' '.join(image_desc) 49 | # create the list if needed 50 | if image_id not in mapping: 51 | mapping[image_id] = list() 52 | # store description 53 | mapping[image_id].append(image_desc) 54 | return mapping 55 | 56 | def clean_descriptions(descriptions): 57 | # prepare translation table for removing punctuation 58 | table = str.maketrans('', '', string.punctuation) 59 | for key, desc_list in descriptions.items(): 60 | for i in range(len(desc_list)): 61 | desc = desc_list[i] 62 | # tokenize 63 | desc = desc.split() 64 | # convert to lower case 65 | desc = [word.lower() for word in desc] 66 | # remove punctuation from each token 67 | desc = [w.translate(table) for w in desc] 68 | # remove hanging 's' and 'a' 69 | desc = [word for word in desc if len(word)>1] 70 | # remove tokens with numbers in them 71 | desc = [word for word in desc if word.isalpha()] 72 | # store as string 73 | desc_list[i] = ' '.join(desc) 74 | 75 | # save descriptions to file, one per line 76 | def save_descriptions(descriptions, filename): 77 | lines = list() 78 | for key, desc_list in descriptions.items(): 79 | for desc in desc_list: 80 | lines.append(key + ' ' + desc) 81 | data = '\n'.join(lines) 82 | file = open(filename, 'w') 83 | file.write(data) 84 | file.close() 85 | 86 | def preprocessData(): 87 | # extract features from all images 88 | path = 'Flicker8k_Dataset' 89 | print('Generating image features...') 90 | features = extract_features(path) 91 | print('Completed. Saving now...') 92 | # save to file 93 | dump(features, open('model_data/features.pkl', 'wb')) 94 | print("Save Complete.") 95 | 96 | # load descriptions containing file and parse descriptions 97 | descriptions_path = 'train_val_data/Flickr8k.token.txt' 98 | 99 | descriptions = load_descriptions(descriptions_path) 100 | print('Loaded Descriptions: %d ' % len(descriptions)) 101 | 102 | # clean descriptions 103 | clean_descriptions(descriptions) 104 | 105 | # save descriptions 106 | save_descriptions(descriptions, 'model_data/descriptions.txt') 107 | 108 | 109 | # Now descriptions.txt is of form : 110 | # Example : 2252123185_487f21e336 stadium full of people watch game -------------------------------------------------------------------------------- /images/image3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/images/image3.png -------------------------------------------------------------------------------- /images/image4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/images/image4.png -------------------------------------------------------------------------------- /images/image5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/images/image5.png -------------------------------------------------------------------------------- /images/image6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/images/image6.png -------------------------------------------------------------------------------- /images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/images/logo.png -------------------------------------------------------------------------------- /opencv.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/opencv.jpg -------------------------------------------------------------------------------- /protos/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/protos/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /protos/__pycache__/string_int_label_map_pb2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/protos/__pycache__/string_int_label_map_pb2.cpython-36.pyc -------------------------------------------------------------------------------- /protos/anchor_generator.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/grid_anchor_generator.proto"; 6 | import "object_detection/protos/ssd_anchor_generator.proto"; 7 | 8 | // Configuration proto for the anchor generator to use in the object detection 9 | // pipeline. See core/anchor_generator.py for details. 10 | message AnchorGenerator { 11 | oneof anchor_generator_oneof { 12 | GridAnchorGenerator grid_anchor_generator = 1; 13 | SsdAnchorGenerator ssd_anchor_generator = 2; 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /protos/anchor_generator_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/anchor_generator.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | import object_detection.protos.grid_anchor_generator_pb2 17 | import object_detection.protos.ssd_anchor_generator_pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='object_detection/protos/anchor_generator.proto', 22 | package='object_detection.protos', 23 | serialized_pb=_b('\n.object_detection/protos/anchor_generator.proto\x12\x17object_detection.protos\x1a\x33object_detection/protos/grid_anchor_generator.proto\x1a\x32object_detection/protos/ssd_anchor_generator.proto\"\xc7\x01\n\x0f\x41nchorGenerator\x12M\n\x15grid_anchor_generator\x18\x01 \x01(\x0b\x32,.object_detection.protos.GridAnchorGeneratorH\x00\x12K\n\x14ssd_anchor_generator\x18\x02 \x01(\x0b\x32+.object_detection.protos.SsdAnchorGeneratorH\x00\x42\x18\n\x16\x61nchor_generator_oneof') 24 | , 25 | dependencies=[object_detection.protos.grid_anchor_generator_pb2.DESCRIPTOR,object_detection.protos.ssd_anchor_generator_pb2.DESCRIPTOR,]) 26 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 27 | 28 | 29 | 30 | 31 | _ANCHORGENERATOR = _descriptor.Descriptor( 32 | name='AnchorGenerator', 33 | full_name='object_detection.protos.AnchorGenerator', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='grid_anchor_generator', full_name='object_detection.protos.AnchorGenerator.grid_anchor_generator', index=0, 40 | number=1, type=11, cpp_type=10, label=1, 41 | has_default_value=False, default_value=None, 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None), 45 | _descriptor.FieldDescriptor( 46 | name='ssd_anchor_generator', full_name='object_detection.protos.AnchorGenerator.ssd_anchor_generator', index=1, 47 | number=2, type=11, cpp_type=10, label=1, 48 | has_default_value=False, default_value=None, 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None), 52 | ], 53 | extensions=[ 54 | ], 55 | nested_types=[], 56 | enum_types=[ 57 | ], 58 | options=None, 59 | is_extendable=False, 60 | extension_ranges=[], 61 | oneofs=[ 62 | _descriptor.OneofDescriptor( 63 | name='anchor_generator_oneof', full_name='object_detection.protos.AnchorGenerator.anchor_generator_oneof', 64 | index=0, containing_type=None, fields=[]), 65 | ], 66 | serialized_start=181, 67 | serialized_end=380, 68 | ) 69 | 70 | _ANCHORGENERATOR.fields_by_name['grid_anchor_generator'].message_type = object_detection.protos.grid_anchor_generator_pb2._GRIDANCHORGENERATOR 71 | _ANCHORGENERATOR.fields_by_name['ssd_anchor_generator'].message_type = object_detection.protos.ssd_anchor_generator_pb2._SSDANCHORGENERATOR 72 | _ANCHORGENERATOR.oneofs_by_name['anchor_generator_oneof'].fields.append( 73 | _ANCHORGENERATOR.fields_by_name['grid_anchor_generator']) 74 | _ANCHORGENERATOR.fields_by_name['grid_anchor_generator'].containing_oneof = _ANCHORGENERATOR.oneofs_by_name['anchor_generator_oneof'] 75 | _ANCHORGENERATOR.oneofs_by_name['anchor_generator_oneof'].fields.append( 76 | _ANCHORGENERATOR.fields_by_name['ssd_anchor_generator']) 77 | _ANCHORGENERATOR.fields_by_name['ssd_anchor_generator'].containing_oneof = _ANCHORGENERATOR.oneofs_by_name['anchor_generator_oneof'] 78 | DESCRIPTOR.message_types_by_name['AnchorGenerator'] = _ANCHORGENERATOR 79 | 80 | AnchorGenerator = _reflection.GeneratedProtocolMessageType('AnchorGenerator', (_message.Message,), dict( 81 | DESCRIPTOR = _ANCHORGENERATOR, 82 | __module__ = 'object_detection.protos.anchor_generator_pb2' 83 | # @@protoc_insertion_point(class_scope:object_detection.protos.AnchorGenerator) 84 | )) 85 | _sym_db.RegisterMessage(AnchorGenerator) 86 | 87 | 88 | # @@protoc_insertion_point(module_scope) 89 | -------------------------------------------------------------------------------- /protos/argmax_matcher.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for ArgMaxMatcher. See 6 | // matchers/argmax_matcher.py for details. 7 | message ArgMaxMatcher { 8 | // Threshold for positive matches. 9 | optional float matched_threshold = 1 [default = 0.5]; 10 | 11 | // Threshold for negative matches. 12 | optional float unmatched_threshold = 2 [default = 0.5]; 13 | 14 | // Whether to construct ArgMaxMatcher without thresholds. 15 | optional bool ignore_thresholds = 3 [default = false]; 16 | 17 | // If True then negative matches are the ones below the unmatched_threshold, 18 | // whereas ignored matches are in between the matched and umatched 19 | // threshold. If False, then negative matches are in between the matched 20 | // and unmatched threshold, and everything lower than unmatched is ignored. 21 | optional bool negatives_lower_than_unmatched = 4 [default = true]; 22 | 23 | // Whether to ensure each row is matched to at least one column. 24 | optional bool force_match_for_each_row = 5 [default = false]; 25 | } 26 | -------------------------------------------------------------------------------- /protos/argmax_matcher_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/argmax_matcher.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/argmax_matcher.proto', 20 | package='object_detection.protos', 21 | serialized_pb=_b('\n,object_detection/protos/argmax_matcher.proto\x12\x17object_detection.protos\"\xca\x01\n\rArgMaxMatcher\x12\x1e\n\x11matched_threshold\x18\x01 \x01(\x02:\x03\x30.5\x12 \n\x13unmatched_threshold\x18\x02 \x01(\x02:\x03\x30.5\x12 \n\x11ignore_thresholds\x18\x03 \x01(\x08:\x05\x66\x61lse\x12,\n\x1enegatives_lower_than_unmatched\x18\x04 \x01(\x08:\x04true\x12\'\n\x18\x66orce_match_for_each_row\x18\x05 \x01(\x08:\x05\x66\x61lse') 22 | ) 23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 24 | 25 | 26 | 27 | 28 | _ARGMAXMATCHER = _descriptor.Descriptor( 29 | name='ArgMaxMatcher', 30 | full_name='object_detection.protos.ArgMaxMatcher', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='matched_threshold', full_name='object_detection.protos.ArgMaxMatcher.matched_threshold', index=0, 37 | number=1, type=2, cpp_type=6, label=1, 38 | has_default_value=True, default_value=0.5, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='unmatched_threshold', full_name='object_detection.protos.ArgMaxMatcher.unmatched_threshold', index=1, 44 | number=2, type=2, cpp_type=6, label=1, 45 | has_default_value=True, default_value=0.5, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='ignore_thresholds', full_name='object_detection.protos.ArgMaxMatcher.ignore_thresholds', index=2, 51 | number=3, type=8, cpp_type=7, label=1, 52 | has_default_value=True, default_value=False, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | _descriptor.FieldDescriptor( 57 | name='negatives_lower_than_unmatched', full_name='object_detection.protos.ArgMaxMatcher.negatives_lower_than_unmatched', index=3, 58 | number=4, type=8, cpp_type=7, label=1, 59 | has_default_value=True, default_value=True, 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None), 63 | _descriptor.FieldDescriptor( 64 | name='force_match_for_each_row', full_name='object_detection.protos.ArgMaxMatcher.force_match_for_each_row', index=4, 65 | number=5, type=8, cpp_type=7, label=1, 66 | has_default_value=True, default_value=False, 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None), 70 | ], 71 | extensions=[ 72 | ], 73 | nested_types=[], 74 | enum_types=[ 75 | ], 76 | options=None, 77 | is_extendable=False, 78 | extension_ranges=[], 79 | oneofs=[ 80 | ], 81 | serialized_start=74, 82 | serialized_end=276, 83 | ) 84 | 85 | DESCRIPTOR.message_types_by_name['ArgMaxMatcher'] = _ARGMAXMATCHER 86 | 87 | ArgMaxMatcher = _reflection.GeneratedProtocolMessageType('ArgMaxMatcher', (_message.Message,), dict( 88 | DESCRIPTOR = _ARGMAXMATCHER, 89 | __module__ = 'object_detection.protos.argmax_matcher_pb2' 90 | # @@protoc_insertion_point(class_scope:object_detection.protos.ArgMaxMatcher) 91 | )) 92 | _sym_db.RegisterMessage(ArgMaxMatcher) 93 | 94 | 95 | # @@protoc_insertion_point(module_scope) 96 | -------------------------------------------------------------------------------- /protos/bipartite_matcher.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for bipartite matcher. See 6 | // matchers/bipartite_matcher.py for details. 7 | message BipartiteMatcher { 8 | } 9 | -------------------------------------------------------------------------------- /protos/bipartite_matcher_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/bipartite_matcher.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/bipartite_matcher.proto', 20 | package='object_detection.protos', 21 | serialized_pb=_b('\n/object_detection/protos/bipartite_matcher.proto\x12\x17object_detection.protos\"\x12\n\x10\x42ipartiteMatcher') 22 | ) 23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 24 | 25 | 26 | 27 | 28 | _BIPARTITEMATCHER = _descriptor.Descriptor( 29 | name='BipartiteMatcher', 30 | full_name='object_detection.protos.BipartiteMatcher', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | ], 36 | extensions=[ 37 | ], 38 | nested_types=[], 39 | enum_types=[ 40 | ], 41 | options=None, 42 | is_extendable=False, 43 | extension_ranges=[], 44 | oneofs=[ 45 | ], 46 | serialized_start=76, 47 | serialized_end=94, 48 | ) 49 | 50 | DESCRIPTOR.message_types_by_name['BipartiteMatcher'] = _BIPARTITEMATCHER 51 | 52 | BipartiteMatcher = _reflection.GeneratedProtocolMessageType('BipartiteMatcher', (_message.Message,), dict( 53 | DESCRIPTOR = _BIPARTITEMATCHER, 54 | __module__ = 'object_detection.protos.bipartite_matcher_pb2' 55 | # @@protoc_insertion_point(class_scope:object_detection.protos.BipartiteMatcher) 56 | )) 57 | _sym_db.RegisterMessage(BipartiteMatcher) 58 | 59 | 60 | # @@protoc_insertion_point(module_scope) 61 | -------------------------------------------------------------------------------- /protos/box_coder.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/faster_rcnn_box_coder.proto"; 6 | import "object_detection/protos/mean_stddev_box_coder.proto"; 7 | import "object_detection/protos/square_box_coder.proto"; 8 | 9 | // Configuration proto for the box coder to be used in the object detection 10 | // pipeline. See core/box_coder.py for details. 11 | message BoxCoder { 12 | oneof box_coder_oneof { 13 | FasterRcnnBoxCoder faster_rcnn_box_coder = 1; 14 | MeanStddevBoxCoder mean_stddev_box_coder = 2; 15 | SquareBoxCoder square_box_coder = 3; 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /protos/box_coder_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/box_coder.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | import object_detection.protos.faster_rcnn_box_coder_pb2 17 | import object_detection.protos.mean_stddev_box_coder_pb2 18 | import object_detection.protos.square_box_coder_pb2 19 | 20 | 21 | DESCRIPTOR = _descriptor.FileDescriptor( 22 | name='object_detection/protos/box_coder.proto', 23 | package='object_detection.protos', 24 | serialized_pb=_b('\n\'object_detection/protos/box_coder.proto\x12\x17object_detection.protos\x1a\x33object_detection/protos/faster_rcnn_box_coder.proto\x1a\x33object_detection/protos/mean_stddev_box_coder.proto\x1a.object_detection/protos/square_box_coder.proto\"\xfe\x01\n\x08\x42oxCoder\x12L\n\x15\x66\x61ster_rcnn_box_coder\x18\x01 \x01(\x0b\x32+.object_detection.protos.FasterRcnnBoxCoderH\x00\x12L\n\x15mean_stddev_box_coder\x18\x02 \x01(\x0b\x32+.object_detection.protos.MeanStddevBoxCoderH\x00\x12\x43\n\x10square_box_coder\x18\x03 \x01(\x0b\x32\'.object_detection.protos.SquareBoxCoderH\x00\x42\x11\n\x0f\x62ox_coder_oneof') 25 | , 26 | dependencies=[object_detection.protos.faster_rcnn_box_coder_pb2.DESCRIPTOR,object_detection.protos.mean_stddev_box_coder_pb2.DESCRIPTOR,object_detection.protos.square_box_coder_pb2.DESCRIPTOR,]) 27 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 28 | 29 | 30 | 31 | 32 | _BOXCODER = _descriptor.Descriptor( 33 | name='BoxCoder', 34 | full_name='object_detection.protos.BoxCoder', 35 | filename=None, 36 | file=DESCRIPTOR, 37 | containing_type=None, 38 | fields=[ 39 | _descriptor.FieldDescriptor( 40 | name='faster_rcnn_box_coder', full_name='object_detection.protos.BoxCoder.faster_rcnn_box_coder', index=0, 41 | number=1, type=11, cpp_type=10, label=1, 42 | has_default_value=False, default_value=None, 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | options=None), 46 | _descriptor.FieldDescriptor( 47 | name='mean_stddev_box_coder', full_name='object_detection.protos.BoxCoder.mean_stddev_box_coder', index=1, 48 | number=2, type=11, cpp_type=10, label=1, 49 | has_default_value=False, default_value=None, 50 | message_type=None, enum_type=None, containing_type=None, 51 | is_extension=False, extension_scope=None, 52 | options=None), 53 | _descriptor.FieldDescriptor( 54 | name='square_box_coder', full_name='object_detection.protos.BoxCoder.square_box_coder', index=2, 55 | number=3, type=11, cpp_type=10, label=1, 56 | has_default_value=False, default_value=None, 57 | message_type=None, enum_type=None, containing_type=None, 58 | is_extension=False, extension_scope=None, 59 | options=None), 60 | ], 61 | extensions=[ 62 | ], 63 | nested_types=[], 64 | enum_types=[ 65 | ], 66 | options=None, 67 | is_extendable=False, 68 | extension_ranges=[], 69 | oneofs=[ 70 | _descriptor.OneofDescriptor( 71 | name='box_coder_oneof', full_name='object_detection.protos.BoxCoder.box_coder_oneof', 72 | index=0, containing_type=None, fields=[]), 73 | ], 74 | serialized_start=223, 75 | serialized_end=477, 76 | ) 77 | 78 | _BOXCODER.fields_by_name['faster_rcnn_box_coder'].message_type = object_detection.protos.faster_rcnn_box_coder_pb2._FASTERRCNNBOXCODER 79 | _BOXCODER.fields_by_name['mean_stddev_box_coder'].message_type = object_detection.protos.mean_stddev_box_coder_pb2._MEANSTDDEVBOXCODER 80 | _BOXCODER.fields_by_name['square_box_coder'].message_type = object_detection.protos.square_box_coder_pb2._SQUAREBOXCODER 81 | _BOXCODER.oneofs_by_name['box_coder_oneof'].fields.append( 82 | _BOXCODER.fields_by_name['faster_rcnn_box_coder']) 83 | _BOXCODER.fields_by_name['faster_rcnn_box_coder'].containing_oneof = _BOXCODER.oneofs_by_name['box_coder_oneof'] 84 | _BOXCODER.oneofs_by_name['box_coder_oneof'].fields.append( 85 | _BOXCODER.fields_by_name['mean_stddev_box_coder']) 86 | _BOXCODER.fields_by_name['mean_stddev_box_coder'].containing_oneof = _BOXCODER.oneofs_by_name['box_coder_oneof'] 87 | _BOXCODER.oneofs_by_name['box_coder_oneof'].fields.append( 88 | _BOXCODER.fields_by_name['square_box_coder']) 89 | _BOXCODER.fields_by_name['square_box_coder'].containing_oneof = _BOXCODER.oneofs_by_name['box_coder_oneof'] 90 | DESCRIPTOR.message_types_by_name['BoxCoder'] = _BOXCODER 91 | 92 | BoxCoder = _reflection.GeneratedProtocolMessageType('BoxCoder', (_message.Message,), dict( 93 | DESCRIPTOR = _BOXCODER, 94 | __module__ = 'object_detection.protos.box_coder_pb2' 95 | # @@protoc_insertion_point(class_scope:object_detection.protos.BoxCoder) 96 | )) 97 | _sym_db.RegisterMessage(BoxCoder) 98 | 99 | 100 | # @@protoc_insertion_point(module_scope) 101 | -------------------------------------------------------------------------------- /protos/box_predictor.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/hyperparams.proto"; 6 | 7 | 8 | // Configuration proto for box predictor. See core/box_predictor.py for details. 9 | message BoxPredictor { 10 | oneof box_predictor_oneof { 11 | ConvolutionalBoxPredictor convolutional_box_predictor = 1; 12 | MaskRCNNBoxPredictor mask_rcnn_box_predictor = 2; 13 | RfcnBoxPredictor rfcn_box_predictor = 3; 14 | } 15 | } 16 | 17 | // Configuration proto for Convolutional box predictor. 18 | message ConvolutionalBoxPredictor { 19 | // Hyperparameters for convolution ops used in the box predictor. 20 | optional Hyperparams conv_hyperparams = 1; 21 | 22 | // Minumum feature depth prior to predicting box encodings and class 23 | // predictions. 24 | optional int32 min_depth = 2 [default = 0]; 25 | 26 | // Maximum feature depth prior to predicting box encodings and class 27 | // predictions. If max_depth is set to 0, no additional feature map will be 28 | // inserted before location and class predictions. 29 | optional int32 max_depth = 3 [default = 0]; 30 | 31 | // Number of the additional conv layers before the predictor. 32 | optional int32 num_layers_before_predictor = 4 [default = 0]; 33 | 34 | // Whether to use dropout for class prediction. 35 | optional bool use_dropout = 5 [default = true]; 36 | 37 | // Keep probability for dropout 38 | optional float dropout_keep_probability = 6 [default = 0.8]; 39 | 40 | // Size of final convolution kernel. If the spatial resolution of the feature 41 | // map is smaller than the kernel size, then the kernel size is set to 42 | // min(feature_width, feature_height). 43 | optional int32 kernel_size = 7 [default = 1]; 44 | 45 | // Size of the encoding for boxes. 46 | optional int32 box_code_size = 8 [default = 4]; 47 | 48 | // Whether to apply sigmoid to the output of class predictions. 49 | // TODO: Do we need this since we have a post processing module.? 50 | optional bool apply_sigmoid_to_scores = 9 [default = false]; 51 | } 52 | 53 | message MaskRCNNBoxPredictor { 54 | // Hyperparameters for fully connected ops used in the box predictor. 55 | optional Hyperparams fc_hyperparams = 1; 56 | 57 | // Whether to use dropout op prior to the both box and class predictions. 58 | optional bool use_dropout = 2 [default= false]; 59 | 60 | // Keep probability for dropout. This is only used if use_dropout is true. 61 | optional float dropout_keep_probability = 3 [default = 0.5]; 62 | 63 | // Size of the encoding for the boxes. 64 | optional int32 box_code_size = 4 [default = 4]; 65 | 66 | // Hyperparameters for convolution ops used in the box predictor. 67 | optional Hyperparams conv_hyperparams = 5; 68 | 69 | // Whether to predict instance masks inside detection boxes. 70 | optional bool predict_instance_masks = 6 [default = false]; 71 | 72 | // The depth for the first conv2d_transpose op applied to the 73 | // image_features in the mask prediciton branch 74 | optional int32 mask_prediction_conv_depth = 7 [default = 256]; 75 | 76 | // Whether to predict keypoints inside detection boxes. 77 | optional bool predict_keypoints = 8 [default = false]; 78 | } 79 | 80 | message RfcnBoxPredictor { 81 | // Hyperparameters for convolution ops used in the box predictor. 82 | optional Hyperparams conv_hyperparams = 1; 83 | 84 | // Bin sizes for RFCN crops. 85 | optional int32 num_spatial_bins_height = 2 [default = 3]; 86 | 87 | optional int32 num_spatial_bins_width = 3 [default = 3]; 88 | 89 | // Target depth to reduce the input image features to. 90 | optional int32 depth = 4 [default=1024]; 91 | 92 | // Size of the encoding for the boxes. 93 | optional int32 box_code_size = 5 [default = 4]; 94 | 95 | // Size to resize the rfcn crops to. 96 | optional int32 crop_height = 6 [default= 12]; 97 | 98 | optional int32 crop_width = 7 [default=12]; 99 | } 100 | -------------------------------------------------------------------------------- /protos/eval.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Message for configuring DetectionModel evaluation jobs (eval.py). 6 | message EvalConfig { 7 | // Number of visualization images to generate. 8 | optional uint32 num_visualizations = 1 [default=10]; 9 | 10 | // Number of examples to process of evaluation. 11 | optional uint32 num_examples = 2 [default=5000]; 12 | 13 | // How often to run evaluation. 14 | optional uint32 eval_interval_secs = 3 [default=300]; 15 | 16 | // Maximum number of times to run evaluation. If set to 0, will run forever. 17 | optional uint32 max_evals = 4 [default=0]; 18 | 19 | // Whether the TensorFlow graph used for evaluation should be saved to disk. 20 | optional bool save_graph = 5 [default=false]; 21 | 22 | // Path to directory to store visualizations in. If empty, visualization 23 | // images are not exported (only shown on Tensorboard). 24 | optional string visualization_export_dir = 6 [default=""]; 25 | 26 | // BNS name of the TensorFlow master. 27 | optional string eval_master = 7 [default=""]; 28 | 29 | // Type of metrics to use for evaluation. Currently supports only Pascal VOC 30 | // detection metrics. 31 | optional string metrics_set = 8 [default="pascal_voc_metrics"]; 32 | 33 | // Path to export detections to COCO compatible JSON format. 34 | optional string export_path = 9 [default='']; 35 | 36 | // Option to not read groundtruth labels and only export detections to 37 | // COCO-compatible JSON file. 38 | optional bool ignore_groundtruth = 10 [default=false]; 39 | 40 | // Use exponential moving averages of variables for evaluation. 41 | // TODO: When this is false make sure the model is constructed 42 | // without moving averages in restore_fn. 43 | optional bool use_moving_averages = 11 [default=false]; 44 | 45 | // Whether to evaluate instance masks. 46 | optional bool eval_instance_masks = 12 [default=false]; 47 | } 48 | -------------------------------------------------------------------------------- /protos/faster_rcnn_box_coder.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for FasterRCNNBoxCoder. See 6 | // box_coders/faster_rcnn_box_coder.py for details. 7 | message FasterRcnnBoxCoder { 8 | // Scale factor for anchor encoded box center. 9 | optional float y_scale = 1 [default = 10.0]; 10 | optional float x_scale = 2 [default = 10.0]; 11 | 12 | // Scale factor for anchor encoded box height. 13 | optional float height_scale = 3 [default = 5.0]; 14 | 15 | // Scale factor for anchor encoded box width. 16 | optional float width_scale = 4 [default = 5.0]; 17 | } 18 | -------------------------------------------------------------------------------- /protos/faster_rcnn_box_coder_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/faster_rcnn_box_coder.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/faster_rcnn_box_coder.proto', 20 | package='object_detection.protos', 21 | serialized_pb=_b('\n3object_detection/protos/faster_rcnn_box_coder.proto\x12\x17object_detection.protos\"o\n\x12\x46\x61sterRcnnBoxCoder\x12\x13\n\x07y_scale\x18\x01 \x01(\x02:\x02\x31\x30\x12\x13\n\x07x_scale\x18\x02 \x01(\x02:\x02\x31\x30\x12\x17\n\x0cheight_scale\x18\x03 \x01(\x02:\x01\x35\x12\x16\n\x0bwidth_scale\x18\x04 \x01(\x02:\x01\x35') 22 | ) 23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 24 | 25 | 26 | 27 | 28 | _FASTERRCNNBOXCODER = _descriptor.Descriptor( 29 | name='FasterRcnnBoxCoder', 30 | full_name='object_detection.protos.FasterRcnnBoxCoder', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='y_scale', full_name='object_detection.protos.FasterRcnnBoxCoder.y_scale', index=0, 37 | number=1, type=2, cpp_type=6, label=1, 38 | has_default_value=True, default_value=10, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='x_scale', full_name='object_detection.protos.FasterRcnnBoxCoder.x_scale', index=1, 44 | number=2, type=2, cpp_type=6, label=1, 45 | has_default_value=True, default_value=10, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='height_scale', full_name='object_detection.protos.FasterRcnnBoxCoder.height_scale', index=2, 51 | number=3, type=2, cpp_type=6, label=1, 52 | has_default_value=True, default_value=5, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | _descriptor.FieldDescriptor( 57 | name='width_scale', full_name='object_detection.protos.FasterRcnnBoxCoder.width_scale', index=3, 58 | number=4, type=2, cpp_type=6, label=1, 59 | has_default_value=True, default_value=5, 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None), 63 | ], 64 | extensions=[ 65 | ], 66 | nested_types=[], 67 | enum_types=[ 68 | ], 69 | options=None, 70 | is_extendable=False, 71 | extension_ranges=[], 72 | oneofs=[ 73 | ], 74 | serialized_start=80, 75 | serialized_end=191, 76 | ) 77 | 78 | DESCRIPTOR.message_types_by_name['FasterRcnnBoxCoder'] = _FASTERRCNNBOXCODER 79 | 80 | FasterRcnnBoxCoder = _reflection.GeneratedProtocolMessageType('FasterRcnnBoxCoder', (_message.Message,), dict( 81 | DESCRIPTOR = _FASTERRCNNBOXCODER, 82 | __module__ = 'object_detection.protos.faster_rcnn_box_coder_pb2' 83 | # @@protoc_insertion_point(class_scope:object_detection.protos.FasterRcnnBoxCoder) 84 | )) 85 | _sym_db.RegisterMessage(FasterRcnnBoxCoder) 86 | 87 | 88 | # @@protoc_insertion_point(module_scope) 89 | -------------------------------------------------------------------------------- /protos/grid_anchor_generator.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for GridAnchorGenerator. See 6 | // anchor_generators/grid_anchor_generator.py for details. 7 | message GridAnchorGenerator { 8 | // Anchor height in pixels. 9 | optional int32 height = 1 [default = 256]; 10 | 11 | // Anchor width in pixels. 12 | optional int32 width = 2 [default = 256]; 13 | 14 | // Anchor stride in height dimension in pixels. 15 | optional int32 height_stride = 3 [default = 16]; 16 | 17 | // Anchor stride in width dimension in pixels. 18 | optional int32 width_stride = 4 [default = 16]; 19 | 20 | // Anchor height offset in pixels. 21 | optional int32 height_offset = 5 [default = 0]; 22 | 23 | // Anchor width offset in pixels. 24 | optional int32 width_offset = 6 [default = 0]; 25 | 26 | // At any given location, len(scales) * len(aspect_ratios) anchors are 27 | // generated with all possible combinations of scales and aspect ratios. 28 | 29 | // List of scales for the anchors. 30 | repeated float scales = 7; 31 | 32 | // List of aspect ratios for the anchors. 33 | repeated float aspect_ratios = 8; 34 | } 35 | -------------------------------------------------------------------------------- /protos/hyperparams.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for the convolution op hyperparameters to use in the 6 | // object detection pipeline. 7 | message Hyperparams { 8 | 9 | // Operations affected by hyperparameters. 10 | enum Op { 11 | // Convolution, Separable Convolution, Convolution transpose. 12 | CONV = 1; 13 | 14 | // Fully connected 15 | FC = 2; 16 | } 17 | optional Op op = 1 [default = CONV]; 18 | 19 | // Regularizer for the weights of the convolution op. 20 | optional Regularizer regularizer = 2; 21 | 22 | // Initializer for the weights of the convolution op. 23 | optional Initializer initializer = 3; 24 | 25 | // Type of activation to apply after convolution. 26 | enum Activation { 27 | // Use None (no activation) 28 | NONE = 0; 29 | 30 | // Use tf.nn.relu 31 | RELU = 1; 32 | 33 | // Use tf.nn.relu6 34 | RELU_6 = 2; 35 | } 36 | optional Activation activation = 4 [default = RELU]; 37 | 38 | // BatchNorm hyperparameters. If this parameter is NOT set then BatchNorm is 39 | // not applied! 40 | optional BatchNorm batch_norm = 5; 41 | } 42 | 43 | // Proto with one-of field for regularizers. 44 | message Regularizer { 45 | oneof regularizer_oneof { 46 | L1Regularizer l1_regularizer = 1; 47 | L2Regularizer l2_regularizer = 2; 48 | } 49 | } 50 | 51 | // Configuration proto for L1 Regularizer. 52 | // See https://www.tensorflow.org/api_docs/python/tf/contrib/layers/l1_regularizer 53 | message L1Regularizer { 54 | optional float weight = 1 [default = 1.0]; 55 | } 56 | 57 | // Configuration proto for L2 Regularizer. 58 | // See https://www.tensorflow.org/api_docs/python/tf/contrib/layers/l2_regularizer 59 | message L2Regularizer { 60 | optional float weight = 1 [default = 1.0]; 61 | } 62 | 63 | // Proto with one-of field for initializers. 64 | message Initializer { 65 | oneof initializer_oneof { 66 | TruncatedNormalInitializer truncated_normal_initializer = 1; 67 | VarianceScalingInitializer variance_scaling_initializer = 2; 68 | } 69 | } 70 | 71 | // Configuration proto for truncated normal initializer. See 72 | // https://www.tensorflow.org/api_docs/python/tf/truncated_normal_initializer 73 | message TruncatedNormalInitializer { 74 | optional float mean = 1 [default = 0.0]; 75 | optional float stddev = 2 [default = 1.0]; 76 | } 77 | 78 | // Configuration proto for variance scaling initializer. See 79 | // https://www.tensorflow.org/api_docs/python/tf/contrib/layers/ 80 | // variance_scaling_initializer 81 | message VarianceScalingInitializer { 82 | optional float factor = 1 [default = 2.0]; 83 | optional bool uniform = 2 [default = false]; 84 | enum Mode { 85 | FAN_IN = 0; 86 | FAN_OUT = 1; 87 | FAN_AVG = 2; 88 | } 89 | optional Mode mode = 3 [default = FAN_IN]; 90 | } 91 | 92 | // Configuration proto for batch norm to apply after convolution op. See 93 | // https://www.tensorflow.org/api_docs/python/tf/contrib/layers/batch_norm 94 | message BatchNorm { 95 | optional float decay = 1 [default = 0.999]; 96 | optional bool center = 2 [default = true]; 97 | optional bool scale = 3 [default = false]; 98 | optional float epsilon = 4 [default = 0.001]; 99 | // Whether to train the batch norm variables. If this is set to false during 100 | // training, the current value of the batch_norm variables are used for 101 | // forward pass but they are never updated. 102 | optional bool train = 5 [default = true]; 103 | } 104 | -------------------------------------------------------------------------------- /protos/image_resizer.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for image resizing operations. 6 | // See builders/image_resizer_builder.py for details. 7 | message ImageResizer { 8 | oneof image_resizer_oneof { 9 | KeepAspectRatioResizer keep_aspect_ratio_resizer = 1; 10 | FixedShapeResizer fixed_shape_resizer = 2; 11 | } 12 | } 13 | 14 | 15 | // Configuration proto for image resizer that keeps aspect ratio. 16 | message KeepAspectRatioResizer { 17 | // Desired size of the smaller image dimension in pixels. 18 | optional int32 min_dimension = 1 [default = 600]; 19 | 20 | // Desired size of the larger image dimension in pixels. 21 | optional int32 max_dimension = 2 [default = 1024]; 22 | } 23 | 24 | 25 | // Configuration proto for image resizer that resizes to a fixed shape. 26 | message FixedShapeResizer { 27 | // Desired height of image in pixels. 28 | optional int32 height = 1 [default = 300]; 29 | 30 | // Desired width of image in pixels. 31 | optional int32 width = 2 [default = 300]; 32 | } 33 | -------------------------------------------------------------------------------- /protos/input_reader.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for defining input readers that generate Object Detection 6 | // Examples from input sources. Input readers are expected to generate a 7 | // dictionary of tensors, with the following fields populated: 8 | // 9 | // 'image': an [image_height, image_width, channels] image tensor that detection 10 | // will be run on. 11 | // 'groundtruth_classes': a [num_boxes] int32 tensor storing the class 12 | // labels of detected boxes in the image. 13 | // 'groundtruth_boxes': a [num_boxes, 4] float tensor storing the coordinates of 14 | // detected boxes in the image. 15 | // 'groundtruth_instance_masks': (Optional), a [num_boxes, image_height, 16 | // image_width] float tensor storing binary mask of the objects in boxes. 17 | 18 | message InputReader { 19 | // Path to StringIntLabelMap pbtxt file specifying the mapping from string 20 | // labels to integer ids. 21 | optional string label_map_path = 1 [default=""]; 22 | 23 | // Whether data should be processed in the order they are read in, or 24 | // shuffled randomly. 25 | optional bool shuffle = 2 [default=true]; 26 | 27 | // Maximum number of records to keep in reader queue. 28 | optional uint32 queue_capacity = 3 [default=2000]; 29 | 30 | // Minimum number of records to keep in reader queue. A large value is needed 31 | // to generate a good random shuffle. 32 | optional uint32 min_after_dequeue = 4 [default=1000]; 33 | 34 | // The number of times a data source is read. If set to zero, the data source 35 | // will be reused indefinitely. 36 | optional uint32 num_epochs = 5 [default=0]; 37 | 38 | // Number of reader instances to create. 39 | optional uint32 num_readers = 6 [default=8]; 40 | 41 | // Whether to load groundtruth instance masks. 42 | optional bool load_instance_masks = 7 [default = false]; 43 | 44 | oneof input_reader { 45 | TFRecordInputReader tf_record_input_reader = 8; 46 | ExternalInputReader external_input_reader = 9; 47 | } 48 | } 49 | 50 | // An input reader that reads TF Example protos from local TFRecord files. 51 | message TFRecordInputReader { 52 | // Path to TFRecordFile. 53 | optional string input_path = 1 [default=""]; 54 | } 55 | 56 | // An externally defined input reader. Users may define an extension to this 57 | // proto to interface their own input readers. 58 | message ExternalInputReader { 59 | extensions 1 to 999; 60 | } 61 | -------------------------------------------------------------------------------- /protos/losses.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Message for configuring the localization loss, classification loss and hard 6 | // example miner used for training object detection models. See core/losses.py 7 | // for details 8 | message Loss { 9 | // Localization loss to use. 10 | optional LocalizationLoss localization_loss = 1; 11 | 12 | // Classification loss to use. 13 | optional ClassificationLoss classification_loss = 2; 14 | 15 | // If not left to default, applies hard example mining. 16 | optional HardExampleMiner hard_example_miner = 3; 17 | 18 | // Classification loss weight. 19 | optional float classification_weight = 4 [default=1.0]; 20 | 21 | // Localization loss weight. 22 | optional float localization_weight = 5 [default=1.0]; 23 | } 24 | 25 | // Configuration for bounding box localization loss function. 26 | message LocalizationLoss { 27 | oneof localization_loss { 28 | WeightedL2LocalizationLoss weighted_l2 = 1; 29 | WeightedSmoothL1LocalizationLoss weighted_smooth_l1 = 2; 30 | WeightedIOULocalizationLoss weighted_iou = 3; 31 | } 32 | } 33 | 34 | // L2 location loss: 0.5 * ||weight * (a - b)|| ^ 2 35 | message WeightedL2LocalizationLoss { 36 | // Output loss per anchor. 37 | optional bool anchorwise_output = 1 [default=false]; 38 | } 39 | 40 | // SmoothL1 (Huber) location loss: .5 * x ^ 2 if |x| < 1 else |x| - .5 41 | message WeightedSmoothL1LocalizationLoss { 42 | // Output loss per anchor. 43 | optional bool anchorwise_output = 1 [default=false]; 44 | } 45 | 46 | // Intersection over union location loss: 1 - IOU 47 | message WeightedIOULocalizationLoss { 48 | } 49 | 50 | // Configuration for class prediction loss function. 51 | message ClassificationLoss { 52 | oneof classification_loss { 53 | WeightedSigmoidClassificationLoss weighted_sigmoid = 1; 54 | WeightedSoftmaxClassificationLoss weighted_softmax = 2; 55 | BootstrappedSigmoidClassificationLoss bootstrapped_sigmoid = 3; 56 | } 57 | } 58 | 59 | // Classification loss using a sigmoid function over class predictions. 60 | message WeightedSigmoidClassificationLoss { 61 | // Output loss per anchor. 62 | optional bool anchorwise_output = 1 [default=false]; 63 | } 64 | 65 | // Classification loss using a softmax function over class predictions. 66 | message WeightedSoftmaxClassificationLoss { 67 | // Output loss per anchor. 68 | optional bool anchorwise_output = 1 [default=false]; 69 | } 70 | 71 | // Classification loss using a sigmoid function over the class prediction with 72 | // the highest prediction score. 73 | message BootstrappedSigmoidClassificationLoss { 74 | // Interpolation weight between 0 and 1. 75 | optional float alpha = 1; 76 | 77 | // Whether hard boot strapping should be used or not. If true, will only use 78 | // one class favored by model. Othewise, will use all predicted class 79 | // probabilities. 80 | optional bool hard_bootstrap = 2 [default=false]; 81 | 82 | // Output loss per anchor. 83 | optional bool anchorwise_output = 3 [default=false]; 84 | } 85 | 86 | // Configuation for hard example miner. 87 | message HardExampleMiner { 88 | // Maximum number of hard examples to be selected per image (prior to 89 | // enforcing max negative to positive ratio constraint). If set to 0, 90 | // all examples obtained after NMS are considered. 91 | optional int32 num_hard_examples = 1 [default=64]; 92 | 93 | // Minimum intersection over union for an example to be discarded during NMS. 94 | optional float iou_threshold = 2 [default=0.7]; 95 | 96 | // Whether to use classification losses ('cls', default), localization losses 97 | // ('loc') or both losses ('both'). In the case of 'both', cls_loss_weight and 98 | // loc_loss_weight are used to compute weighted sum of the two losses. 99 | enum LossType { 100 | BOTH = 0; 101 | CLASSIFICATION = 1; 102 | LOCALIZATION = 2; 103 | } 104 | optional LossType loss_type = 3 [default=BOTH]; 105 | 106 | // Maximum number of negatives to retain for each positive anchor. If 107 | // num_negatives_per_positive is 0 no prespecified negative:positive ratio is 108 | // enforced. 109 | optional int32 max_negatives_per_positive = 4 [default=0]; 110 | 111 | // Minimum number of negative anchors to sample for a given image. Setting 112 | // this to a positive number samples negatives in an image without any 113 | // positive anchors and thus not bias the model towards having at least one 114 | // detection per image. 115 | optional int32 min_negatives_per_image = 5 [default=0]; 116 | } 117 | -------------------------------------------------------------------------------- /protos/matcher.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/argmax_matcher.proto"; 6 | import "object_detection/protos/bipartite_matcher.proto"; 7 | 8 | // Configuration proto for the matcher to be used in the object detection 9 | // pipeline. See core/matcher.py for details. 10 | message Matcher { 11 | oneof matcher_oneof { 12 | ArgMaxMatcher argmax_matcher = 1; 13 | BipartiteMatcher bipartite_matcher = 2; 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /protos/matcher_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/matcher.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | import object_detection.protos.argmax_matcher_pb2 17 | import object_detection.protos.bipartite_matcher_pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='object_detection/protos/matcher.proto', 22 | package='object_detection.protos', 23 | serialized_pb=_b('\n%object_detection/protos/matcher.proto\x12\x17object_detection.protos\x1a,object_detection/protos/argmax_matcher.proto\x1a/object_detection/protos/bipartite_matcher.proto\"\xa4\x01\n\x07Matcher\x12@\n\x0e\x61rgmax_matcher\x18\x01 \x01(\x0b\x32&.object_detection.protos.ArgMaxMatcherH\x00\x12\x46\n\x11\x62ipartite_matcher\x18\x02 \x01(\x0b\x32).object_detection.protos.BipartiteMatcherH\x00\x42\x0f\n\rmatcher_oneof') 24 | , 25 | dependencies=[object_detection.protos.argmax_matcher_pb2.DESCRIPTOR,object_detection.protos.bipartite_matcher_pb2.DESCRIPTOR,]) 26 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 27 | 28 | 29 | 30 | 31 | _MATCHER = _descriptor.Descriptor( 32 | name='Matcher', 33 | full_name='object_detection.protos.Matcher', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='argmax_matcher', full_name='object_detection.protos.Matcher.argmax_matcher', index=0, 40 | number=1, type=11, cpp_type=10, label=1, 41 | has_default_value=False, default_value=None, 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None), 45 | _descriptor.FieldDescriptor( 46 | name='bipartite_matcher', full_name='object_detection.protos.Matcher.bipartite_matcher', index=1, 47 | number=2, type=11, cpp_type=10, label=1, 48 | has_default_value=False, default_value=None, 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None), 52 | ], 53 | extensions=[ 54 | ], 55 | nested_types=[], 56 | enum_types=[ 57 | ], 58 | options=None, 59 | is_extendable=False, 60 | extension_ranges=[], 61 | oneofs=[ 62 | _descriptor.OneofDescriptor( 63 | name='matcher_oneof', full_name='object_detection.protos.Matcher.matcher_oneof', 64 | index=0, containing_type=None, fields=[]), 65 | ], 66 | serialized_start=162, 67 | serialized_end=326, 68 | ) 69 | 70 | _MATCHER.fields_by_name['argmax_matcher'].message_type = object_detection.protos.argmax_matcher_pb2._ARGMAXMATCHER 71 | _MATCHER.fields_by_name['bipartite_matcher'].message_type = object_detection.protos.bipartite_matcher_pb2._BIPARTITEMATCHER 72 | _MATCHER.oneofs_by_name['matcher_oneof'].fields.append( 73 | _MATCHER.fields_by_name['argmax_matcher']) 74 | _MATCHER.fields_by_name['argmax_matcher'].containing_oneof = _MATCHER.oneofs_by_name['matcher_oneof'] 75 | _MATCHER.oneofs_by_name['matcher_oneof'].fields.append( 76 | _MATCHER.fields_by_name['bipartite_matcher']) 77 | _MATCHER.fields_by_name['bipartite_matcher'].containing_oneof = _MATCHER.oneofs_by_name['matcher_oneof'] 78 | DESCRIPTOR.message_types_by_name['Matcher'] = _MATCHER 79 | 80 | Matcher = _reflection.GeneratedProtocolMessageType('Matcher', (_message.Message,), dict( 81 | DESCRIPTOR = _MATCHER, 82 | __module__ = 'object_detection.protos.matcher_pb2' 83 | # @@protoc_insertion_point(class_scope:object_detection.protos.Matcher) 84 | )) 85 | _sym_db.RegisterMessage(Matcher) 86 | 87 | 88 | # @@protoc_insertion_point(module_scope) 89 | -------------------------------------------------------------------------------- /protos/mean_stddev_box_coder.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for MeanStddevBoxCoder. See 6 | // box_coders/mean_stddev_box_coder.py for details. 7 | message MeanStddevBoxCoder { 8 | } 9 | -------------------------------------------------------------------------------- /protos/mean_stddev_box_coder_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/mean_stddev_box_coder.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/mean_stddev_box_coder.proto', 20 | package='object_detection.protos', 21 | serialized_pb=_b('\n3object_detection/protos/mean_stddev_box_coder.proto\x12\x17object_detection.protos\"\x14\n\x12MeanStddevBoxCoder') 22 | ) 23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 24 | 25 | 26 | 27 | 28 | _MEANSTDDEVBOXCODER = _descriptor.Descriptor( 29 | name='MeanStddevBoxCoder', 30 | full_name='object_detection.protos.MeanStddevBoxCoder', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | ], 36 | extensions=[ 37 | ], 38 | nested_types=[], 39 | enum_types=[ 40 | ], 41 | options=None, 42 | is_extendable=False, 43 | extension_ranges=[], 44 | oneofs=[ 45 | ], 46 | serialized_start=80, 47 | serialized_end=100, 48 | ) 49 | 50 | DESCRIPTOR.message_types_by_name['MeanStddevBoxCoder'] = _MEANSTDDEVBOXCODER 51 | 52 | MeanStddevBoxCoder = _reflection.GeneratedProtocolMessageType('MeanStddevBoxCoder', (_message.Message,), dict( 53 | DESCRIPTOR = _MEANSTDDEVBOXCODER, 54 | __module__ = 'object_detection.protos.mean_stddev_box_coder_pb2' 55 | # @@protoc_insertion_point(class_scope:object_detection.protos.MeanStddevBoxCoder) 56 | )) 57 | _sym_db.RegisterMessage(MeanStddevBoxCoder) 58 | 59 | 60 | # @@protoc_insertion_point(module_scope) 61 | -------------------------------------------------------------------------------- /protos/model.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/faster_rcnn.proto"; 6 | import "object_detection/protos/ssd.proto"; 7 | 8 | // Top level configuration for DetectionModels. 9 | message DetectionModel { 10 | oneof model { 11 | FasterRcnn faster_rcnn = 1; 12 | Ssd ssd = 2; 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /protos/model_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/model.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | import object_detection.protos.faster_rcnn_pb2 17 | import object_detection.protos.ssd_pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='object_detection/protos/model.proto', 22 | package='object_detection.protos', 23 | serialized_pb=_b('\n#object_detection/protos/model.proto\x12\x17object_detection.protos\x1a)object_detection/protos/faster_rcnn.proto\x1a!object_detection/protos/ssd.proto\"\x82\x01\n\x0e\x44\x65tectionModel\x12:\n\x0b\x66\x61ster_rcnn\x18\x01 \x01(\x0b\x32#.object_detection.protos.FasterRcnnH\x00\x12+\n\x03ssd\x18\x02 \x01(\x0b\x32\x1c.object_detection.protos.SsdH\x00\x42\x07\n\x05model') 24 | , 25 | dependencies=[object_detection.protos.faster_rcnn_pb2.DESCRIPTOR,object_detection.protos.ssd_pb2.DESCRIPTOR,]) 26 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 27 | 28 | 29 | 30 | 31 | _DETECTIONMODEL = _descriptor.Descriptor( 32 | name='DetectionModel', 33 | full_name='object_detection.protos.DetectionModel', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='faster_rcnn', full_name='object_detection.protos.DetectionModel.faster_rcnn', index=0, 40 | number=1, type=11, cpp_type=10, label=1, 41 | has_default_value=False, default_value=None, 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None), 45 | _descriptor.FieldDescriptor( 46 | name='ssd', full_name='object_detection.protos.DetectionModel.ssd', index=1, 47 | number=2, type=11, cpp_type=10, label=1, 48 | has_default_value=False, default_value=None, 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None), 52 | ], 53 | extensions=[ 54 | ], 55 | nested_types=[], 56 | enum_types=[ 57 | ], 58 | options=None, 59 | is_extendable=False, 60 | extension_ranges=[], 61 | oneofs=[ 62 | _descriptor.OneofDescriptor( 63 | name='model', full_name='object_detection.protos.DetectionModel.model', 64 | index=0, containing_type=None, fields=[]), 65 | ], 66 | serialized_start=143, 67 | serialized_end=273, 68 | ) 69 | 70 | _DETECTIONMODEL.fields_by_name['faster_rcnn'].message_type = object_detection.protos.faster_rcnn_pb2._FASTERRCNN 71 | _DETECTIONMODEL.fields_by_name['ssd'].message_type = object_detection.protos.ssd_pb2._SSD 72 | _DETECTIONMODEL.oneofs_by_name['model'].fields.append( 73 | _DETECTIONMODEL.fields_by_name['faster_rcnn']) 74 | _DETECTIONMODEL.fields_by_name['faster_rcnn'].containing_oneof = _DETECTIONMODEL.oneofs_by_name['model'] 75 | _DETECTIONMODEL.oneofs_by_name['model'].fields.append( 76 | _DETECTIONMODEL.fields_by_name['ssd']) 77 | _DETECTIONMODEL.fields_by_name['ssd'].containing_oneof = _DETECTIONMODEL.oneofs_by_name['model'] 78 | DESCRIPTOR.message_types_by_name['DetectionModel'] = _DETECTIONMODEL 79 | 80 | DetectionModel = _reflection.GeneratedProtocolMessageType('DetectionModel', (_message.Message,), dict( 81 | DESCRIPTOR = _DETECTIONMODEL, 82 | __module__ = 'object_detection.protos.model_pb2' 83 | # @@protoc_insertion_point(class_scope:object_detection.protos.DetectionModel) 84 | )) 85 | _sym_db.RegisterMessage(DetectionModel) 86 | 87 | 88 | # @@protoc_insertion_point(module_scope) 89 | -------------------------------------------------------------------------------- /protos/optimizer.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Messages for configuring the optimizing strategy for training object 6 | // detection models. 7 | 8 | // Top level optimizer message. 9 | message Optimizer { 10 | oneof optimizer { 11 | RMSPropOptimizer rms_prop_optimizer = 1; 12 | MomentumOptimizer momentum_optimizer = 2; 13 | AdamOptimizer adam_optimizer = 3; 14 | } 15 | optional bool use_moving_average = 4 [default=true]; 16 | optional float moving_average_decay = 5 [default=0.9999]; 17 | } 18 | 19 | // Configuration message for the RMSPropOptimizer 20 | // See: https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer 21 | message RMSPropOptimizer { 22 | optional LearningRate learning_rate = 1; 23 | optional float momentum_optimizer_value = 2 [default=0.9]; 24 | optional float decay = 3 [default=0.9]; 25 | optional float epsilon = 4 [default=1.0]; 26 | } 27 | 28 | // Configuration message for the MomentumOptimizer 29 | // See: https://www.tensorflow.org/api_docs/python/tf/train/MomentumOptimizer 30 | message MomentumOptimizer { 31 | optional LearningRate learning_rate = 1; 32 | optional float momentum_optimizer_value = 2 [default=0.9]; 33 | } 34 | 35 | // Configuration message for the AdamOptimizer 36 | // See: https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer 37 | message AdamOptimizer { 38 | optional LearningRate learning_rate = 1; 39 | } 40 | 41 | // Configuration message for optimizer learning rate. 42 | message LearningRate { 43 | oneof learning_rate { 44 | ConstantLearningRate constant_learning_rate = 1; 45 | ExponentialDecayLearningRate exponential_decay_learning_rate = 2; 46 | ManualStepLearningRate manual_step_learning_rate = 3; 47 | } 48 | } 49 | 50 | // Configuration message for a constant learning rate. 51 | message ConstantLearningRate { 52 | optional float learning_rate = 1 [default=0.002]; 53 | } 54 | 55 | // Configuration message for an exponentially decaying learning rate. 56 | // See https://www.tensorflow.org/versions/master/api_docs/python/train/ \ 57 | // decaying_the_learning_rate#exponential_decay 58 | message ExponentialDecayLearningRate { 59 | optional float initial_learning_rate = 1 [default=0.002]; 60 | optional uint32 decay_steps = 2 [default=4000000]; 61 | optional float decay_factor = 3 [default=0.95]; 62 | optional bool staircase = 4 [default=true]; 63 | } 64 | 65 | // Configuration message for a manually defined learning rate schedule. 66 | message ManualStepLearningRate { 67 | optional float initial_learning_rate = 1 [default=0.002]; 68 | message LearningRateSchedule { 69 | optional uint32 step = 1; 70 | optional float learning_rate = 2 [default=0.002]; 71 | } 72 | repeated LearningRateSchedule schedule = 2; 73 | } 74 | -------------------------------------------------------------------------------- /protos/pipeline.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/eval.proto"; 6 | import "object_detection/protos/input_reader.proto"; 7 | import "object_detection/protos/model.proto"; 8 | import "object_detection/protos/train.proto"; 9 | 10 | // Convenience message for configuring a training and eval pipeline. Allows all 11 | // of the pipeline parameters to be configured from one file. 12 | message TrainEvalPipelineConfig { 13 | optional DetectionModel model = 1; 14 | optional TrainConfig train_config = 2; 15 | optional InputReader train_input_reader = 3; 16 | optional EvalConfig eval_config = 4; 17 | optional InputReader eval_input_reader = 5; 18 | } 19 | -------------------------------------------------------------------------------- /protos/post_processing.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for non-max-suppression operation on a batch of 6 | // detections. 7 | message BatchNonMaxSuppression { 8 | // Scalar threshold for score (low scoring boxes are removed). 9 | optional float score_threshold = 1 [default = 0.0]; 10 | 11 | // Scalar threshold for IOU (boxes that have high IOU overlap 12 | // with previously selected boxes are removed). 13 | optional float iou_threshold = 2 [default = 0.6]; 14 | 15 | // Maximum number of detections to retain per class. 16 | optional int32 max_detections_per_class = 3 [default = 100]; 17 | 18 | // Maximum number of detections to retain across all classes. 19 | optional int32 max_total_detections = 5 [default = 100]; 20 | } 21 | 22 | // Configuration proto for post-processing predicted boxes and 23 | // scores. 24 | message PostProcessing { 25 | // Non max suppression parameters. 26 | optional BatchNonMaxSuppression batch_non_max_suppression = 1; 27 | 28 | // Enum to specify how to convert the detection scores. 29 | enum ScoreConverter { 30 | // Input scores equals output scores. 31 | IDENTITY = 0; 32 | 33 | // Applies a sigmoid on input scores. 34 | SIGMOID = 1; 35 | 36 | // Applies a softmax on input scores 37 | SOFTMAX = 2; 38 | } 39 | 40 | // Score converter to use. 41 | optional ScoreConverter score_converter = 2 [default = IDENTITY]; 42 | } 43 | -------------------------------------------------------------------------------- /protos/region_similarity_calculator.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for region similarity calculators. See 6 | // core/region_similarity_calculator.py for details. 7 | message RegionSimilarityCalculator { 8 | oneof region_similarity { 9 | NegSqDistSimilarity neg_sq_dist_similarity = 1; 10 | IouSimilarity iou_similarity = 2; 11 | IoaSimilarity ioa_similarity = 3; 12 | } 13 | } 14 | 15 | // Configuration for negative squared distance similarity calculator. 16 | message NegSqDistSimilarity { 17 | } 18 | 19 | // Configuration for intersection-over-union (IOU) similarity calculator. 20 | message IouSimilarity { 21 | } 22 | 23 | // Configuration for intersection-over-area (IOA) similarity calculator. 24 | message IoaSimilarity { 25 | } 26 | -------------------------------------------------------------------------------- /protos/square_box_coder.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for SquareBoxCoder. See 6 | // box_coders/square_box_coder.py for details. 7 | message SquareBoxCoder { 8 | // Scale factor for anchor encoded box center. 9 | optional float y_scale = 1 [default = 10.0]; 10 | optional float x_scale = 2 [default = 10.0]; 11 | 12 | // Scale factor for anchor encoded box length. 13 | optional float length_scale = 3 [default = 5.0]; 14 | } 15 | -------------------------------------------------------------------------------- /protos/square_box_coder_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/square_box_coder.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/square_box_coder.proto', 20 | package='object_detection.protos', 21 | serialized_pb=_b('\n.object_detection/protos/square_box_coder.proto\x12\x17object_detection.protos\"S\n\x0eSquareBoxCoder\x12\x13\n\x07y_scale\x18\x01 \x01(\x02:\x02\x31\x30\x12\x13\n\x07x_scale\x18\x02 \x01(\x02:\x02\x31\x30\x12\x17\n\x0clength_scale\x18\x03 \x01(\x02:\x01\x35') 22 | ) 23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 24 | 25 | 26 | 27 | 28 | _SQUAREBOXCODER = _descriptor.Descriptor( 29 | name='SquareBoxCoder', 30 | full_name='object_detection.protos.SquareBoxCoder', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='y_scale', full_name='object_detection.protos.SquareBoxCoder.y_scale', index=0, 37 | number=1, type=2, cpp_type=6, label=1, 38 | has_default_value=True, default_value=10, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='x_scale', full_name='object_detection.protos.SquareBoxCoder.x_scale', index=1, 44 | number=2, type=2, cpp_type=6, label=1, 45 | has_default_value=True, default_value=10, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='length_scale', full_name='object_detection.protos.SquareBoxCoder.length_scale', index=2, 51 | number=3, type=2, cpp_type=6, label=1, 52 | has_default_value=True, default_value=5, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | ], 57 | extensions=[ 58 | ], 59 | nested_types=[], 60 | enum_types=[ 61 | ], 62 | options=None, 63 | is_extendable=False, 64 | extension_ranges=[], 65 | oneofs=[ 66 | ], 67 | serialized_start=75, 68 | serialized_end=158, 69 | ) 70 | 71 | DESCRIPTOR.message_types_by_name['SquareBoxCoder'] = _SQUAREBOXCODER 72 | 73 | SquareBoxCoder = _reflection.GeneratedProtocolMessageType('SquareBoxCoder', (_message.Message,), dict( 74 | DESCRIPTOR = _SQUAREBOXCODER, 75 | __module__ = 'object_detection.protos.square_box_coder_pb2' 76 | # @@protoc_insertion_point(class_scope:object_detection.protos.SquareBoxCoder) 77 | )) 78 | _sym_db.RegisterMessage(SquareBoxCoder) 79 | 80 | 81 | # @@protoc_insertion_point(module_scope) 82 | -------------------------------------------------------------------------------- /protos/ssd.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | package object_detection.protos; 3 | 4 | import "object_detection/protos/anchor_generator.proto"; 5 | import "object_detection/protos/box_coder.proto"; 6 | import "object_detection/protos/box_predictor.proto"; 7 | import "object_detection/protos/hyperparams.proto"; 8 | import "object_detection/protos/image_resizer.proto"; 9 | import "object_detection/protos/matcher.proto"; 10 | import "object_detection/protos/losses.proto"; 11 | import "object_detection/protos/post_processing.proto"; 12 | import "object_detection/protos/region_similarity_calculator.proto"; 13 | 14 | // Configuration for Single Shot Detection (SSD) models. 15 | message Ssd { 16 | 17 | // Number of classes to predict. 18 | optional int32 num_classes = 1; 19 | 20 | // Image resizer for preprocessing the input image. 21 | optional ImageResizer image_resizer = 2; 22 | 23 | // Feature extractor config. 24 | optional SsdFeatureExtractor feature_extractor = 3; 25 | 26 | // Box coder to encode the boxes. 27 | optional BoxCoder box_coder = 4; 28 | 29 | // Matcher to match groundtruth with anchors. 30 | optional Matcher matcher = 5; 31 | 32 | // Region similarity calculator to compute similarity of boxes. 33 | optional RegionSimilarityCalculator similarity_calculator = 6; 34 | 35 | // Box predictor to attach to the features. 36 | optional BoxPredictor box_predictor = 7; 37 | 38 | // Anchor generator to compute anchors. 39 | optional AnchorGenerator anchor_generator = 8; 40 | 41 | // Post processing to apply on the predictions. 42 | optional PostProcessing post_processing = 9; 43 | 44 | // Whether to normalize the loss by number of groundtruth boxes that match to 45 | // the anchors. 46 | optional bool normalize_loss_by_num_matches = 10 [default=true]; 47 | 48 | // Loss configuration for training. 49 | optional Loss loss = 11; 50 | } 51 | 52 | 53 | message SsdFeatureExtractor { 54 | // Type of ssd feature extractor. 55 | optional string type = 1; 56 | 57 | // The factor to alter the depth of the channels in the feature extractor. 58 | optional float depth_multiplier = 2 [default=1.0]; 59 | 60 | // Minimum number of the channels in the feature extractor. 61 | optional int32 min_depth = 3 [default=16]; 62 | 63 | // Hyperparameters for the feature extractor. 64 | optional Hyperparams conv_hyperparams = 4; 65 | } 66 | -------------------------------------------------------------------------------- /protos/ssd_anchor_generator.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for SSD anchor generator described in 6 | // https://arxiv.org/abs/1512.02325. See 7 | // anchor_generators/multiple_grid_anchor_generator.py for details. 8 | message SsdAnchorGenerator { 9 | // Number of grid layers to create anchors for. 10 | optional int32 num_layers = 1 [default = 6]; 11 | 12 | // Scale of anchors corresponding to finest resolution. 13 | optional float min_scale = 2 [default = 0.2]; 14 | 15 | // Scale of anchors corresponding to coarsest resolution 16 | optional float max_scale = 3 [default = 0.95]; 17 | 18 | // Aspect ratios for anchors at each grid point. 19 | repeated float aspect_ratios = 4; 20 | 21 | // Whether to use the following aspect ratio and scale combination for the 22 | // layer with the finest resolution : (scale=0.1, aspect_ratio=1.0), 23 | // (scale=min_scale, aspect_ration=2.0), (scale=min_scale, aspect_ratio=0.5). 24 | optional bool reduce_boxes_in_lowest_layer = 5 [default = true]; 25 | } 26 | -------------------------------------------------------------------------------- /protos/ssd_anchor_generator_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/ssd_anchor_generator.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/ssd_anchor_generator.proto', 20 | package='object_detection.protos', 21 | serialized_pb=_b('\n2object_detection/protos/ssd_anchor_generator.proto\x12\x17object_detection.protos\"\x9f\x01\n\x12SsdAnchorGenerator\x12\x15\n\nnum_layers\x18\x01 \x01(\x05:\x01\x36\x12\x16\n\tmin_scale\x18\x02 \x01(\x02:\x03\x30.2\x12\x17\n\tmax_scale\x18\x03 \x01(\x02:\x04\x30.95\x12\x15\n\raspect_ratios\x18\x04 \x03(\x02\x12*\n\x1creduce_boxes_in_lowest_layer\x18\x05 \x01(\x08:\x04true') 22 | ) 23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 24 | 25 | 26 | 27 | 28 | _SSDANCHORGENERATOR = _descriptor.Descriptor( 29 | name='SsdAnchorGenerator', 30 | full_name='object_detection.protos.SsdAnchorGenerator', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='num_layers', full_name='object_detection.protos.SsdAnchorGenerator.num_layers', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=True, default_value=6, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='min_scale', full_name='object_detection.protos.SsdAnchorGenerator.min_scale', index=1, 44 | number=2, type=2, cpp_type=6, label=1, 45 | has_default_value=True, default_value=0.2, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='max_scale', full_name='object_detection.protos.SsdAnchorGenerator.max_scale', index=2, 51 | number=3, type=2, cpp_type=6, label=1, 52 | has_default_value=True, default_value=0.95, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | _descriptor.FieldDescriptor( 57 | name='aspect_ratios', full_name='object_detection.protos.SsdAnchorGenerator.aspect_ratios', index=3, 58 | number=4, type=2, cpp_type=6, label=3, 59 | has_default_value=False, default_value=[], 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None), 63 | _descriptor.FieldDescriptor( 64 | name='reduce_boxes_in_lowest_layer', full_name='object_detection.protos.SsdAnchorGenerator.reduce_boxes_in_lowest_layer', index=4, 65 | number=5, type=8, cpp_type=7, label=1, 66 | has_default_value=True, default_value=True, 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None), 70 | ], 71 | extensions=[ 72 | ], 73 | nested_types=[], 74 | enum_types=[ 75 | ], 76 | options=None, 77 | is_extendable=False, 78 | extension_ranges=[], 79 | oneofs=[ 80 | ], 81 | serialized_start=80, 82 | serialized_end=239, 83 | ) 84 | 85 | DESCRIPTOR.message_types_by_name['SsdAnchorGenerator'] = _SSDANCHORGENERATOR 86 | 87 | SsdAnchorGenerator = _reflection.GeneratedProtocolMessageType('SsdAnchorGenerator', (_message.Message,), dict( 88 | DESCRIPTOR = _SSDANCHORGENERATOR, 89 | __module__ = 'object_detection.protos.ssd_anchor_generator_pb2' 90 | # @@protoc_insertion_point(class_scope:object_detection.protos.SsdAnchorGenerator) 91 | )) 92 | _sym_db.RegisterMessage(SsdAnchorGenerator) 93 | 94 | 95 | # @@protoc_insertion_point(module_scope) 96 | -------------------------------------------------------------------------------- /protos/string_int_label_map.proto: -------------------------------------------------------------------------------- 1 | // Message to store the mapping from class label strings to class id. Datasets 2 | // use string labels to represent classes while the object detection framework 3 | // works with class ids. This message maps them so they can be converted back 4 | // and forth as needed. 5 | syntax = "proto2"; 6 | 7 | package object_detection.protos; 8 | 9 | message StringIntLabelMapItem { 10 | // String name. The most common practice is to set this to a MID or synsets 11 | // id. 12 | optional string name = 1; 13 | 14 | // Integer id that maps to the string name above. Label ids should start from 15 | // 1. 16 | optional int32 id = 2; 17 | 18 | // Human readable string label. 19 | optional string display_name = 3; 20 | }; 21 | 22 | message StringIntLabelMap { 23 | repeated StringIntLabelMapItem item = 1; 24 | }; 25 | -------------------------------------------------------------------------------- /protos/string_int_label_map_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/string_int_label_map.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/string_int_label_map.proto', 20 | package='object_detection.protos', 21 | serialized_pb=_b('\n2object_detection/protos/string_int_label_map.proto\x12\x17object_detection.protos\"G\n\x15StringIntLabelMapItem\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02id\x18\x02 \x01(\x05\x12\x14\n\x0c\x64isplay_name\x18\x03 \x01(\t\"Q\n\x11StringIntLabelMap\x12<\n\x04item\x18\x01 \x03(\x0b\x32..object_detection.protos.StringIntLabelMapItem') 22 | ) 23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 24 | 25 | 26 | 27 | 28 | _STRINGINTLABELMAPITEM = _descriptor.Descriptor( 29 | name='StringIntLabelMapItem', 30 | full_name='object_detection.protos.StringIntLabelMapItem', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='name', full_name='object_detection.protos.StringIntLabelMapItem.name', index=0, 37 | number=1, type=9, cpp_type=9, label=1, 38 | has_default_value=False, default_value=_b("").decode('utf-8'), 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='id', full_name='object_detection.protos.StringIntLabelMapItem.id', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='display_name', full_name='object_detection.protos.StringIntLabelMapItem.display_name', index=2, 51 | number=3, type=9, cpp_type=9, label=1, 52 | has_default_value=False, default_value=_b("").decode('utf-8'), 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | ], 57 | extensions=[ 58 | ], 59 | nested_types=[], 60 | enum_types=[ 61 | ], 62 | options=None, 63 | is_extendable=False, 64 | extension_ranges=[], 65 | oneofs=[ 66 | ], 67 | serialized_start=79, 68 | serialized_end=150, 69 | ) 70 | 71 | 72 | _STRINGINTLABELMAP = _descriptor.Descriptor( 73 | name='StringIntLabelMap', 74 | full_name='object_detection.protos.StringIntLabelMap', 75 | filename=None, 76 | file=DESCRIPTOR, 77 | containing_type=None, 78 | fields=[ 79 | _descriptor.FieldDescriptor( 80 | name='item', full_name='object_detection.protos.StringIntLabelMap.item', index=0, 81 | number=1, type=11, cpp_type=10, label=3, 82 | has_default_value=False, default_value=[], 83 | message_type=None, enum_type=None, containing_type=None, 84 | is_extension=False, extension_scope=None, 85 | options=None), 86 | ], 87 | extensions=[ 88 | ], 89 | nested_types=[], 90 | enum_types=[ 91 | ], 92 | options=None, 93 | is_extendable=False, 94 | extension_ranges=[], 95 | oneofs=[ 96 | ], 97 | serialized_start=152, 98 | serialized_end=233, 99 | ) 100 | 101 | _STRINGINTLABELMAP.fields_by_name['item'].message_type = _STRINGINTLABELMAPITEM 102 | DESCRIPTOR.message_types_by_name['StringIntLabelMapItem'] = _STRINGINTLABELMAPITEM 103 | DESCRIPTOR.message_types_by_name['StringIntLabelMap'] = _STRINGINTLABELMAP 104 | 105 | StringIntLabelMapItem = _reflection.GeneratedProtocolMessageType('StringIntLabelMapItem', (_message.Message,), dict( 106 | DESCRIPTOR = _STRINGINTLABELMAPITEM, 107 | __module__ = 'object_detection.protos.string_int_label_map_pb2' 108 | # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMapItem) 109 | )) 110 | _sym_db.RegisterMessage(StringIntLabelMapItem) 111 | 112 | StringIntLabelMap = _reflection.GeneratedProtocolMessageType('StringIntLabelMap', (_message.Message,), dict( 113 | DESCRIPTOR = _STRINGINTLABELMAP, 114 | __module__ = 'object_detection.protos.string_int_label_map_pb2' 115 | # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMap) 116 | )) 117 | _sym_db.RegisterMessage(StringIntLabelMap) 118 | 119 | 120 | # @@protoc_insertion_point(module_scope) 121 | -------------------------------------------------------------------------------- /protos/train.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/optimizer.proto"; 6 | import "object_detection/protos/preprocessor.proto"; 7 | 8 | // Message for configuring DetectionModel training jobs (train.py). 9 | message TrainConfig { 10 | // Input queue batch size. 11 | optional uint32 batch_size = 1 [default=32]; 12 | 13 | // Data augmentation options. 14 | repeated PreprocessingStep data_augmentation_options = 2; 15 | 16 | // Whether to synchronize replicas during training. 17 | optional bool sync_replicas = 3 [default=false]; 18 | 19 | // How frequently to keep checkpoints. 20 | optional uint32 keep_checkpoint_every_n_hours = 4 [default=1000]; 21 | 22 | // Optimizer used to train the DetectionModel. 23 | optional Optimizer optimizer = 5; 24 | 25 | // If greater than 0, clips gradients by this value. 26 | optional float gradient_clipping_by_norm = 6 [default=0.0]; 27 | 28 | // Checkpoint to restore variables from. Typically used to load feature 29 | // extractor variables trained outside of object detection. 30 | optional string fine_tune_checkpoint = 7 [default=""]; 31 | 32 | // Specifies if the finetune checkpoint is from an object detection model. 33 | // If from an object detection model, the model being trained should have 34 | // the same parameters with the exception of the num_classes parameter. 35 | // If false, it assumes the checkpoint was a object classification model. 36 | optional bool from_detection_checkpoint = 8 [default=false]; 37 | 38 | // Number of steps to train the DetectionModel for. If 0, will train the model 39 | // indefinitely. 40 | optional uint32 num_steps = 9 [default=0]; 41 | 42 | // Number of training steps between replica startup. 43 | // This flag must be set to 0 if sync_replicas is set to true. 44 | optional float startup_delay_steps = 10 [default=15]; 45 | 46 | // If greater than 0, multiplies the gradient of bias variables by this 47 | // amount. 48 | optional float bias_grad_multiplier = 11 [default=0]; 49 | 50 | // Variables that should not be updated during training. 51 | repeated string freeze_variables = 12; 52 | 53 | // Number of replicas to aggregate before making parameter updates. 54 | optional int32 replicas_to_aggregate = 13 [default=1]; 55 | 56 | // Maximum number of elements to store within a queue. 57 | optional int32 batch_queue_capacity = 14 [default=600]; 58 | 59 | // Number of threads to use for batching. 60 | optional int32 num_batch_queue_threads = 15 [default=8]; 61 | 62 | // Maximum capacity of the queue used to prefetch assembled batches. 63 | optional int32 prefetch_queue_capacity = 16 [default=10]; 64 | } 65 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ## AUTOMATIC IMAGE CAPTIONING USING CNN-LSTM DEEP NEURAL NETWORKS AND FLASK [![](https://img.shields.io/github/license/sourcerer-io/hall-of-fame.svg?colorB=ff0000)](https://github.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/blob/master/LICENSE) 2 | 3 | ### Description 4 | 5 | Image caption generation has emerged as a challenging and important research area following ad-vances in statistical language modelling and image recognition. The generation of captions from images has various practical benefits, ranging from aiding the visually impaired, to enabling the automatic and cost-saving labelling of the millions of images uploaded to the Internet every day. The field also brings together state-of-the-art models in Natural Language Processing and Computer Vision, two of the major fields in Artificial Intelligence. 6 | In this model, we has used CNN and LSTM to generate captions for the images and deployed our model using Flask. 7 | 8 | 9 | ### Deployment Procedure 10 | 11 | ## 1.Download and Install Python 3x and make sure to set the path(it is automated most of the times). 12 | ## 2.Download Anaconda IDE and Visual Studio Code. 13 | ## 3.Download Flickr8k dataset through the below link: 14 | 15 | https://illinois.edu/fb/sec/1713398 16 | 17 | Place the dataset files in image-captoin/train_val_data 18 | 19 | 20 | ## 4.Download the following libraries required by the project through the PIP using the following format. 21 | ## PIP INSTALL << LIBRARY NAME >> 22 | • tensorflow 23 | • keras 24 | • numpy 25 | • pandas 26 | • opencv-python 27 | • flask 28 | • flask-caption 29 | • scikit-learn 30 | • nltk 31 | • pytorch 32 | • theano 33 | • corpus 34 | • textblob 35 | • scipy 36 | • matplotlib 37 | 38 | 39 | ## 5.Download the code from the following github repository. 40 | https://github.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask 41 | ## 6.Run app.py and cap.py in the terminal. 42 | ## 7.Open browser and type Localhost:3000. 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | backports.weakref==1.0rc1 2 | bleach==1.5.0 3 | certifi==2017.7.27.1 4 | click==6.7 5 | cycler==0.10.0 6 | decorator==4.1.2 7 | entrypoints==0.2.3 8 | Flask==0.12.2 9 | html5lib==0.9999999 10 | ipykernel==4.6.1 11 | ipython==6.1.0 12 | ipython-genutils==0.2.0 13 | ipywidgets==7.0.0 14 | itsdangerous==0.24 15 | jedi==0.10.3 16 | Jinja2==2.9.6 17 | jsonschema==2.6.0 18 | lxml==3.8.0 19 | Markdown==2.2.0 20 | MarkupSafe==1.0 21 | matplotlib==2.0.2 22 | mistune==0.7.4 23 | nbconvert==5.2.1 24 | nbformat==4.4.0 25 | numpy==1.13.0 26 | olefile==0.44 27 | opencv-python==3.3.0.9 28 | pandocfilters==1.4.2 29 | pexpect==4.2.1 30 | pickleshare==0.7.4 31 | Pillow==4.2.1 32 | prompt-toolkit==1.0.15 33 | protobuf==3.3.0 34 | ptyprocess==0.5.2 35 | Pygments==2.2.0 36 | pyparsing==2.2.0 37 | python-dateutil==2.6.1 38 | pytz==2017.2 39 | pyzmq==16.0.2 40 | qtconsole==4.3.1 41 | simplegeneric==0.8.1 42 | six==1.10.0 43 | tensorflow==1.2.0 44 | terminado==0.6 45 | testpath==0.3.1 46 | tornado==4.5.1 47 | traitlets==4.3.2 48 | urllib3==1.22 49 | wcwidth==0.1.7 50 | Werkzeug==0.12.2 51 | widgetsnbextension==3.0.2 52 | -------------------------------------------------------------------------------- /speech.py: -------------------------------------------------------------------------------- 1 | import soco, time 2 | 3 | from gtts import gTTS 4 | from mutagen.mp3 import MP3 5 | from captionbot import CaptionBot 6 | from soco.snapshot import Snapshot 7 | from flask import Flask, request, render_template, redirect 8 | 9 | language = 'en' 10 | c=CaptionBot() 11 | 12 | app = Flask(__name__) 13 | 14 | @app.route('/') 15 | def index(): 16 | return render_template('index.html') 17 | 18 | @app.route('/play', methods=['POST']) 19 | 20 | class Sonos(): 21 | def play(self, text): 22 | file = request.files['file'] 23 | f = os.path.join(app.config['UPLOAD_FOLDER'], file.filename) 24 | file.save(f) 25 | print(f) 26 | text = c.file_caption(f) 27 | tts = gTTS(text=text, lang='en') 28 | tts.save("sound.mp3") 29 | audio = MP3("sound.mp3") 30 | 31 | # print soco.discover() 32 | # for speaker in soco.discover(timeout=5): 33 | # print speaker 34 | # snap = Snapshot(speaker) 35 | # snap.snapshot() 36 | 37 | # speaker.volume = 50 38 | # speaker.play_uri("http://192.168.0.3:8080/sound.mp3") 39 | 40 | # time.sleep(audio.info.length) 41 | # snap.restore() 42 | 43 | if __name__ == '__main__': 44 | app.run(debug=True, port=2000) 45 | -------------------------------------------------------------------------------- /static.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/static.zip -------------------------------------------------------------------------------- /templates.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/templates.zip -------------------------------------------------------------------------------- /tests/comedor.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/tests/comedor.jpg -------------------------------------------------------------------------------- /tests/familia.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/tests/familia.jpg -------------------------------------------------------------------------------- /tests/fiesta.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/tests/fiesta.jpg -------------------------------------------------------------------------------- /tests/mascotas.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/tests/mascotas.jpg -------------------------------------------------------------------------------- /uploads/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/1.jpg -------------------------------------------------------------------------------- /uploads/1002674143_1b742ab4b8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/1002674143_1b742ab4b8.jpg -------------------------------------------------------------------------------- /uploads/1015584366_dfcec3c85a.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/1015584366_dfcec3c85a.jpg -------------------------------------------------------------------------------- /uploads/1019077836_6fc9b15408.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/1019077836_6fc9b15408.jpg -------------------------------------------------------------------------------- /uploads/10815824_2997e03d76.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/10815824_2997e03d76.jpg -------------------------------------------------------------------------------- /uploads/1110208841_5bb6806afe.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/1110208841_5bb6806afe.jpg -------------------------------------------------------------------------------- /uploads/1260816604_570fc35836.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/1260816604_570fc35836.jpg -------------------------------------------------------------------------------- /uploads/12830823_87d2654e31.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/12830823_87d2654e31.jpg -------------------------------------------------------------------------------- /uploads/17273391_55cfc7d3d4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/17273391_55cfc7d3d4.jpg -------------------------------------------------------------------------------- /uploads/2018-06-sample-gallery.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/2018-06-sample-gallery.png -------------------------------------------------------------------------------- /uploads/2043427251_83b746da8e.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/2043427251_83b746da8e.jpg -------------------------------------------------------------------------------- /uploads/23445819_3a458716c1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/23445819_3a458716c1.jpg -------------------------------------------------------------------------------- /uploads/24.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/24.jpg -------------------------------------------------------------------------------- /uploads/2658009523_b49d611db8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/2658009523_b49d611db8.jpg -------------------------------------------------------------------------------- /uploads/2661567396_cbe4c2e5be.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/2661567396_cbe4c2e5be.jpg -------------------------------------------------------------------------------- /uploads/27782020_4dab210360.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/27782020_4dab210360.jpg -------------------------------------------------------------------------------- /uploads/29871656_938329389659068_5796396155926156498_o.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/29871656_938329389659068_5796396155926156498_o.jpg -------------------------------------------------------------------------------- /uploads/3047751696_78c2efe5e6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/3047751696_78c2efe5e6.jpg -------------------------------------------------------------------------------- /uploads/3134092148_151154139a.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/3134092148_151154139a.jpg -------------------------------------------------------------------------------- /uploads/3262793378_773b21ec19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/3262793378_773b21ec19.jpg -------------------------------------------------------------------------------- /uploads/3301754574_465af5bf6d.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/3301754574_465af5bf6d.jpg -------------------------------------------------------------------------------- /uploads/33108590_d685bfe51c.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/33108590_d685bfe51c.jpg -------------------------------------------------------------------------------- /uploads/3354883962_170d19bfe4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/3354883962_170d19bfe4.jpg -------------------------------------------------------------------------------- /uploads/3422458549_f3f3878dbf.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/3422458549_f3f3878dbf.jpg -------------------------------------------------------------------------------- /uploads/3518687038_964c523958.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/3518687038_964c523958.jpg -------------------------------------------------------------------------------- /uploads/35506150_cbdb630f4f.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/35506150_cbdb630f4f.jpg -------------------------------------------------------------------------------- /uploads/3595216998_0a19efebd0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/3595216998_0a19efebd0.jpg -------------------------------------------------------------------------------- /uploads/3597210806_95b07bb968.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/3597210806_95b07bb968.jpg -------------------------------------------------------------------------------- /uploads/3601508034_5a3bfc905e.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/3601508034_5a3bfc905e.jpg -------------------------------------------------------------------------------- /uploads/3601843201_4809e66909.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/3601843201_4809e66909.jpg -------------------------------------------------------------------------------- /uploads/3607489370_92683861f7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/3607489370_92683861f7.jpg -------------------------------------------------------------------------------- /uploads/3613375729_d0b3c41556.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/3613375729_d0b3c41556.jpg -------------------------------------------------------------------------------- /uploads/3637013_c675de7705.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/3637013_c675de7705.jpg -------------------------------------------------------------------------------- /uploads/3640743904_d14eea0a0b.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/3640743904_d14eea0a0b.jpg -------------------------------------------------------------------------------- /uploads/36422830_55c844bc2d.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/36422830_55c844bc2d.jpg -------------------------------------------------------------------------------- /uploads/3710971182_cb01c97d15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/3710971182_cb01c97d15.jpg -------------------------------------------------------------------------------- /uploads/37375513_2262615307097058_8585448275321028608_o.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/37375513_2262615307097058_8585448275321028608_o.jpg -------------------------------------------------------------------------------- /uploads/403523132_73b9a1a4b3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/403523132_73b9a1a4b3.jpg -------------------------------------------------------------------------------- /uploads/41999070_838089137e.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/41999070_838089137e.jpg -------------------------------------------------------------------------------- /uploads/42637986_135a9786a6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/42637986_135a9786a6.jpg -------------------------------------------------------------------------------- /uploads/42637987_866635edf6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/42637987_866635edf6.jpg -------------------------------------------------------------------------------- /uploads/44129946_9eeb385d77.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/44129946_9eeb385d77.jpg -------------------------------------------------------------------------------- /uploads/460195978_fc522a4979.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/460195978_fc522a4979.jpg -------------------------------------------------------------------------------- /uploads/47871819_db55ac4699.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/47871819_db55ac4699.jpg -------------------------------------------------------------------------------- /uploads/54501196_a9ac9d66f2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/54501196_a9ac9d66f2.jpg -------------------------------------------------------------------------------- /uploads/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/6.jpg -------------------------------------------------------------------------------- /uploads/667626_18933d713e.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/667626_18933d713e.jpg -------------------------------------------------------------------------------- /uploads/72218201_e0e9c7d65b.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/72218201_e0e9c7d65b.jpg -------------------------------------------------------------------------------- /uploads/818340833_7b963c0ee3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/818340833_7b963c0ee3.jpg -------------------------------------------------------------------------------- /uploads/93922153_8d831f7f01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/93922153_8d831f7f01.jpg -------------------------------------------------------------------------------- /uploads/99679241_adc853a5c0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/99679241_adc853a5c0.jpg -------------------------------------------------------------------------------- /uploads/CR3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/CR3.jpg -------------------------------------------------------------------------------- /uploads/DSC00221.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/DSC00221.JPG -------------------------------------------------------------------------------- /uploads/IMG_20180406_162945.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/IMG_20180406_162945.jpg -------------------------------------------------------------------------------- /uploads/IMG_20181003_193709_mh1538575655201.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/IMG_20181003_193709_mh1538575655201.jpg -------------------------------------------------------------------------------- /uploads/IMG_20190409_112950_mh1554795435105.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/IMG_20190409_112950_mh1554795435105.jpg -------------------------------------------------------------------------------- /uploads/Penguins.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/Penguins.jpg -------------------------------------------------------------------------------- /uploads/WhatsApp Image 2018-07-16 at 7.32.07 PM.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/WhatsApp Image 2018-07-16 at 7.32.07 PM.jpeg -------------------------------------------------------------------------------- /uploads/WhatsApp_Image_2018-07-16_at_7.32.07_PM_1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/WhatsApp_Image_2018-07-16_at_7.32.07_PM_1.jpeg -------------------------------------------------------------------------------- /uploads/ai3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/ai3.jpg -------------------------------------------------------------------------------- /uploads/canon-eos-sample-photo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/canon-eos-sample-photo.jpg -------------------------------------------------------------------------------- /uploads/college-grad-temping-feature.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/college-grad-temping-feature.jpg -------------------------------------------------------------------------------- /uploads/comedor.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/comedor.jpg -------------------------------------------------------------------------------- /uploads/crop635w_accomlishedstudent0Small.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/crop635w_accomlishedstudent0Small.jpg -------------------------------------------------------------------------------- /uploads/deep.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/deep.jpg -------------------------------------------------------------------------------- /uploads/depositphotos_14060460-stock-illustration-3d-heart-protection-vector-icon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/depositphotos_14060460-stock-illustration-3d-heart-protection-vector-icon.jpg -------------------------------------------------------------------------------- /uploads/download (1).jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/download (1).jpg -------------------------------------------------------------------------------- /uploads/download.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/download.jpg -------------------------------------------------------------------------------- /uploads/e020fc9a4def5c66ba435e27109a0890.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/e020fc9a4def5c66ba435e27109a0890.jpg -------------------------------------------------------------------------------- /uploads/familia.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/familia.jpg -------------------------------------------------------------------------------- /uploads/fiesta.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/fiesta.jpg -------------------------------------------------------------------------------- /uploads/geek.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/geek.png -------------------------------------------------------------------------------- /uploads/images (1).jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/images (1).jpg -------------------------------------------------------------------------------- /uploads/images (2).jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/images (2).jpg -------------------------------------------------------------------------------- /uploads/images.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/images.jpg -------------------------------------------------------------------------------- /uploads/latest_mobile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/latest_mobile.png -------------------------------------------------------------------------------- /uploads/mascotas.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/mascotas.jpg -------------------------------------------------------------------------------- /uploads/modern_class.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/modern_class.jpg -------------------------------------------------------------------------------- /uploads/molecular-model-colorful_f57072f6-2ab1-11e9-b115-35431bcc9744.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/molecular-model-colorful_f57072f6-2ab1-11e9-b115-35431bcc9744.jpg -------------------------------------------------------------------------------- /uploads/opencv.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/opencv.jpg -------------------------------------------------------------------------------- /uploads/sai.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/sai.jpg -------------------------------------------------------------------------------- /uploads/sample-5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/sample-5.jpg -------------------------------------------------------------------------------- /uploads/trail.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/trail.jpg -------------------------------------------------------------------------------- /uploads/trail2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/trail2.jpg -------------------------------------------------------------------------------- /uploads/yash4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/uploads/yash4.jpg -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/label_map_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/utils/__pycache__/label_map_util.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/visualization_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaswanthpalaghat/Automatic-Image-Captioning-using-CNN-LSTM-deep-neural-networks-and-flask/2dd4be83633ae827cff4818625b57312218bec1f/utils/__pycache__/visualization_utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/category_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Functions for importing/exporting Object Detection categories.""" 17 | import csv 18 | 19 | import tensorflow as tf 20 | 21 | 22 | def load_categories_from_csv_file(csv_path): 23 | """Loads categories from a csv file. 24 | 25 | The CSV file should have one comma delimited numeric category id and string 26 | category name pair per line. For example: 27 | 28 | 0,"cat" 29 | 1,"dog" 30 | 2,"bird" 31 | ... 32 | 33 | Args: 34 | csv_path: Path to the csv file to be parsed into categories. 35 | Returns: 36 | categories: A list of dictionaries representing all possible categories. 37 | The categories will contain an integer 'id' field and a string 38 | 'name' field. 39 | Raises: 40 | ValueError: If the csv file is incorrectly formatted. 41 | """ 42 | categories = [] 43 | 44 | with tf.gfile.Open(csv_path, 'r') as csvfile: 45 | reader = csv.reader(csvfile, delimiter=',', quotechar='"') 46 | for row in reader: 47 | if not row: 48 | continue 49 | 50 | if len(row) != 2: 51 | raise ValueError('Expected 2 fields per row in csv: %s' % ','.join(row)) 52 | 53 | category_id = int(row[0]) 54 | category_name = row[1] 55 | categories.append({'id': category_id, 'name': category_name}) 56 | 57 | return categories 58 | 59 | 60 | def save_categories_to_csv_file(categories, csv_path): 61 | """Saves categories to a csv file. 62 | 63 | Args: 64 | categories: A list of dictionaries representing categories to save to file. 65 | Each category must contain an 'id' and 'name' field. 66 | csv_path: Path to the csv file to be parsed into categories. 67 | """ 68 | categories.sort(key=lambda x: x['id']) 69 | with tf.gfile.Open(csv_path, 'w') as csvfile: 70 | writer = csv.writer(csvfile, delimiter=',', quotechar='"') 71 | for category in categories: 72 | writer.writerow([category['id'], category['name']]) 73 | -------------------------------------------------------------------------------- /utils/category_util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.utils.category_util.""" 17 | import os 18 | 19 | import tensorflow as tf 20 | 21 | from object_detection.utils import category_util 22 | 23 | 24 | class EvalUtilTest(tf.test.TestCase): 25 | 26 | def test_load_categories_from_csv_file(self): 27 | csv_data = """ 28 | 0,"cat" 29 | 1,"dog" 30 | 2,"bird" 31 | """.strip(' ') 32 | csv_path = os.path.join(self.get_temp_dir(), 'test.csv') 33 | with tf.gfile.Open(csv_path, 'wb') as f: 34 | f.write(csv_data) 35 | 36 | categories = category_util.load_categories_from_csv_file(csv_path) 37 | self.assertTrue({'id': 0, 'name': 'cat'} in categories) 38 | self.assertTrue({'id': 1, 'name': 'dog'} in categories) 39 | self.assertTrue({'id': 2, 'name': 'bird'} in categories) 40 | 41 | def test_save_categories_to_csv_file(self): 42 | categories = [ 43 | {'id': 0, 'name': 'cat'}, 44 | {'id': 1, 'name': 'dog'}, 45 | {'id': 2, 'name': 'bird'}, 46 | ] 47 | csv_path = os.path.join(self.get_temp_dir(), 'test.csv') 48 | category_util.save_categories_to_csv_file(categories, csv_path) 49 | saved_categories = category_util.load_categories_from_csv_file(csv_path) 50 | self.assertEqual(saved_categories, categories) 51 | 52 | 53 | if __name__ == '__main__': 54 | tf.test.main() 55 | -------------------------------------------------------------------------------- /utils/dataset_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility functions for creating TFRecord data sets.""" 17 | 18 | import tensorflow as tf 19 | 20 | 21 | def int64_feature(value): 22 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 23 | 24 | 25 | def int64_list_feature(value): 26 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 27 | 28 | 29 | def bytes_feature(value): 30 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 31 | 32 | 33 | def bytes_list_feature(value): 34 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 35 | 36 | 37 | def float_list_feature(value): 38 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 39 | 40 | 41 | def read_examples_list(path): 42 | """Read list of training or validation examples. 43 | 44 | The file is assumed to contain a single example per line where the first 45 | token in the line is an identifier that allows us to find the image and 46 | annotation xml for that example. 47 | 48 | For example, the line: 49 | xyz 3 50 | would allow us to find files xyz.jpg and xyz.xml (the 3 would be ignored). 51 | 52 | Args: 53 | path: absolute path to examples list file. 54 | 55 | Returns: 56 | list of example identifiers (strings). 57 | """ 58 | with tf.gfile.GFile(path) as fid: 59 | lines = fid.readlines() 60 | return [line.strip().split(' ')[0] for line in lines] 61 | 62 | 63 | def recursive_parse_xml_to_dict(xml): 64 | """Recursively parses XML contents to python dict. 65 | 66 | We assume that `object` tags are the only ones that can appear 67 | multiple times at the same level of a tree. 68 | 69 | Args: 70 | xml: xml tree obtained by parsing XML file contents using lxml.etree 71 | 72 | Returns: 73 | Python dictionary holding XML contents. 74 | """ 75 | if not xml: 76 | return {xml.tag: xml.text} 77 | result = {} 78 | for child in xml: 79 | child_result = recursive_parse_xml_to_dict(child) 80 | if child.tag != 'object': 81 | result[child.tag] = child_result[child.tag] 82 | else: 83 | if child.tag not in result: 84 | result[child.tag] = [] 85 | result[child.tag].append(child_result[child.tag]) 86 | return {xml.tag: result} 87 | -------------------------------------------------------------------------------- /utils/dataset_util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.utils.dataset_util.""" 17 | 18 | import os 19 | import tensorflow as tf 20 | 21 | from object_detection.utils import dataset_util 22 | 23 | 24 | class DatasetUtilTest(tf.test.TestCase): 25 | 26 | def test_read_examples_list(self): 27 | example_list_data = """example1 1\nexample2 2""" 28 | example_list_path = os.path.join(self.get_temp_dir(), 'examples.txt') 29 | with tf.gfile.Open(example_list_path, 'wb') as f: 30 | f.write(example_list_data) 31 | 32 | examples = dataset_util.read_examples_list(example_list_path) 33 | self.assertListEqual(['example1', 'example2'], examples) 34 | 35 | 36 | if __name__ == '__main__': 37 | tf.test.main() 38 | -------------------------------------------------------------------------------- /utils/label_map_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Label map utility functions.""" 17 | 18 | import logging 19 | 20 | import tensorflow as tf 21 | from google.protobuf import text_format 22 | from finalproject.protos import string_int_label_map_pb2 23 | 24 | 25 | def _validate_label_map(label_map): 26 | """Checks if a label map is valid. 27 | 28 | Args: 29 | label_map: StringIntLabelMap to validate. 30 | 31 | Raises: 32 | ValueError: if label map is invalid. 33 | """ 34 | for item in label_map.item: 35 | if item.id < 1: 36 | raise ValueError('Label map ids should be >= 1.') 37 | 38 | 39 | def create_category_index(categories): 40 | """Creates dictionary of COCO compatible categories keyed by category id. 41 | 42 | Args: 43 | categories: a list of dicts, each of which has the following keys: 44 | 'id': (required) an integer id uniquely identifying this category. 45 | 'name': (required) string representing category name 46 | e.g., 'cat', 'dog', 'pizza'. 47 | 48 | Returns: 49 | category_index: a dict containing the same entries as categories, but keyed 50 | by the 'id' field of each category. 51 | """ 52 | category_index = {} 53 | for cat in categories: 54 | category_index[cat['id']] = cat 55 | return category_index 56 | 57 | 58 | def convert_label_map_to_categories(label_map, 59 | max_num_classes, 60 | use_display_name=True): 61 | """Loads label map proto and returns categories list compatible with eval. 62 | 63 | This function loads a label map and returns a list of dicts, each of which 64 | has the following keys: 65 | 'id': (required) an integer id uniquely identifying this category. 66 | 'name': (required) string representing category name 67 | e.g., 'cat', 'dog', 'pizza'. 68 | We only allow class into the list if its id-label_id_offset is 69 | between 0 (inclusive) and max_num_classes (exclusive). 70 | If there are several items mapping to the same id in the label map, 71 | we will only keep the first one in the categories list. 72 | 73 | Args: 74 | label_map: a StringIntLabelMapProto or None. If None, a default categories 75 | list is created with max_num_classes categories. 76 | max_num_classes: maximum number of (consecutive) label indices to include. 77 | use_display_name: (boolean) choose whether to load 'display_name' field 78 | as category name. If False or if the display_name field does not exist, 79 | uses 'name' field as category names instead. 80 | Returns: 81 | categories: a list of dictionaries representing all possible categories. 82 | """ 83 | categories = [] 84 | list_of_ids_already_added = [] 85 | if not label_map: 86 | label_id_offset = 1 87 | for class_id in range(max_num_classes): 88 | categories.append({ 89 | 'id': class_id + label_id_offset, 90 | 'name': 'category_{}'.format(class_id + label_id_offset) 91 | }) 92 | return categories 93 | for item in label_map.item: 94 | if not 0 < item.id <= max_num_classes: 95 | logging.info('Ignore item %d since it falls outside of requested ' 96 | 'label range.', item.id) 97 | continue 98 | if use_display_name and item.HasField('display_name'): 99 | name = item.display_name 100 | else: 101 | name = item.name 102 | if item.id not in list_of_ids_already_added: 103 | list_of_ids_already_added.append(item.id) 104 | categories.append({'id': item.id, 'name': name}) 105 | return categories 106 | 107 | 108 | def load_labelmap(path): 109 | """Loads label map proto. 110 | 111 | Args: 112 | path: path to StringIntLabelMap proto text file. 113 | Returns: 114 | a StringIntLabelMapProto 115 | """ 116 | with tf.gfile.GFile(path, 'r') as fid: 117 | label_map_string = fid.read() 118 | label_map = string_int_label_map_pb2.StringIntLabelMap() 119 | try: 120 | text_format.Merge(label_map_string, label_map) 121 | except text_format.ParseError: 122 | label_map.ParseFromString(label_map_string) 123 | _validate_label_map(label_map) 124 | return label_map 125 | 126 | 127 | def get_label_map_dict(label_map_path): 128 | """Reads a label map and returns a dictionary of label names to id. 129 | 130 | Args: 131 | label_map_path: path to label_map. 132 | 133 | Returns: 134 | A dictionary mapping label names to id. 135 | """ 136 | label_map = load_labelmap(label_map_path) 137 | label_map_dict = {} 138 | for item in label_map.item: 139 | label_map_dict[item.name] = item.id 140 | return label_map_dict 141 | -------------------------------------------------------------------------------- /utils/learning_schedules.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Library of common learning rate schedules.""" 17 | 18 | import tensorflow as tf 19 | 20 | 21 | def exponential_decay_with_burnin(global_step, 22 | learning_rate_base, 23 | learning_rate_decay_steps, 24 | learning_rate_decay_factor, 25 | burnin_learning_rate=0.0, 26 | burnin_steps=0): 27 | """Exponential decay schedule with burn-in period. 28 | 29 | In this schedule, learning rate is fixed at burnin_learning_rate 30 | for a fixed period, before transitioning to a regular exponential 31 | decay schedule. 32 | 33 | Args: 34 | global_step: int tensor representing global step. 35 | learning_rate_base: base learning rate. 36 | learning_rate_decay_steps: steps to take between decaying the learning rate. 37 | Note that this includes the number of burn-in steps. 38 | learning_rate_decay_factor: multiplicative factor by which to decay 39 | learning rate. 40 | burnin_learning_rate: initial learning rate during burn-in period. If 41 | 0.0 (which is the default), then the burn-in learning rate is simply 42 | set to learning_rate_base. 43 | burnin_steps: number of steps to use burnin learning rate. 44 | 45 | Returns: 46 | a (scalar) float tensor representing learning rate 47 | """ 48 | if burnin_learning_rate == 0: 49 | burnin_learning_rate = learning_rate_base 50 | post_burnin_learning_rate = tf.train.exponential_decay( 51 | learning_rate_base, 52 | global_step, 53 | learning_rate_decay_steps, 54 | learning_rate_decay_factor, 55 | staircase=True) 56 | return tf.cond( 57 | tf.less(global_step, burnin_steps), 58 | lambda: tf.convert_to_tensor(burnin_learning_rate), 59 | lambda: post_burnin_learning_rate) 60 | 61 | 62 | def manual_stepping(global_step, boundaries, rates): 63 | """Manually stepped learning rate schedule. 64 | 65 | This function provides fine grained control over learning rates. One must 66 | specify a sequence of learning rates as well as a set of integer steps 67 | at which the current learning rate must transition to the next. For example, 68 | if boundaries = [5, 10] and rates = [.1, .01, .001], then the learning 69 | rate returned by this function is .1 for global_step=0,...,4, .01 for 70 | global_step=5...9, and .001 for global_step=10 and onward. 71 | 72 | Args: 73 | global_step: int64 (scalar) tensor representing global step. 74 | boundaries: a list of global steps at which to switch learning 75 | rates. This list is assumed to consist of increasing positive integers. 76 | rates: a list of (float) learning rates corresponding to intervals between 77 | the boundaries. The length of this list must be exactly 78 | len(boundaries) + 1. 79 | 80 | Returns: 81 | a (scalar) float tensor representing learning rate 82 | Raises: 83 | ValueError: if one of the following checks fails: 84 | 1. boundaries is a strictly increasing list of positive integers 85 | 2. len(rates) == len(boundaries) + 1 86 | """ 87 | if any([b < 0 for b in boundaries]) or any( 88 | [not isinstance(b, int) for b in boundaries]): 89 | raise ValueError('boundaries must be a list of positive integers') 90 | if any([bnext <= b for bnext, b in zip(boundaries[1:], boundaries[:-1])]): 91 | raise ValueError('Entries in boundaries must be strictly increasing.') 92 | if any([not isinstance(r, float) for r in rates]): 93 | raise ValueError('Learning rates must be floats') 94 | if len(rates) != len(boundaries) + 1: 95 | raise ValueError('Number of provided learning rates must exceed ' 96 | 'number of boundary points by exactly 1.') 97 | step_boundaries = tf.constant(boundaries, tf.int64) 98 | learning_rates = tf.constant(rates, tf.float32) 99 | unreached_boundaries = tf.reshape(tf.where( 100 | tf.greater(step_boundaries, global_step)), [-1]) 101 | unreached_boundaries = tf.concat([unreached_boundaries, [len(boundaries)]], 0) 102 | index = tf.reshape(tf.reduce_min(unreached_boundaries), [1]) 103 | return tf.reshape(tf.slice(learning_rates, index, [1]), []) 104 | -------------------------------------------------------------------------------- /utils/learning_schedules_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.utils.learning_schedules.""" 17 | import tensorflow as tf 18 | 19 | from object_detection.utils import learning_schedules 20 | 21 | 22 | class LearningSchedulesTest(tf.test.TestCase): 23 | 24 | def testExponentialDecayWithBurnin(self): 25 | global_step = tf.placeholder(tf.int32, []) 26 | learning_rate_base = 1.0 27 | learning_rate_decay_steps = 3 28 | learning_rate_decay_factor = .1 29 | burnin_learning_rate = .5 30 | burnin_steps = 2 31 | exp_rates = [.5, .5, 1, .1, .1, .1, .01, .01] 32 | learning_rate = learning_schedules.exponential_decay_with_burnin( 33 | global_step, learning_rate_base, learning_rate_decay_steps, 34 | learning_rate_decay_factor, burnin_learning_rate, burnin_steps) 35 | with self.test_session() as sess: 36 | output_rates = [] 37 | for input_global_step in range(8): 38 | output_rate = sess.run(learning_rate, 39 | feed_dict={global_step: input_global_step}) 40 | output_rates.append(output_rate) 41 | self.assertAllClose(output_rates, exp_rates) 42 | 43 | def testManualStepping(self): 44 | global_step = tf.placeholder(tf.int64, []) 45 | boundaries = [2, 3, 7] 46 | rates = [1.0, 2.0, 3.0, 4.0] 47 | exp_rates = [1.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0] 48 | learning_rate = learning_schedules.manual_stepping(global_step, boundaries, 49 | rates) 50 | with self.test_session() as sess: 51 | output_rates = [] 52 | for input_global_step in range(10): 53 | output_rate = sess.run(learning_rate, 54 | feed_dict={global_step: input_global_step}) 55 | output_rates.append(output_rate) 56 | self.assertAllClose(output_rates, exp_rates) 57 | 58 | if __name__ == '__main__': 59 | tf.test.main() 60 | -------------------------------------------------------------------------------- /utils/metrics_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.metrics.""" 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | 21 | from object_detection.utils import metrics 22 | 23 | 24 | class MetricsTest(tf.test.TestCase): 25 | 26 | def test_compute_cor_loc(self): 27 | num_gt_imgs_per_class = np.array([100, 1, 5, 1, 1], dtype=int) 28 | num_images_correctly_detected_per_class = np.array([10, 0, 1, 0, 0], 29 | dtype=int) 30 | corloc = metrics.compute_cor_loc(num_gt_imgs_per_class, 31 | num_images_correctly_detected_per_class) 32 | expected_corloc = np.array([0.1, 0, 0.2, 0, 0], dtype=float) 33 | self.assertTrue(np.allclose(corloc, expected_corloc)) 34 | 35 | def test_compute_cor_loc_nans(self): 36 | num_gt_imgs_per_class = np.array([100, 0, 0, 1, 1], dtype=int) 37 | num_images_correctly_detected_per_class = np.array([10, 0, 1, 0, 0], 38 | dtype=int) 39 | corloc = metrics.compute_cor_loc(num_gt_imgs_per_class, 40 | num_images_correctly_detected_per_class) 41 | expected_corloc = np.array([0.1, np.nan, np.nan, 0, 0], dtype=float) 42 | self.assertAllClose(corloc, expected_corloc) 43 | 44 | def test_compute_precision_recall(self): 45 | num_gt = 10 46 | scores = np.array([0.4, 0.3, 0.6, 0.2, 0.7, 0.1], dtype=float) 47 | labels = np.array([0, 1, 1, 0, 0, 1], dtype=bool) 48 | accumulated_tp_count = np.array([0, 1, 1, 2, 2, 3], dtype=float) 49 | expected_precision = accumulated_tp_count / np.array([1, 2, 3, 4, 5, 6]) 50 | expected_recall = accumulated_tp_count / num_gt 51 | precision, recall = metrics.compute_precision_recall(scores, labels, num_gt) 52 | self.assertAllClose(precision, expected_precision) 53 | self.assertAllClose(recall, expected_recall) 54 | 55 | def test_compute_average_precision(self): 56 | precision = np.array([0.8, 0.76, 0.9, 0.65, 0.7, 0.5, 0.55, 0], dtype=float) 57 | recall = np.array([0.3, 0.3, 0.4, 0.4, 0.45, 0.45, 0.5, 0.5], dtype=float) 58 | processed_precision = np.array([0.9, 0.9, 0.9, 0.7, 0.7, 0.55, 0.55, 0], 59 | dtype=float) 60 | recall_interval = np.array([0.3, 0, 0.1, 0, 0.05, 0, 0.05, 0], dtype=float) 61 | expected_mean_ap = np.sum(recall_interval * processed_precision) 62 | mean_ap = metrics.compute_average_precision(precision, recall) 63 | self.assertAlmostEqual(expected_mean_ap, mean_ap) 64 | 65 | def test_compute_precision_recall_and_ap_no_groundtruth(self): 66 | num_gt = 0 67 | scores = np.array([0.4, 0.3, 0.6, 0.2, 0.7, 0.1], dtype=float) 68 | labels = np.array([0, 0, 0, 0, 0, 0], dtype=bool) 69 | expected_precision = None 70 | expected_recall = None 71 | precision, recall = metrics.compute_precision_recall(scores, labels, num_gt) 72 | self.assertEquals(precision, expected_precision) 73 | self.assertEquals(recall, expected_recall) 74 | ap = metrics.compute_average_precision(precision, recall) 75 | self.assertTrue(np.isnan(ap)) 76 | 77 | 78 | if __name__ == '__main__': 79 | tf.test.main() 80 | -------------------------------------------------------------------------------- /utils/np_box_list.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Numpy BoxList classes and functions.""" 17 | 18 | import numpy as np 19 | from six import moves 20 | 21 | 22 | class BoxList(object): 23 | """Box collection. 24 | 25 | BoxList represents a list of bounding boxes as numpy array, where each 26 | bounding box is represented as a row of 4 numbers, 27 | [y_min, x_min, y_max, x_max]. It is assumed that all bounding boxes within a 28 | given list correspond to a single image. 29 | 30 | Optionally, users can add additional related fields (such as 31 | objectness/classification scores). 32 | """ 33 | 34 | def __init__(self, data): 35 | """Constructs box collection. 36 | 37 | Args: 38 | data: a numpy array of shape [N, 4] representing box coordinates 39 | 40 | Raises: 41 | ValueError: if bbox data is not a numpy array 42 | ValueError: if invalid dimensions for bbox data 43 | """ 44 | if not isinstance(data, np.ndarray): 45 | raise ValueError('data must be a numpy array.') 46 | if len(data.shape) != 2 or data.shape[1] != 4: 47 | raise ValueError('Invalid dimensions for box data.') 48 | if data.dtype != np.float32 and data.dtype != np.float64: 49 | raise ValueError('Invalid data type for box data: float is required.') 50 | if not self._is_valid_boxes(data): 51 | raise ValueError('Invalid box data. data must be a numpy array of ' 52 | 'N*[y_min, x_min, y_max, x_max]') 53 | self.data = {'boxes': data} 54 | 55 | def num_boxes(self): 56 | """Return number of boxes held in collections.""" 57 | return self.data['boxes'].shape[0] 58 | 59 | def get_extra_fields(self): 60 | """Return all non-box fields.""" 61 | return [k for k in self.data.keys() if k != 'boxes'] 62 | 63 | def has_field(self, field): 64 | return field in self.data 65 | 66 | def add_field(self, field, field_data): 67 | """Add data to a specified field. 68 | 69 | Args: 70 | field: a string parameter used to speficy a related field to be accessed. 71 | field_data: a numpy array of [N, ...] representing the data associated 72 | with the field. 73 | Raises: 74 | ValueError: if the field is already exist or the dimension of the field 75 | data does not matches the number of boxes. 76 | """ 77 | if self.has_field(field): 78 | raise ValueError('Field ' + field + 'already exists') 79 | if len(field_data.shape) < 1 or field_data.shape[0] != self.num_boxes(): 80 | raise ValueError('Invalid dimensions for field data') 81 | self.data[field] = field_data 82 | 83 | def get(self): 84 | """Convenience function for accesssing box coordinates. 85 | 86 | Returns: 87 | a numpy array of shape [N, 4] representing box corners 88 | """ 89 | return self.get_field('boxes') 90 | 91 | def get_field(self, field): 92 | """Accesses data associated with the specified field in the box collection. 93 | 94 | Args: 95 | field: a string parameter used to speficy a related field to be accessed. 96 | 97 | Returns: 98 | a numpy 1-d array representing data of an associated field 99 | 100 | Raises: 101 | ValueError: if invalid field 102 | """ 103 | if not self.has_field(field): 104 | raise ValueError('field {} does not exist'.format(field)) 105 | return self.data[field] 106 | 107 | def get_coordinates(self): 108 | """Get corner coordinates of boxes. 109 | 110 | Returns: 111 | a list of 4 1-d numpy arrays [y_min, x_min, y_max, x_max] 112 | """ 113 | box_coordinates = self.get() 114 | y_min = box_coordinates[:, 0] 115 | x_min = box_coordinates[:, 1] 116 | y_max = box_coordinates[:, 2] 117 | x_max = box_coordinates[:, 3] 118 | return [y_min, x_min, y_max, x_max] 119 | 120 | def _is_valid_boxes(self, data): 121 | """Check whether data fullfills the format of N*[ymin, xmin, ymax, xmin]. 122 | 123 | Args: 124 | data: a numpy array of shape [N, 4] representing box coordinates 125 | 126 | Returns: 127 | a boolean indicating whether all ymax of boxes are equal or greater than 128 | ymin, and all xmax of boxes are equal or greater than xmin. 129 | """ 130 | if data.shape[0] > 0: 131 | for i in moves.range(data.shape[0]): 132 | if data[i, 0] > data[i, 2] or data[i, 1] > data[i, 3]: 133 | return False 134 | return True 135 | -------------------------------------------------------------------------------- /utils/np_box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Operations for [N, 4] numpy arrays representing bounding boxes. 17 | 18 | Example box operations that are supported: 19 | * Areas: compute bounding box areas 20 | * IOU: pairwise intersection-over-union scores 21 | """ 22 | import numpy as np 23 | 24 | 25 | def area(boxes): 26 | """Computes area of boxes. 27 | 28 | Args: 29 | boxes: Numpy array with shape [N, 4] holding N boxes 30 | 31 | Returns: 32 | a numpy array with shape [N*1] representing box areas 33 | """ 34 | return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 35 | 36 | 37 | def intersection(boxes1, boxes2): 38 | """Compute pairwise intersection areas between boxes. 39 | 40 | Args: 41 | boxes1: a numpy array with shape [N, 4] holding N boxes 42 | boxes2: a numpy array with shape [M, 4] holding M boxes 43 | 44 | Returns: 45 | a numpy array with shape [N*M] representing pairwise intersection area 46 | """ 47 | [y_min1, x_min1, y_max1, x_max1] = np.split(boxes1, 4, axis=1) 48 | [y_min2, x_min2, y_max2, x_max2] = np.split(boxes2, 4, axis=1) 49 | 50 | all_pairs_min_ymax = np.minimum(y_max1, np.transpose(y_max2)) 51 | all_pairs_max_ymin = np.maximum(y_min1, np.transpose(y_min2)) 52 | intersect_heights = np.maximum( 53 | np.zeros(all_pairs_max_ymin.shape), 54 | all_pairs_min_ymax - all_pairs_max_ymin) 55 | all_pairs_min_xmax = np.minimum(x_max1, np.transpose(x_max2)) 56 | all_pairs_max_xmin = np.maximum(x_min1, np.transpose(x_min2)) 57 | intersect_widths = np.maximum( 58 | np.zeros(all_pairs_max_xmin.shape), 59 | all_pairs_min_xmax - all_pairs_max_xmin) 60 | return intersect_heights * intersect_widths 61 | 62 | 63 | def iou(boxes1, boxes2): 64 | """Computes pairwise intersection-over-union between box collections. 65 | 66 | Args: 67 | boxes1: a numpy array with shape [N, 4] holding N boxes. 68 | boxes2: a numpy array with shape [M, 4] holding N boxes. 69 | 70 | Returns: 71 | a numpy array with shape [N, M] representing pairwise iou scores. 72 | """ 73 | intersect = intersection(boxes1, boxes2) 74 | area1 = area(boxes1) 75 | area2 = area(boxes2) 76 | union = np.expand_dims(area1, axis=1) + np.expand_dims( 77 | area2, axis=0) - intersect 78 | return intersect / union 79 | 80 | 81 | def ioa(boxes1, boxes2): 82 | """Computes pairwise intersection-over-area between box collections. 83 | 84 | Intersection-over-area (ioa) between two boxes box1 and box2 is defined as 85 | their intersection area over box2's area. Note that ioa is not symmetric, 86 | that is, IOA(box1, box2) != IOA(box2, box1). 87 | 88 | Args: 89 | boxes1: a numpy array with shape [N, 4] holding N boxes. 90 | boxes2: a numpy array with shape [M, 4] holding N boxes. 91 | 92 | Returns: 93 | a numpy array with shape [N, M] representing pairwise ioa scores. 94 | """ 95 | intersect = intersection(boxes1, boxes2) 96 | areas = np.expand_dims(area(boxes2), axis=0) 97 | return intersect / areas 98 | -------------------------------------------------------------------------------- /utils/np_box_ops_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.np_box_ops.""" 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | 21 | from object_detection.utils import np_box_ops 22 | 23 | 24 | class BoxOpsTests(tf.test.TestCase): 25 | 26 | def setUp(self): 27 | boxes1 = np.array([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]], 28 | dtype=float) 29 | boxes2 = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0], 30 | [0.0, 0.0, 20.0, 20.0]], 31 | dtype=float) 32 | self.boxes1 = boxes1 33 | self.boxes2 = boxes2 34 | 35 | def testArea(self): 36 | areas = np_box_ops.area(self.boxes1) 37 | expected_areas = np.array([6.0, 5.0], dtype=float) 38 | self.assertAllClose(expected_areas, areas) 39 | 40 | def testIntersection(self): 41 | intersection = np_box_ops.intersection(self.boxes1, self.boxes2) 42 | expected_intersection = np.array([[2.0, 0.0, 6.0], [1.0, 0.0, 5.0]], 43 | dtype=float) 44 | self.assertAllClose(intersection, expected_intersection) 45 | 46 | def testIOU(self): 47 | iou = np_box_ops.iou(self.boxes1, self.boxes2) 48 | expected_iou = np.array([[2.0 / 16.0, 0.0, 6.0 / 400.0], 49 | [1.0 / 16.0, 0.0, 5.0 / 400.0]], 50 | dtype=float) 51 | self.assertAllClose(iou, expected_iou) 52 | 53 | def testIOA(self): 54 | boxes1 = np.array([[0.25, 0.25, 0.75, 0.75], 55 | [0.0, 0.0, 0.5, 0.75]], 56 | dtype=np.float32) 57 | boxes2 = np.array([[0.5, 0.25, 1.0, 1.0], 58 | [0.0, 0.0, 1.0, 1.0]], 59 | dtype=np.float32) 60 | ioa21 = np_box_ops.ioa(boxes2, boxes1) 61 | expected_ioa21 = np.array([[0.5, 0.0], 62 | [1.0, 1.0]], 63 | dtype=np.float32) 64 | self.assertAllClose(ioa21, expected_ioa21) 65 | 66 | 67 | if __name__ == '__main__': 68 | tf.test.main() 69 | -------------------------------------------------------------------------------- /utils/shape_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utils used to manipulate tensor shapes.""" 17 | 18 | import tensorflow as tf 19 | 20 | 21 | def _is_tensor(t): 22 | """Returns a boolean indicating whether the input is a tensor. 23 | 24 | Args: 25 | t: the input to be tested. 26 | 27 | Returns: 28 | a boolean that indicates whether t is a tensor. 29 | """ 30 | return isinstance(t, (tf.Tensor, tf.SparseTensor, tf.Variable)) 31 | 32 | 33 | def _set_dim_0(t, d0): 34 | """Sets the 0-th dimension of the input tensor. 35 | 36 | Args: 37 | t: the input tensor, assuming the rank is at least 1. 38 | d0: an integer indicating the 0-th dimension of the input tensor. 39 | 40 | Returns: 41 | the tensor t with the 0-th dimension set. 42 | """ 43 | t_shape = t.get_shape().as_list() 44 | t_shape[0] = d0 45 | t.set_shape(t_shape) 46 | return t 47 | 48 | 49 | def pad_tensor(t, length): 50 | """Pads the input tensor with 0s along the first dimension up to the length. 51 | 52 | Args: 53 | t: the input tensor, assuming the rank is at least 1. 54 | length: a tensor of shape [1] or an integer, indicating the first dimension 55 | of the input tensor t after padding, assuming length <= t.shape[0]. 56 | 57 | Returns: 58 | padded_t: the padded tensor, whose first dimension is length. If the length 59 | is an integer, the first dimension of padded_t is set to length 60 | statically. 61 | """ 62 | t_rank = tf.rank(t) 63 | t_shape = tf.shape(t) 64 | t_d0 = t_shape[0] 65 | pad_d0 = tf.expand_dims(length - t_d0, 0) 66 | pad_shape = tf.cond( 67 | tf.greater(t_rank, 1), lambda: tf.concat([pad_d0, t_shape[1:]], 0), 68 | lambda: tf.expand_dims(length - t_d0, 0)) 69 | padded_t = tf.concat([t, tf.zeros(pad_shape, dtype=t.dtype)], 0) 70 | if not _is_tensor(length): 71 | padded_t = _set_dim_0(padded_t, length) 72 | return padded_t 73 | 74 | 75 | def clip_tensor(t, length): 76 | """Clips the input tensor along the first dimension up to the length. 77 | 78 | Args: 79 | t: the input tensor, assuming the rank is at least 1. 80 | length: a tensor of shape [1] or an integer, indicating the first dimension 81 | of the input tensor t after clipping, assuming length <= t.shape[0]. 82 | 83 | Returns: 84 | clipped_t: the clipped tensor, whose first dimension is length. If the 85 | length is an integer, the first dimension of clipped_t is set to length 86 | statically. 87 | """ 88 | clipped_t = tf.gather(t, tf.range(length)) 89 | if not _is_tensor(length): 90 | clipped_t = _set_dim_0(clipped_t, length) 91 | return clipped_t 92 | 93 | 94 | def pad_or_clip_tensor(t, length): 95 | """Pad or clip the input tensor along the first dimension. 96 | 97 | Args: 98 | t: the input tensor, assuming the rank is at least 1. 99 | length: a tensor of shape [1] or an integer, indicating the first dimension 100 | of the input tensor t after processing. 101 | 102 | Returns: 103 | processed_t: the processed tensor, whose first dimension is length. If the 104 | length is an integer, the first dimension of the processed tensor is set 105 | to length statically. 106 | """ 107 | processed_t = tf.cond( 108 | tf.greater(tf.shape(t)[0], length), 109 | lambda: clip_tensor(t, length), 110 | lambda: pad_tensor(t, length)) 111 | if not _is_tensor(length): 112 | processed_t = _set_dim_0(processed_t, length) 113 | return processed_t 114 | 115 | 116 | def combined_static_and_dynamic_shape(tensor): 117 | """Returns a list containing static and dynamic values for the dimensions. 118 | 119 | Returns a list of static and dynamic values for shape dimensions. This is 120 | useful to preserve static shapes when available in reshape operation. 121 | 122 | Args: 123 | tensor: A tensor of any type. 124 | 125 | Returns: 126 | A list of size tensor.shape.ndims containing integers or a scalar tensor. 127 | """ 128 | static_shape = tensor.shape.as_list() 129 | dynamic_shape = tf.shape(tensor) 130 | combined_shape = [] 131 | for index, dim in enumerate(static_shape): 132 | if dim is not None: 133 | combined_shape.append(dim) 134 | else: 135 | combined_shape.append(dynamic_shape[index]) 136 | return combined_shape 137 | -------------------------------------------------------------------------------- /utils/static_shape.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Helper functions to access TensorShape values. 17 | 18 | The rank 4 tensor_shape must be of the form [batch_size, height, width, depth]. 19 | """ 20 | 21 | 22 | def get_batch_size(tensor_shape): 23 | """Returns batch size from the tensor shape. 24 | 25 | Args: 26 | tensor_shape: A rank 4 TensorShape. 27 | 28 | Returns: 29 | An integer representing the batch size of the tensor. 30 | """ 31 | tensor_shape.assert_has_rank(rank=4) 32 | return tensor_shape[0].value 33 | 34 | 35 | def get_height(tensor_shape): 36 | """Returns height from the tensor shape. 37 | 38 | Args: 39 | tensor_shape: A rank 4 TensorShape. 40 | 41 | Returns: 42 | An integer representing the height of the tensor. 43 | """ 44 | tensor_shape.assert_has_rank(rank=4) 45 | return tensor_shape[1].value 46 | 47 | 48 | def get_width(tensor_shape): 49 | """Returns width from the tensor shape. 50 | 51 | Args: 52 | tensor_shape: A rank 4 TensorShape. 53 | 54 | Returns: 55 | An integer representing the width of the tensor. 56 | """ 57 | tensor_shape.assert_has_rank(rank=4) 58 | return tensor_shape[2].value 59 | 60 | 61 | def get_depth(tensor_shape): 62 | """Returns depth from the tensor shape. 63 | 64 | Args: 65 | tensor_shape: A rank 4 TensorShape. 66 | 67 | Returns: 68 | An integer representing the depth of the tensor. 69 | """ 70 | tensor_shape.assert_has_rank(rank=4) 71 | return tensor_shape[3].value 72 | -------------------------------------------------------------------------------- /utils/static_shape_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.utils.static_shape.""" 17 | 18 | import tensorflow as tf 19 | 20 | from object_detection.utils import static_shape 21 | 22 | 23 | class StaticShapeTest(tf.test.TestCase): 24 | 25 | def test_return_correct_batchSize(self): 26 | tensor_shape = tf.TensorShape(dims=[32, 299, 384, 3]) 27 | self.assertEqual(32, static_shape.get_batch_size(tensor_shape)) 28 | 29 | def test_return_correct_height(self): 30 | tensor_shape = tf.TensorShape(dims=[32, 299, 384, 3]) 31 | self.assertEqual(299, static_shape.get_height(tensor_shape)) 32 | 33 | def test_return_correct_width(self): 34 | tensor_shape = tf.TensorShape(dims=[32, 299, 384, 3]) 35 | self.assertEqual(384, static_shape.get_width(tensor_shape)) 36 | 37 | def test_return_correct_depth(self): 38 | tensor_shape = tf.TensorShape(dims=[32, 299, 384, 3]) 39 | self.assertEqual(3, static_shape.get_depth(tensor_shape)) 40 | 41 | def test_die_on_tensor_shape_with_rank_three(self): 42 | tensor_shape = tf.TensorShape(dims=[32, 299, 384]) 43 | with self.assertRaises(ValueError): 44 | static_shape.get_batch_size(tensor_shape) 45 | static_shape.get_height(tensor_shape) 46 | static_shape.get_width(tensor_shape) 47 | static_shape.get_depth(tensor_shape) 48 | 49 | if __name__ == '__main__': 50 | tf.test.main() 51 | -------------------------------------------------------------------------------- /utils/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Contains functions which are convenient for unit testing.""" 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | from object_detection.core import anchor_generator 21 | from object_detection.core import box_coder 22 | from object_detection.core import box_list 23 | from object_detection.core import box_predictor 24 | from object_detection.core import matcher 25 | from object_detection.utils import shape_utils 26 | 27 | 28 | class MockBoxCoder(box_coder.BoxCoder): 29 | """Simple `difference` BoxCoder.""" 30 | 31 | @property 32 | def code_size(self): 33 | return 4 34 | 35 | def _encode(self, boxes, anchors): 36 | return boxes.get() - anchors.get() 37 | 38 | def _decode(self, rel_codes, anchors): 39 | return box_list.BoxList(rel_codes + anchors.get()) 40 | 41 | 42 | class MockBoxPredictor(box_predictor.BoxPredictor): 43 | """Simple box predictor that ignores inputs and outputs all zeros.""" 44 | 45 | def __init__(self, is_training, num_classes): 46 | super(MockBoxPredictor, self).__init__(is_training, num_classes) 47 | 48 | def _predict(self, image_features, num_predictions_per_location): 49 | combined_feature_shape = shape_utils.combined_static_and_dynamic_shape( 50 | image_features) 51 | batch_size = combined_feature_shape[0] 52 | num_anchors = (combined_feature_shape[1] * combined_feature_shape[2]) 53 | code_size = 4 54 | zero = tf.reduce_sum(0 * image_features) 55 | box_encodings = zero + tf.zeros( 56 | (batch_size, num_anchors, 1, code_size), dtype=tf.float32) 57 | class_predictions_with_background = zero + tf.zeros( 58 | (batch_size, num_anchors, self.num_classes + 1), dtype=tf.float32) 59 | return {box_predictor.BOX_ENCODINGS: box_encodings, 60 | box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND: 61 | class_predictions_with_background} 62 | 63 | 64 | class MockAnchorGenerator(anchor_generator.AnchorGenerator): 65 | """Mock anchor generator.""" 66 | 67 | def name_scope(self): 68 | return 'MockAnchorGenerator' 69 | 70 | def num_anchors_per_location(self): 71 | return [1] 72 | 73 | def _generate(self, feature_map_shape_list): 74 | num_anchors = sum([shape[0] * shape[1] for shape in feature_map_shape_list]) 75 | return box_list.BoxList(tf.zeros((num_anchors, 4), dtype=tf.float32)) 76 | 77 | 78 | class MockMatcher(matcher.Matcher): 79 | """Simple matcher that matches first anchor to first groundtruth box.""" 80 | 81 | def _match(self, similarity_matrix): 82 | return tf.constant([0, -1, -1, -1], dtype=tf.int32) 83 | 84 | 85 | def create_diagonal_gradient_image(height, width, depth): 86 | """Creates pyramid image. Useful for testing. 87 | 88 | For example, pyramid_image(5, 6, 1) looks like: 89 | # [[[ 5. 4. 3. 2. 1. 0.] 90 | # [ 6. 5. 4. 3. 2. 1.] 91 | # [ 7. 6. 5. 4. 3. 2.] 92 | # [ 8. 7. 6. 5. 4. 3.] 93 | # [ 9. 8. 7. 6. 5. 4.]]] 94 | 95 | Args: 96 | height: height of image 97 | width: width of image 98 | depth: depth of image 99 | 100 | Returns: 101 | pyramid image 102 | """ 103 | row = np.arange(height) 104 | col = np.arange(width)[::-1] 105 | image_layer = np.expand_dims(row, 1) + col 106 | image_layer = np.expand_dims(image_layer, 2) 107 | 108 | image = image_layer 109 | for i in range(1, depth): 110 | image = np.concatenate((image, image_layer * pow(10, i)), 2) 111 | 112 | return image.astype(np.float32) 113 | 114 | 115 | def create_random_boxes(num_boxes, max_height, max_width): 116 | """Creates random bounding boxes of specific maximum height and width. 117 | 118 | Args: 119 | num_boxes: number of boxes. 120 | max_height: maximum height of boxes. 121 | max_width: maximum width of boxes. 122 | 123 | Returns: 124 | boxes: numpy array of shape [num_boxes, 4]. Each row is in form 125 | [y_min, x_min, y_max, x_max]. 126 | """ 127 | 128 | y_1 = np.random.uniform(size=(1, num_boxes)) * max_height 129 | y_2 = np.random.uniform(size=(1, num_boxes)) * max_height 130 | x_1 = np.random.uniform(size=(1, num_boxes)) * max_width 131 | x_2 = np.random.uniform(size=(1, num_boxes)) * max_width 132 | 133 | boxes = np.zeros(shape=(num_boxes, 4)) 134 | boxes[:, 0] = np.minimum(y_1, y_2) 135 | boxes[:, 1] = np.minimum(x_1, x_2) 136 | boxes[:, 2] = np.maximum(y_1, y_2) 137 | boxes[:, 3] = np.maximum(x_1, x_2) 138 | 139 | return boxes.astype(np.float32) 140 | -------------------------------------------------------------------------------- /utils/test_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.utils.test_utils.""" 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | 21 | from object_detection.utils import test_utils 22 | 23 | 24 | class TestUtilsTest(tf.test.TestCase): 25 | 26 | def test_diagonal_gradient_image(self): 27 | """Tests if a good pyramid image is created.""" 28 | pyramid_image = test_utils.create_diagonal_gradient_image(3, 4, 2) 29 | 30 | # Test which is easy to understand. 31 | expected_first_channel = np.array([[3, 2, 1, 0], 32 | [4, 3, 2, 1], 33 | [5, 4, 3, 2]], dtype=np.float32) 34 | self.assertAllEqual(np.squeeze(pyramid_image[:, :, 0]), 35 | expected_first_channel) 36 | 37 | # Actual test. 38 | expected_image = np.array([[[3, 30], 39 | [2, 20], 40 | [1, 10], 41 | [0, 0]], 42 | [[4, 40], 43 | [3, 30], 44 | [2, 20], 45 | [1, 10]], 46 | [[5, 50], 47 | [4, 40], 48 | [3, 30], 49 | [2, 20]]], dtype=np.float32) 50 | 51 | self.assertAllEqual(pyramid_image, expected_image) 52 | 53 | def test_random_boxes(self): 54 | """Tests if valid random boxes are created.""" 55 | num_boxes = 1000 56 | max_height = 3 57 | max_width = 5 58 | boxes = test_utils.create_random_boxes(num_boxes, 59 | max_height, 60 | max_width) 61 | 62 | true_column = np.ones(shape=(num_boxes)) == 1 63 | self.assertAllEqual(boxes[:, 0] < boxes[:, 2], true_column) 64 | self.assertAllEqual(boxes[:, 1] < boxes[:, 3], true_column) 65 | 66 | self.assertTrue(boxes[:, 0].min() >= 0) 67 | self.assertTrue(boxes[:, 1].min() >= 0) 68 | self.assertTrue(boxes[:, 2].max() <= max_height) 69 | self.assertTrue(boxes[:, 3].max() <= max_width) 70 | 71 | 72 | if __name__ == '__main__': 73 | tf.test.main() 74 | -------------------------------------------------------------------------------- /utils/variables_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Helper functions for manipulating collections of variables during training. 17 | """ 18 | import logging 19 | import re 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | # TODO: Consider replacing with tf.contrib.filter_variables in 27 | # tensorflow/contrib/framework/python/ops/variables.py 28 | def filter_variables(variables, filter_regex_list, invert=False): 29 | """Filters out the variables matching the filter_regex. 30 | 31 | Filter out the variables whose name matches the any of the regular 32 | expressions in filter_regex_list and returns the remaining variables. 33 | Optionally, if invert=True, the complement set is returned. 34 | 35 | Args: 36 | variables: a list of tensorflow variables. 37 | filter_regex_list: a list of string regular expressions. 38 | invert: (boolean). If True, returns the complement of the filter set; that 39 | is, all variables matching filter_regex are kept and all others discarded. 40 | 41 | Returns: 42 | a list of filtered variables. 43 | """ 44 | kept_vars = [] 45 | variables_to_ignore_patterns = filter(None, filter_regex_list) 46 | for var in variables: 47 | add = True 48 | for pattern in variables_to_ignore_patterns: 49 | if re.match(pattern, var.op.name): 50 | add = False 51 | break 52 | if add != invert: 53 | kept_vars.append(var) 54 | return kept_vars 55 | 56 | 57 | def multiply_gradients_matching_regex(grads_and_vars, regex_list, multiplier): 58 | """Multiply gradients whose variable names match a regular expression. 59 | 60 | Args: 61 | grads_and_vars: A list of gradient to variable pairs (tuples). 62 | regex_list: A list of string regular expressions. 63 | multiplier: A (float) multiplier to apply to each gradient matching the 64 | regular expression. 65 | 66 | Returns: 67 | grads_and_vars: A list of gradient to variable pairs (tuples). 68 | """ 69 | variables = [pair[1] for pair in grads_and_vars] 70 | matching_vars = filter_variables(variables, regex_list, invert=True) 71 | for var in matching_vars: 72 | logging.info('Applying multiplier %f to variable [%s]', 73 | multiplier, var.op.name) 74 | grad_multipliers = {var: float(multiplier) for var in matching_vars} 75 | return slim.learning.multiply_gradients(grads_and_vars, 76 | grad_multipliers) 77 | 78 | 79 | def freeze_gradients_matching_regex(grads_and_vars, regex_list): 80 | """Freeze gradients whose variable names match a regular expression. 81 | 82 | Args: 83 | grads_and_vars: A list of gradient to variable pairs (tuples). 84 | regex_list: A list of string regular expressions. 85 | 86 | Returns: 87 | grads_and_vars: A list of gradient to variable pairs (tuples) that do not 88 | contain the variables and gradients matching the regex. 89 | """ 90 | variables = [pair[1] for pair in grads_and_vars] 91 | matching_vars = filter_variables(variables, regex_list, invert=True) 92 | kept_grads_and_vars = [pair for pair in grads_and_vars 93 | if pair[1] not in matching_vars] 94 | for var in matching_vars: 95 | logging.info('Freezing variable [%s]', var.op.name) 96 | return kept_grads_and_vars 97 | 98 | 99 | def get_variables_available_in_checkpoint(variables, checkpoint_path): 100 | """Returns the subset of variables available in the checkpoint. 101 | 102 | Inspects given checkpoint and returns the subset of variables that are 103 | available in it. 104 | 105 | TODO: force input and output to be a dictionary. 106 | 107 | Args: 108 | variables: a list or dictionary of variables to find in checkpoint. 109 | checkpoint_path: path to the checkpoint to restore variables from. 110 | 111 | Returns: 112 | A list or dictionary of variables. 113 | Raises: 114 | ValueError: if `variables` is not a list or dict. 115 | """ 116 | if isinstance(variables, list): 117 | variable_names_map = {variable.op.name: variable for variable in variables} 118 | elif isinstance(variables, dict): 119 | variable_names_map = variables 120 | else: 121 | raise ValueError('`variables` is expected to be a list or dict.') 122 | ckpt_reader = tf.train.NewCheckpointReader(checkpoint_path) 123 | ckpt_vars = ckpt_reader.get_variable_to_shape_map().keys() 124 | vars_in_ckpt = {} 125 | for variable_name, variable in sorted(variable_names_map.items()): 126 | if variable_name in ckpt_vars: 127 | vars_in_ckpt[variable_name] = variable 128 | else: 129 | logging.warning('Variable [%s] not available in checkpoint', 130 | variable_name) 131 | if isinstance(variables, list): 132 | return vars_in_ckpt.values() 133 | return vars_in_ckpt 134 | --------------------------------------------------------------------------------