├── .gitignore ├── models_hash └── model_hash.json ├── src └── main │ ├── inference │ ├── __init__.py │ ├── ConfigurationSchema.json │ ├── inference_engines_factory.py │ ├── base_error.py │ ├── errors.py │ ├── exceptions.py │ ├── base_inference_engine.py │ └── tensorflow_detection.py │ ├── .gitignore │ ├── object_detection │ ├── core │ │ ├── __init__.py │ │ ├── box_list.py │ │ └── standard_fields.py │ ├── image1.jpg │ ├── utils │ │ ├── static_shape.py │ │ ├── label_map_util.py │ │ ├── shape_utils.py │ │ ├── visualization_utils.py │ │ └── ops.py │ └── protos │ │ └── string_int_label_map_pb2.py │ ├── fonts │ └── DejaVuSans.ttf │ ├── models.py │ ├── ocr.py │ ├── deep_learning_service.py │ └── start.py ├── docs ├── 1.gif ├── 2.gif ├── 3.gif ├── 4.gif ├── 5.gif ├── tcpu.png ├── tcpu2.png ├── TCPU20req.png ├── TCPU40req.png ├── nvidia-smi.gif ├── swagger_endpoints.png └── uml │ ├── InferenceClassDiagram.png │ ├── InferenceSequenceDiagram.png │ ├── InferenceSequenceDiagram.xml │ └── InferenceClassDiagram.drawio ├── models └── .gitignore ├── docker ├── requirements.txt └── dockerfile ├── cpu-inference.yaml ├── install_prerequisites.sh ├── README-docker_swarm.md ├── README.md └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | -------------------------------------------------------------------------------- /models_hash/model_hash.json: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /src/main/inference/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/main/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.log 3 | -------------------------------------------------------------------------------- /src/main/object_detection/core/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/1.gif -------------------------------------------------------------------------------- /docs/2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/2.gif -------------------------------------------------------------------------------- /docs/3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/3.gif -------------------------------------------------------------------------------- /docs/4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/4.gif -------------------------------------------------------------------------------- /docs/5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/5.gif -------------------------------------------------------------------------------- /docs/tcpu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/tcpu.png -------------------------------------------------------------------------------- /models/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | 6 | -------------------------------------------------------------------------------- /docs/tcpu2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/tcpu2.png -------------------------------------------------------------------------------- /docs/TCPU20req.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/TCPU20req.png -------------------------------------------------------------------------------- /docs/TCPU40req.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/TCPU40req.png -------------------------------------------------------------------------------- /docs/nvidia-smi.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/nvidia-smi.gif -------------------------------------------------------------------------------- /docs/swagger_endpoints.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/swagger_endpoints.png -------------------------------------------------------------------------------- /src/main/fonts/DejaVuSans.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/src/main/fonts/DejaVuSans.ttf -------------------------------------------------------------------------------- /docs/uml/InferenceClassDiagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/uml/InferenceClassDiagram.png -------------------------------------------------------------------------------- /src/main/object_detection/image1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/src/main/object_detection/image1.jpg -------------------------------------------------------------------------------- /docs/uml/InferenceSequenceDiagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/uml/InferenceSequenceDiagram.png -------------------------------------------------------------------------------- /docker/requirements.txt: -------------------------------------------------------------------------------- 1 | aiofiles 2 | celery 3 | fastapi 4 | h5py 5 | matplotlib 6 | numpy 7 | opencv-python 8 | python-multipart 9 | pandas 10 | Pillow 11 | python-socketio 12 | requests 13 | scipy 14 | sklearn 15 | socketIO-client-nexus 16 | tensorflow==1.13.1 17 | uvicorn 18 | jsonschema 19 | pytz 20 | pytesseract 21 | 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /docker/dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.6 2 | 3 | LABEL maintainer="antoine.charbel@inmind.ai" 4 | 5 | COPY docker/requirements.txt . 6 | COPY src/main /main 7 | 8 | RUN apt-get update && apt-get install -y tesseract-ocr 9 | 10 | RUN pip install -r requirements.txt 11 | 12 | WORKDIR /main 13 | 14 | CMD ["uvicorn", "start:app", "--host", "0.0.0.0", "--port", "4343"] 15 | -------------------------------------------------------------------------------- /cpu-inference.yaml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | 3 | services: 4 | api: 5 | ports: 6 | - "4343:4343" 7 | image: tensorflow_inference_api_cpu 8 | volumes: 9 | - "/mnt/models:/models" 10 | deploy: 11 | replicas: 1 12 | update_config: 13 | parallelism: 2 14 | delay: 10s 15 | restart_policy: 16 | condition: on-failure 17 | -------------------------------------------------------------------------------- /src/main/models.py: -------------------------------------------------------------------------------- 1 | class ApiResponse: 2 | 3 | def __init__(self, success=True, data=None, error=None): 4 | """ 5 | Defines the response shape 6 | :param success: A boolean that returns if the request has succeeded or not 7 | :param data: The model's response 8 | :param error: The error in case an exception was raised 9 | """ 10 | self.data = data 11 | self.error = error.__str__() if error is not None else '' 12 | self.success = success 13 | -------------------------------------------------------------------------------- /install_prerequisites.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This will install docker following [https://docs.docker.com/install/linux/docker-ce/ubuntu/] 4 | sudo apt-get remove docker docker-engine docker.io 5 | sudo apt-get update 6 | 7 | sudo apt-get install \ 8 | apt-transport-https \ 9 | ca-certificates \ 10 | curl \ 11 | software-properties-common 12 | 13 | curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add - 14 | sudo apt-key fingerprint 0EBFCD88 15 | 16 | sudo add-apt-repository \ 17 | "deb [arch=amd64] https://download.docker.com/linux/ubuntu \ 18 | $(lsb_release -cs) \ 19 | stable" 20 | 21 | sudo apt-get update 22 | sudo apt-get install -y docker-ce 23 | sudo groupadd docker 24 | sudo usermod -aG docker ${USER} 25 | docker run hello-world 26 | 27 | -------------------------------------------------------------------------------- /src/main/inference/ConfigurationSchema.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "object", 3 | "properties": { 4 | "inference_engine_name": { 5 | "type": "string" 6 | }, 7 | "confidence": { 8 | "type": "number", 9 | "minimum": 0, 10 | "maximum": 100 11 | }, 12 | "predictions": { 13 | "type": "number", 14 | "minimum": 0 15 | }, 16 | "number_of_classes": { 17 | "type": "number" 18 | }, 19 | "framework": { 20 | "type": "string" 21 | }, 22 | "type": { 23 | "type": "string" 24 | }, 25 | "network": { 26 | "type": "string" 27 | } 28 | }, 29 | "required": [ 30 | "inference_engine_name", 31 | "confidence", 32 | "predictions", 33 | "number_of_classes", 34 | "framework", 35 | "type", 36 | "network" 37 | ] 38 | } -------------------------------------------------------------------------------- /src/main/inference/inference_engines_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from inference.exceptions import ModelNotFound, ApplicationError, InvalidModelConfiguration, InferenceEngineNotFound, ModelNotLoaded 4 | 5 | 6 | class InferenceEngineFactory: 7 | 8 | @staticmethod 9 | def get_engine(path_to_model): 10 | """ 11 | Reads the model's inference engine from the model's configuration and calls the right inference engine class. 12 | :param path_to_model: Model's path 13 | :return: The model's instance 14 | """ 15 | if not os.path.exists(path_to_model): 16 | raise ModelNotFound() 17 | try: 18 | configuration = json.loads(open(os.path.join(path_to_model, 'config.json')).read()) 19 | except Exception: 20 | raise InvalidModelConfiguration('config.json not found or corrupted') 21 | try: 22 | inference_engine_name = configuration['inference_engine_name'] 23 | except Exception: 24 | raise InvalidModelConfiguration('missing inference engine name in config.json') 25 | try: 26 | # import one of the available inference engine class (in this project there's only one), and return a 27 | # model instance 28 | return getattr(__import__(inference_engine_name), 'InferenceEngine')(path_to_model) 29 | except ApplicationError as e: 30 | print(e) 31 | raise e 32 | except Exception as e: 33 | print(e) 34 | raise InferenceEngineNotFound(inference_engine_name) 35 | -------------------------------------------------------------------------------- /src/main/inference/base_error.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from datetime import datetime 3 | from abc import ABC, abstractmethod 4 | 5 | 6 | class AbstractError(ABC): 7 | 8 | def __init__(self): 9 | """ 10 | Sets the logger file, level, and format. 11 | The logging file will contain the logging level, request date, request status, and model response. 12 | """ 13 | self.logger = logging.getLogger('logger') 14 | date = datetime.now().strftime('%Y-%m-%d') 15 | file_path = 'logs/' + date + '.log' 16 | self.handler = logging.FileHandler(file_path) 17 | self.handler.setLevel(logging.INFO) 18 | self.handler.setFormatter(logging.Formatter("%(levelname)s;%(asctime)s;%(message)s")) 19 | self.logger.addHandler(self.handler) 20 | 21 | @abstractmethod 22 | def info(self, message): 23 | """ 24 | Logs an info message to the logging file. 25 | :param message: Containing the request status and the model response 26 | :return: 27 | """ 28 | pass 29 | 30 | @abstractmethod 31 | def warning(self, message): 32 | """ 33 | Logs a warning message to the logging file. 34 | :param message: Containing the request status and the model response 35 | :return: 36 | """ 37 | pass 38 | 39 | @abstractmethod 40 | def error(self, message): 41 | """ 42 | Logs an Error message to the logging file. 43 | :param message: Containing the request status and the model response 44 | :return: 45 | """ 46 | pass 47 | -------------------------------------------------------------------------------- /src/main/inference/errors.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from datetime import datetime 4 | from inference.base_error import AbstractError 5 | 6 | 7 | class Error(AbstractError): 8 | 9 | def __init__(self): 10 | if 'logs' not in os.listdir(): 11 | os.mkdir('logs') 12 | self.date = None 13 | super().__init__() 14 | 15 | def info(self, message): 16 | self.check_date() 17 | self.logger.info(message) 18 | 19 | def warning(self, message): 20 | self.check_date() 21 | self.logger.warning(message) 22 | 23 | def error(self, message): 24 | self.check_date() 25 | self.logger.error(message) 26 | 27 | def check_date(self): 28 | """ 29 | Divides logging per day. Each logging file corresponds to a specific day. 30 | It also removes all logging files exceeding one year. 31 | :return: 32 | """ 33 | self.date = datetime.now().strftime('%Y-%m-%d') 34 | file_path = self.date + '.log' 35 | if file_path not in os.listdir('logs'): 36 | self.logger.removeHandler(self.handler) 37 | self.handler = logging.FileHandler('logs/' + file_path) 38 | self.handler.setLevel(logging.INFO) 39 | self.handler.setFormatter(logging.Formatter("%(levelname)s;%(asctime)s;%(message)s")) 40 | self.logger.addHandler(self.handler) 41 | oldest_log_file = os.listdir('logs')[0] 42 | oldest_date = oldest_log_file.split('.')[0] 43 | a = datetime.strptime(datetime.now().strftime('%Y-%m-%d'), '%Y-%m-%d') 44 | b = datetime.strptime(oldest_date, '%Y-%m-%d') 45 | delta = a - b 46 | if delta.days > 365: 47 | os.remove('logs/' + oldest_log_file) 48 | -------------------------------------------------------------------------------- /src/main/inference/exceptions.py: -------------------------------------------------------------------------------- 1 | __metaclass__ = type 2 | 3 | 4 | class ApplicationError(Exception): 5 | """Base class for other exceptions""" 6 | 7 | def __init__(self, default_message, additional_message=''): 8 | self.default_message = default_message 9 | self.additional_message = additional_message 10 | 11 | def __str__(self): 12 | return self.get_message() 13 | 14 | def get_message(self): 15 | return self.default_message if self.additional_message == '' else "{}: {}".format(self.default_message, 16 | self.additional_message) 17 | 18 | 19 | class InvalidModelConfiguration(ApplicationError): 20 | """Raised when the model's configuration is corrupted""" 21 | 22 | def __init__(self, additional_message=''): 23 | # super('Invalid model configuration', additional_message) 24 | super().__init__('Invalid model configuration', additional_message) 25 | 26 | 27 | class ModelNotFound(ApplicationError): 28 | """Raised when the model is not found""" 29 | 30 | def __init__(self, additional_message=''): 31 | # super('Model not found', additional_message) 32 | super().__init__('Model not found', additional_message) 33 | 34 | 35 | class ModelNotLoaded(ApplicationError): 36 | """Raised when the model is not loaded""" 37 | 38 | def __init__(self, additional_message=''): 39 | # super('Error loading model', additional_message) 40 | super().__init__('Error loading model', additional_message) 41 | 42 | 43 | class InvalidInputData(ApplicationError): 44 | """Raised when the input data is corrupted""" 45 | 46 | def __init__(self, additional_message=''): 47 | # super('Invalid input data', additional_message) 48 | super().__init__('Invalid input data', additional_message) 49 | 50 | 51 | class InferenceEngineNotFound(ApplicationError): 52 | """Raised when the Inference Engine is not found""" 53 | 54 | def __init__(self, additional_message=''): 55 | # super('Inference engine not found', additional_message) 56 | super().__init__('Inference engine not found', additional_message) 57 | -------------------------------------------------------------------------------- /src/main/object_detection/utils/static_shape.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Helper functions to access TensorShape values. 17 | 18 | The rank 4 tensor_shape must be of the form [batch_size, height, width, depth]. 19 | """ 20 | 21 | 22 | def get_batch_size(tensor_shape): 23 | """Returns batch size from the tensor shape. 24 | 25 | Args: 26 | tensor_shape: A rank 4 TensorShape. 27 | 28 | Returns: 29 | An integer representing the batch size of the tensor. 30 | """ 31 | tensor_shape.assert_has_rank(rank=4) 32 | return tensor_shape[0].value 33 | 34 | 35 | def get_height(tensor_shape): 36 | """Returns height from the tensor shape. 37 | 38 | Args: 39 | tensor_shape: A rank 4 TensorShape. 40 | 41 | Returns: 42 | An integer representing the height of the tensor. 43 | """ 44 | tensor_shape.assert_has_rank(rank=4) 45 | return tensor_shape[1].value 46 | 47 | 48 | def get_width(tensor_shape): 49 | """Returns width from the tensor shape. 50 | 51 | Args: 52 | tensor_shape: A rank 4 TensorShape. 53 | 54 | Returns: 55 | An integer representing the width of the tensor. 56 | """ 57 | tensor_shape.assert_has_rank(rank=4) 58 | return tensor_shape[2].value 59 | 60 | 61 | def get_depth(tensor_shape): 62 | """Returns depth from the tensor shape. 63 | 64 | Args: 65 | tensor_shape: A rank 4 TensorShape. 66 | 67 | Returns: 68 | An integer representing the depth of the tensor. 69 | """ 70 | tensor_shape.assert_has_rank(rank=4) 71 | return tensor_shape[3].value 72 | -------------------------------------------------------------------------------- /src/main/inference/base_inference_engine.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from inference.exceptions import InvalidModelConfiguration, ModelNotLoaded, ApplicationError 3 | 4 | 5 | class AbstractInferenceEngine(ABC): 6 | 7 | def __init__(self, model_path): 8 | """ 9 | Takes a model path and calls the load function. 10 | :param model_path: The model's path 11 | :return: 12 | """ 13 | self.labels = [] 14 | self.configuration = {} 15 | self.model_path = model_path 16 | try: 17 | self.validate_configuration() 18 | except ApplicationError as e: 19 | raise e 20 | try: 21 | self.load() 22 | except ApplicationError as e: 23 | raise e 24 | except Exception as e: 25 | print(e) 26 | raise ModelNotLoaded() 27 | 28 | @abstractmethod 29 | def load(self): 30 | """ 31 | Loads the model based on the underlying implementation. 32 | """ 33 | pass 34 | 35 | @abstractmethod 36 | def free(self): 37 | """ 38 | Performs any manual memory implementation required to when unloading a model. 39 | Will be called when the class's destructor is called. 40 | """ 41 | pass 42 | 43 | @abstractmethod 44 | async def infer(self, input_data, draw, predict_batch): 45 | """ 46 | Performs the required inference based on the underlying implementation of this class. 47 | Could be used to return classification predictions, object detection coordinates... 48 | :param predict_batch: Boolean 49 | :param input_data: A single image 50 | :param draw: Used to draw bounding boxes on image instead of returning them 51 | :return: A bounding-box 52 | """ 53 | pass 54 | 55 | @abstractmethod 56 | async def run_batch(self, input_data, draw, predict_batch): 57 | """ 58 | Iterates over images and returns a prediction for each one. 59 | :param predict_batch: Boolean 60 | :param input_data: List of images 61 | :param draw: Used to draw bounding boxes on image instead of returning them 62 | :return: List of bounding-boxes 63 | """ 64 | pass 65 | 66 | @abstractmethod 67 | def validate_configuration(self): 68 | """ 69 | Validates that the model and its files are valid based on the underlying implementation's requirements. 70 | Can check for configuration values, folder structure... 71 | """ 72 | pass 73 | 74 | @abstractmethod 75 | def set_model_configuration(self, data): 76 | """ 77 | Takes the configuration from the config.json file 78 | :param data: Json data 79 | :return: 80 | """ 81 | pass 82 | 83 | @abstractmethod 84 | def validate_json_configuration(self, data): 85 | """ 86 | Validates the configuration of the config.json file. 87 | :param data: Json data 88 | :return: 89 | """ 90 | pass 91 | 92 | def __del__(self): 93 | self.free() 94 | -------------------------------------------------------------------------------- /docs/uml/InferenceSequenceDiagram.xml: -------------------------------------------------------------------------------- 1 | 7VxZd6M2FP41fkyOxM5jlknb02nPdDKnnT5ikG06GLkCZ+mvrwRik4SNEwROTvIwAxchgfTdT3fDC/Nm+/QTCXab33CEkoUBoqeFebswDAgNn/7HJM+lxPZgKViTOOKNGsF9/B/iQsCl+zhCWadhjnGSx7uuMMRpisK8IwsIwY/dZiucdEfdBWskCe7DIJGlf8VRvqneC4Dmws8oXm/40J7NLyyD8Mea4H3Kx1sY5qr4Ky9vg6ov3j7bBBF+bInMTwvzhmCcl0fbpxuUsLmtpq28767nav3cBKX5kBvcZbAMYGgtQ8NdWUtwYZQ9PATJns/Fwry6RWj3GQUkjdP1PSIPcYj44+fP1ZTRN9mxw/02+RyvUBKn9Ox6h0i8RTki9ErCxV8a2fXjJs7R/S4I2a2PFExUtsm3CT2D9JCubx7QW0h9niTBLouXxaiASggK9ySLH9BXlJUwYlK8z9lINzU8iqZsWVDEu6pnHhT9buOQHyfBEiXX9Tre4ASz4VNcvFCWE/wDVUK6vKD4q69UcGFDrOIkabW8K/6YnL7VXbCNE6YdfyISBWnAxVwVoMHPVQMFSbxOqSyka1xMorzoHAcPiOToqSXiIPgJYboA5Jk24VcthwOSK6zHTx9b6K8wvmkB37G4MOAat667blBHDzjwBoKwooE2CkW8tVCyw3GaF+Pb1wv7VoAdJvkGr3EaJG3gNWAA7x0MvTo+GB1uFxw1DbbRoQCHMQY4vv/x5Sr+zc5urn7/+rP/a/Trt1+eL0xbO0AmmEfLswbNIwSOBiUzVUz/S7pC9I1D9CldUwK9C8Ic83E+yP5M9ftksnfBcbJ3FDg0XS1krwDiB9mPRvbmyeiwu2QP7WFkX7UbneyhJeEBRdRc56fNkn9qpMKSN20+Y7zji/QPyvNnPu3BPsddNho+tRnekxAdQXcekDXq66hqxF7r4BIRlAQ55b+O4qsmvLj1ipDgudWAq07T8xcmaFYeAmHpXcF5ENr7XRoRfA16UD5As/L1m7wCDKqd30nozF5H8QM9XLPDryiIKjEdpnVF0ZjNUHpS83CDcYbYqm7Yv1WjJRFvO9pZiAlhmxZd+2LD7btRwD9V5bwLV8J2xKDZKmWalMiDUUJMfd4rfmEbR1GhOz082PiyKu04qLoS+9SuO3/kRdv9VbESuORdnaYCEmatLsQvTLPbA16tMqqnImedhlxJwT2VrUUNFrpxJAldjA/76oy31JPtK8P3jtpXYDJn2v8wrzSaV96p4BAIyBtoXdm6rCvT0o6PCaaxjl8dc6X1xKugNId3mDwGJKptBIL+3aMsX9SBX+XsilvyEuc53tILKI2uWFybyRIc/jiqTgNVoW1Io2SJH9s2dCGgF6qnKp6DztV3rtDFyd/s5BLQfZQLbp/al2+f22ctVBTCXrI9bC4DLeayZCuYttm1hz330vF8x7aha1jQrMI3VY+lI8A7OWA324YQT/O8bkfl20sdjWGJQFndb4LilhKnlRkKqD/EJpDZ0eCfDKelybaK12PhtwdKcyG7F4qH3TswBK963DsJVy4UHDIABgF0NN9M3jlmcdQVC6d8XnciGrFEdReWpUfdT3XfTYlWDrvvR9pr8t/laN89Yt43SHDBNfVG+VZd3r6swMku7wW4hIZrdNbognf1Skgal4IbDIUMwyhu8CGlO2Ru9ni/AVu7eBcU00+vXbHcxHxusYilnPHVdUbHidP1t4K8TGcYtJQedi+vvTjf5Mo2sqlyQ+0RbGTl2ivDIELKaSgaPmIhbyIW4nldDKpiIYauWIgShB/xkBHxcFDNB4PE7u5Gdd3AEWdetKPGi4koiEp7TETLVFr+zFMJfWnmJvAJat8SXgKjjkyUoQrXMsYJVRxasKO+B58WwffoOpL9N47vOnZZ2hQ9x5FcFE9Q9CMJRrG57wt41OCgWHBmwEIPioD1zwCwFSeejtjyTt2QtUWrVQ9kbfMkyFZPpdendmaFLPVTgWWL8WB31Hjw2fCksL6m4Xa7GBoD9oXgi2kIieiR4Cs/sDkBIGUnnweZeZBntU/DPGZB5bca5XF6DLMXRXl8o5touBin1qF6yMr16t6vL8RTQbu1/J/LdS9zDNvye5Exlp6yEvfOYQsJCVrlwyI0dKZHcXN9wc01lflIVWJXl+FtzG3HAMfr7gm+PZLhPTS+78tbxbmW3nliaTiYwG4wVG6uopYu35N0ISUI4zTLgzREr6h+K/SDqscJ7cuHYNsIihoqeVdFc73s84KiOWjCbi1S9dHCK7cW6te7Qr/TbS/zmruU2aDgofljeWiDUs6DKor1lEicSmuO3UWfBbwJaM2dGx+2K+DDnxIfQ0oS1OGnqeFhCR+iWIcqyOX8taO+W0MZjRpnciah3ivv92GIsux9GbkWgELe2FSGl6c0c6sE5hnWnPQp3dHQhZ66lJO/BrEPKuep7fVwvSlXXtaqB4pz8I3s5aTym9ZDKEbrVVqorCLWpoXGrFsui6GYVmfPnTn6WFZVjx5+9CyBgG0xTT40/igWpUkdjRR/NDz1OH3PdaS9JhbpLUrbERTFxfdZb70urVLRj7q0nvmZxascZEgc/G5Dd/VqlfCqFVL8CHwkougbR6/iy4mK3g833pMFYXpO30ZyrDBLnw0xb6VIwWpO129vCaaJWFe/a3I2SXZf1EpNteuOaI5MUYtejSmnKZttv8lUyo7mm9n6q0DFGFs/cJxu5HecXKUQpa6q6PXv+5YhgeBLufpjrvq5EL+Yp3SgTPvulKxvzfLjHu1yK1sI5nvQmpj0vaEcr94dpg4M+QJXO4dqnGQ7zxXvFp5Pc9TWkn9ApI7a7tqKT7V6h9PsncWOHHH6TYXdp/rxJ30MMHe6BroiA9gvStegpzjv1sNRwd/FEACY/LzpkJ08t07GtyIHRqlmJxTR9XOcwzEiq9qz1e2PpY18cbSJCUiuuqgJaFzOmcHUrJR5jNoFAzhijcEopiYUY0x1VmsCa3MWJ/cVVGINpJLzyFmJ0WPnyAcGR9rrcTtt2e2Ug07v0faAPYZjx/ZQfZH6AtuDnja/ol2uXPNT5ean/wE= -------------------------------------------------------------------------------- /docs/uml/InferenceClassDiagram.drawio: -------------------------------------------------------------------------------- 1 | 7V1bk6O2Ev41rjp58BZ348fxXHZzMptsZU7tyebFJYNsk8WICDyeya+PBBIGJF8mHiObaGtrFzWSkLo/Wt2tFh7Yt6uXjxiky88ohPHAMsKXgX03sMgfwyX/UcprSTF9xyspCxyFjLYlPEV/QUY0GHUdhTBrVMwRivMobRIDlCQwyBs0gDHaNKvNUdx8agoWUCA8BSAWqf+PwnxZUn3X2NI/wWix5E82DXZnBoLvC4zWCXvewLLnxZ/y9grwvlj9bAlCtKmR7PuBfYsRysur1cstjClzOdvKdg877lbjxjDJj2owAsEMuCPHh6499mZD1sMziNeMFz8mc0i6CyAbcf7KuUQGn9LL9Sp+wGBFLiebZZTDpxQElL4h6CC0Zb6KSckklwVnIH22QUrV5GkhQKsoYNcxmMF4UrHyFsUIk1sJSugzshyj75VcaLdzlOQPYBXFFG9fIQ5BAhiZYcuk3YI4WiSkEBDmQNLhROQWnz7EOXypkRj3PkK0gjl+JVXY3aHJMcuwPnTHjLDZIsesBL6swWbMmwIG10XV/VZk5IJJ7UgJ8ifVRHhLGIFRHJNJCzLcRKsYFHytyangHKtUCGcZxeEjeEVrOvQsJ6LhpckS4egvUh/wxuQ25ny3PNpbFMctGW4bPdHO2GMwzEizL1weZkV6BFle4SSOQZpFs2JwtMoK4EWUTFCeoxVHFpvVQ+3J21exjzh0Wyi0JCi0DAkILecsIBT1SBhPM4ifI6oabsitOwjTRwhwEiWLJ3ajjU4y+byJzJLpLUlIAMZ5HMM57YEyMiL6/YaRc0Q1U0YUFXn6Y1HnztlSfmX8cXaoNET6m8cFZpZRGMKkwFQOcjCr3pkURUle8NSdkL+Ey7fGB3fgknndkrK5LZO/tDrOyXtK5geiAgWQYH4DKe5l+Nj51h9GDIOI5R0JEO8c+HAEfECybuNpjBbkZV6UELmnJA2KDkHhWt2B4jl5fUx+h/Dr0Ln5/BO+T3/xbiX2x8CaxAiE//mhxMRNGv0Ks5QwRauLLpEx8hUjw5IiI8ry6Yo6IJkGiFqAmIajGCG2DCF4nZQA0fhQjA/7WIv0XPgQbY46PqYzkAdLjRLFKHE7NEulKHH3rzPTwinUq41qnPiqLVXvAE4ClMyjhcaJYi/XUG23jgScHBP50HG5fsTlbIvpqbcH5t4jOixFpC8gcli6UNMwopspRXSOXmlFdaSi2vviX2RETjrisQiMGcggW9DCCJfYeMoxDdFpdHSKDvWhOXFjicXmmg72BKEYgkTjo1t8KA/QmfLYrRiheyQkjY5u0aE+OmdKA7jMpyZKBFYhfr2+KEGI8vicKQ3gLmDeirlo/aEAHcrjcqY0fEvR0Yy0aM9FATqUR+NMMRwnzx7QgZV+BFbcMdMHbw6sVAvd+4NQjPUNQ5BDppjolVZMpyim6i2/npiKKYm20SQnyIIpjyzhSQOjW2CoD6dIom1LkISxRoZaZCgPpEgyuImlGyxh8H1KlxNu6X5FpBuNjm7RoT6QYssCKV5MZ502sOD9uabnSJi9WGLGSF/IvwWDjJI+LARL7zm1exRGQ4YIeo+ZblWf5GrB/i+eHNUIYEUFHoulH2kfcwoLTm42afRIeBO1nzLDAoUTJiCDzOKv2s/atQktbdOWmDKNH0Hi8zN3T/UfMbl8qSout/sekhtSq2DXvGUzOcvA5EtSt4Mg2o++FMkcNVWfyBxab1PurIsVux1ukc8uG29zGC31Lde6LdVc06VUj7c8LqvtmH2C8TOkvQ7OfxjrdMfO9444ymKOZY6ddS7HzvcFQcFwAXnEgKxDS7RACYjvt9QWf7d1HhGVaSGCP2CevzK5gXWOmgKCL1H+G21O1ryy9K125+6F9VwUXnkhIfP9rV74tu2BFrfNitK2XXhDD05uJUsoDwW47gzpesoEm6E1DuAe3tksoS4n7w7ctxZ7zHumnN0LFAxjkEfPsDEOmcyLpmRa4LVWgRkZ256/UMIWf2brKBULbT3sqD50Tqo+dEZuC57lgLdgrWZ+gskgJjVWxzrvE6LS4AMIcsTeFB0u62G4rAm7I51f71wK1ZZmT9JgPizgyNfNFkq1t3OSt1OpgesJnznirs8hTGiV1Q+VZR+rpfxzaSlXuuWYYhTALCNaQG85vptqqt7z61FNrvSkSIjBZjqj7xaRBbl4od8x0dE6RSBRHuX3XSWeI/cCq4JqL9C1r8wLNPymxTxufUjnUP1uHDtXljKlY8Egg4KVeK1R4SL1rZ74pjwkXOZbrTF5/VBSt4A6HkeZO5oW39Sq542qiPbWP1GxLziN10nTaux6sGXfxAxIhIGVJ6BrQ2rW6pytRKlFdL9z2gRc9wPJqiRDdWOomPFHhpJjRqM3FoSMsbFgDMr3Ec7nUIphr30Hg3Ukox+RjFbwtdqnOpipeLYjoK48UxHoeMY7uaru9WUqupJMxWITW5/KUQgL5REMV8xTpCbROqBBUH3iUyU2lGcqemII9CZNYyIqapfqAxh9NmqEL1u4jnKzxhPt62EI52Ad59MVUVfFZ7/1YvZOCsu7vu1kT2L3gjCMqLoCscbIBWBEucHjiXYwy01h8NDn1JUCRLnVM7IEmXeYDcqvazt615QNOjhuG5B/LeL9tgFPE7mjUuTGVYvcOlLkPP6gfOe3adaaI7sOn0P1Tb+Dbd+RmqyCXmggfu7yIB7tC1NBo+uT+QfDbGWgeIcFT0pfII4I16ireu68FOdY7WRfFhrGekH6xxrgWCOEZ6krXpGqA1DVksQesWtJEhrY+5OX7JOqO43BnGfB4zuy16T8LgXuHMWHVZx7USqOf1NIy/x8Njc3hi5Owx0wukUNd0Bnue1TpW9uYLL0kfPqOVtj/gjM7/w9sWtJL7ZG7UPOlrsXjkKDTtKLfTVhjzNlr4vbZyfYkMZ1Ac429quzww1s/4DCdKwTG3QDaa9PkN4J4GvBpWM1QWAdWJaFcxZvbrAfZgcHyHcf+ARLTcEatSzet8H1/v5/P8P16M+R735bfJpN/vtsBZLfQ+Qaat+Gk7Df395KWkVhWCKbZjvs3ACiezy0O4JndhZir9Y8PnlA+L1ak+OxsatTkhqbOu+QOiBltHiQpheMtg3zwhgtnqntBaNFRPMPlKpitJib1QtGu8w13JrO3IlVxWgx6cj88OGYbMiL57XZVB6mJ0La6ZLTYvZOLyDdjr+aY1stosUMmF7wWdDRlqNYR4u51UZPVIfltwwPZ+yqVR6S3y/qC7Ndy24zW8y07JbZPXVcbKMVtLBMWTpal5y2+srp9qqonNO99REbjJaeTuySzf8SD1E5n3vqIJotPHuWYj6L/mFf7A6LfxiXp4Io9g8lP8DSF1a7fG67tUe3rO6pj+i2v/skO7LUJZ9FF7EXfB754wafR4pXQ8mPbvSCzw5LLGh9n1odn3tqdYyE3yNTHZa2JHHpPnDab39CvTKs35/TpIgR/RRTde8jBunyM/3yGCH+DQ== -------------------------------------------------------------------------------- /src/main/ocr.py: -------------------------------------------------------------------------------- 1 | import pytesseract 2 | import unicodedata 3 | import re 4 | import numpy as np 5 | 6 | 7 | 8 | # Define class variables 9 | 10 | bounding_box_order = ["left", "top", "right", "bottom"] 11 | 12 | # This method will take the model bounding box predictions and return the extracted text inside each box 13 | def one_shot_ocr_service(image, output): 14 | # iterate over detections 15 | response = [] 16 | detections = output['bounding-boxes'] 17 | 18 | for i in range(0, len(detections)): 19 | 20 | # crop image for every detection: 21 | coordinates = (detections[i]["coordinates"]) 22 | cropped = image.crop((float(coordinates["left"]), float( 23 | coordinates["top"]), float(coordinates["right"]), float(coordinates["bottom"]))) 24 | 25 | # convert image to grayscale for better accuracy 26 | processed_img=cropped.convert('L') 27 | 28 | # extract text with positive confidence from cropped image 29 | df = pytesseract.image_to_data(processed_img, output_type='data.frame') 30 | valid_df = df[df["conf"] > 0] 31 | extracted_text = " ".join(valid_df["text"].values) 32 | 33 | # process text 34 | extracted_text = str(unicodedata.normalize('NFKD', extracted_text).encode('ascii', 'ignore').decode()).strip().replace("\n", " ").replace( 35 | "...", ".").replace("..", ".").replace('”', ' ').replace('“', ' ').replace("'", ' ').replace('\"', '').replace("alt/1m", "").strip() 36 | extracted_text = re.sub( 37 | '[^A-Za-z0-9.!?,;%:=()\[\]$€&/\- ]+', '', extracted_text) 38 | extracted_text = " ".join(extracted_text.split()) 39 | 40 | # wrap each prediction inside a dictionary 41 | if len(extracted_text) is not 0: 42 | prediction = dict() 43 | prediction["text"] = extracted_text 44 | bounding_box = [coordinates[el] for el in bounding_box_order] 45 | prediction["box"] = bounding_box 46 | prediction["score"] = valid_df["conf"].mean()/100.0 47 | 48 | response.append(prediction) 49 | 50 | return response 51 | 52 | # This method will take an image and return the extracted text from the image 53 | def ocr_service(image): 54 | # convert image to grayscale for better accuracy 55 | processed_img=image.convert('L') 56 | 57 | # Get data including boxes, confidences, line and page numbers 58 | df = pytesseract.image_to_data(processed_img, output_type='data.frame') 59 | valid_df = df[df["conf"] > 0] 60 | 61 | # process text 62 | extracted_text = " ".join(valid_df["text"].values) 63 | extracted_text = str(unicodedata.normalize('NFKD', extracted_text).encode('ascii', 'ignore').decode()).strip().replace("\n", " ").replace( 64 | "...", ".").replace("..", ".").replace('”', ' ').replace('“', ' ').replace("'", ' ').replace('\"', '').replace("alt/1m", "").strip() 65 | extracted_text = re.sub( 66 | '[^A-Za-z0-9.!?,;%:=()\[\]$€&/\- ]+', '', extracted_text) 67 | extracted_text = " ".join(extracted_text.split()) 68 | 69 | # calculate the bounding box data based on pytesseract results 70 | coordinates = {} 71 | index = valid_df.index.values 72 | coordinates["left"] = valid_df.loc[index[0], "left"] 73 | coordinates["top"] = valid_df.loc[index[0], "top"] 74 | coordinates["bottom"] = valid_df.loc[index[-1], 75 | "top"] + valid_df.loc[index[-1], "height"] 76 | coordinates["right"] = valid_df.loc[index[-1], 77 | "left"] + valid_df.loc[index[-1], "width"] 78 | bounding_box = [coordinates[el].item() for el in bounding_box_order] 79 | 80 | # wrap each prediction inside a dictionary 81 | response = {} 82 | response["text"] = extracted_text 83 | response["box"] = bounding_box 84 | response["score"] = valid_df["conf"].mean()/100.0 85 | 86 | return [response] 87 | -------------------------------------------------------------------------------- /README-docker_swarm.md: -------------------------------------------------------------------------------- 1 | # Tensorflow CPU Inference API For Windows and Linux with docker swarm 2 | Please use **docker swarm** only if you need to: 3 | 4 | * Provide redundancy in terms of API containers: In case a container went down, the incoming requests will be redirected to another running instance. 5 | 6 | * Coordinate between the containers: Swarm will orchestrate between the APIs and choose one of them to listen to the incoming request. 7 | 8 | * Scale up the Inference service in order to get a faster prediction especially if there's traffic on the service. 9 | 10 | ## Run the docker container 11 | 12 | Docker swarm can scale up the API into multiple replicas and can be used on one or multiple hosts(Linux users only). In both cases, a docker swarm setup is required for all hosts. 13 | 14 | #### Docker swarm setup 15 | 16 | 1- Initialize Swarm: 17 | 18 | ```sh 19 | docker swarm init 20 | ``` 21 | 22 | 2- On the manager host, open the cpu-inference.yaml file and specify the number of replicas needed. In case you are using multiple hosts (With multiple hosts section), the number of replicas will be divided across all hosts. 23 | 24 | ```yaml 25 | version: "3" 26 | 27 | services: 28 | api: 29 | ports: 30 | - "4343:4343" 31 | image: tensorflow_inference_api_cpu 32 | volumes: 33 | - "/mnt/models:/models" 34 | deploy: 35 | replicas: 1 36 | update_config: 37 | parallelism: 2 38 | delay: 10s 39 | restart_policy: 40 | condition: on-failure 41 | ``` 42 | 43 | **Notes about cpu-inference.yaml:** 44 | 45 | * the volumes field on the left of ":" should be an absolute path, can be changeable by the user, and represents the models directory on your Operating System 46 | * the following volume's field ":/models" should never be changed 47 | 48 | #### With one host 49 | 50 | Deploy the API: 51 | 52 | ```sh 53 | docker stack deploy -c cpu-inference.yaml tensorflow-cpu 54 | ``` 55 | 56 | ![onehost](./docs/tcpu.png) 57 | 58 | #### With multiple hosts (Linux users only) 59 | 60 | 1- **Make sure hosts are reachable on the same network**. 61 | 62 | 2- Choose a host to be the manager and run the following command on the chosen host to generate a token so the other hosts can join: 63 | 64 | ```sh 65 | docker swarm join-token worker 66 | ``` 67 | 68 | A command will appear on your terminal, copy and paste it on the other hosts, as seen in the below image 69 | 70 | 3- Deploy your application using: 71 | 72 | ```sh 73 | docker stack deploy -c cpu-inference.yaml tensorflow-cpu 74 | ``` 75 | 76 | ![multhost](./docs/tcpu2.png) 77 | 78 | #### Useful Commands 79 | 80 | 1- In order to scale up the service to 4 replicas for example use this command: 81 | 82 | ```sh 83 | docker service scale tensorflow-cpu_api=4 84 | ``` 85 | 86 | 2- To check the available workers: 87 | 88 | ```sh 89 | docker node ls 90 | ``` 91 | 92 | 3- To check on which node the container is running: 93 | 94 | ```sh 95 | docker service ps tensorflow-cpu_api 96 | ``` 97 | 98 | 4- To check the number of replicas: 99 | 100 | ```sh 101 | docker service ls 102 | ``` 103 | 104 | ## Benchmarking 105 | 106 | Here are two graphs showing time of prediction for different number of requests at the same time. 107 | 108 | 109 | ![CPU 20 req](./docs/TCPU20req.png) 110 | 111 | 112 | ![CPU 40 req](./docs/TCPU40req.png) 113 | 114 | 115 | We can see that both graphs got the same result no matter what is the number of received requests at the same time. When we increase the number of workers (hosts) we are able to speed up the inference by at least 2 times. For example we can see in the last column we were able to process 40 requests in: 116 | 117 | - 17.5 seconds with 20 replicas in 1 machine 118 | - 8.8 seconds with 20 replicas in each of the 2 machines 119 | 120 | Moreover, in case one of the machines is down the others are always ready to receive requests. 121 | 122 | Finally since we are predicting on CPU scaling more replicas doesn't mean a faster prediction, 4 containers was faster than 20. 123 | -------------------------------------------------------------------------------- /src/main/object_detection/protos/string_int_label_map_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/string_int_label_map.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/string_int_label_map.proto', 20 | package='object_detection.protos', 21 | syntax='proto2', 22 | serialized_pb=_b('\n2object_detection/protos/string_int_label_map.proto\x12\x17object_detection.protos\"G\n\x15StringIntLabelMapItem\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02id\x18\x02 \x01(\x05\x12\x14\n\x0c\x64isplay_name\x18\x03 \x01(\t\"Q\n\x11StringIntLabelMap\x12<\n\x04item\x18\x01 \x03(\x0b\x32..object_detection.protos.StringIntLabelMapItem') 23 | ) 24 | 25 | 26 | 27 | 28 | _STRINGINTLABELMAPITEM = _descriptor.Descriptor( 29 | name='StringIntLabelMapItem', 30 | full_name='object_detection.protos.StringIntLabelMapItem', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='name', full_name='object_detection.protos.StringIntLabelMapItem.name', index=0, 37 | number=1, type=9, cpp_type=9, label=1, 38 | has_default_value=False, default_value=_b("").decode('utf-8'), 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='id', full_name='object_detection.protos.StringIntLabelMapItem.id', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | _descriptor.FieldDescriptor( 50 | name='display_name', full_name='object_detection.protos.StringIntLabelMapItem.display_name', index=2, 51 | number=3, type=9, cpp_type=9, label=1, 52 | has_default_value=False, default_value=_b("").decode('utf-8'), 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None, file=DESCRIPTOR), 56 | ], 57 | extensions=[ 58 | ], 59 | nested_types=[], 60 | enum_types=[ 61 | ], 62 | options=None, 63 | is_extendable=False, 64 | syntax='proto2', 65 | extension_ranges=[], 66 | oneofs=[ 67 | ], 68 | serialized_start=79, 69 | serialized_end=150, 70 | ) 71 | 72 | 73 | _STRINGINTLABELMAP = _descriptor.Descriptor( 74 | name='StringIntLabelMap', 75 | full_name='object_detection.protos.StringIntLabelMap', 76 | filename=None, 77 | file=DESCRIPTOR, 78 | containing_type=None, 79 | fields=[ 80 | _descriptor.FieldDescriptor( 81 | name='item', full_name='object_detection.protos.StringIntLabelMap.item', index=0, 82 | number=1, type=11, cpp_type=10, label=3, 83 | has_default_value=False, default_value=[], 84 | message_type=None, enum_type=None, containing_type=None, 85 | is_extension=False, extension_scope=None, 86 | options=None, file=DESCRIPTOR), 87 | ], 88 | extensions=[ 89 | ], 90 | nested_types=[], 91 | enum_types=[ 92 | ], 93 | options=None, 94 | is_extendable=False, 95 | syntax='proto2', 96 | extension_ranges=[], 97 | oneofs=[ 98 | ], 99 | serialized_start=152, 100 | serialized_end=233, 101 | ) 102 | 103 | _STRINGINTLABELMAP.fields_by_name['item'].message_type = _STRINGINTLABELMAPITEM 104 | DESCRIPTOR.message_types_by_name['StringIntLabelMapItem'] = _STRINGINTLABELMAPITEM 105 | DESCRIPTOR.message_types_by_name['StringIntLabelMap'] = _STRINGINTLABELMAP 106 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 107 | 108 | StringIntLabelMapItem = _reflection.GeneratedProtocolMessageType('StringIntLabelMapItem', (_message.Message,), dict( 109 | DESCRIPTOR = _STRINGINTLABELMAPITEM, 110 | __module__ = 'object_detection.protos.string_int_label_map_pb2' 111 | # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMapItem) 112 | )) 113 | _sym_db.RegisterMessage(StringIntLabelMapItem) 114 | 115 | StringIntLabelMap = _reflection.GeneratedProtocolMessageType('StringIntLabelMap', (_message.Message,), dict( 116 | DESCRIPTOR = _STRINGINTLABELMAP, 117 | __module__ = 'object_detection.protos.string_int_label_map_pb2' 118 | # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMap) 119 | )) 120 | _sym_db.RegisterMessage(StringIntLabelMap) 121 | 122 | 123 | # @@protoc_insertion_point(module_scope) 124 | -------------------------------------------------------------------------------- /src/main/object_detection/utils/label_map_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Label map utility functions.""" 17 | 18 | import logging 19 | 20 | import tensorflow as tf 21 | from google.protobuf import text_format 22 | from object_detection.protos import string_int_label_map_pb2 23 | 24 | 25 | def _validate_label_map(label_map): 26 | """Checks if a label map is valid. 27 | 28 | Args: 29 | label_map: StringIntLabelMap to validate. 30 | 31 | Raises: 32 | ValueError: if label map is invalid. 33 | """ 34 | for item in label_map.item: 35 | if item.id < 0: 36 | raise ValueError('Label map ids should be >= 0.') 37 | if (item.id == 0 and item.name != 'background' and 38 | item.display_name != 'background'): 39 | raise ValueError('Label map id 0 is reserved for the background label') 40 | 41 | 42 | def create_category_index(categories): 43 | """Creates dictionary of COCO compatible categories keyed by category id. 44 | 45 | Args: 46 | categories: a list of dicts, each of which has the following keys: 47 | 'id': (required) an integer id uniquely identifying this category. 48 | 'name': (required) string representing category name 49 | e.g., 'cat', 'dog', 'pizza'. 50 | 51 | Returns: 52 | category_index: a dict containing the same entries as categories, but keyed 53 | by the 'id' field of each category. 54 | """ 55 | category_index = {} 56 | for cat in categories: 57 | category_index[cat['id']] = cat 58 | return category_index 59 | 60 | 61 | def get_max_label_map_index(label_map): 62 | """Get maximum index in label map. 63 | 64 | Args: 65 | label_map: a StringIntLabelMapProto 66 | 67 | Returns: 68 | an integer 69 | """ 70 | return max([item.id for item in label_map.item]) 71 | 72 | 73 | def convert_label_map_to_categories(label_map, 74 | max_num_classes, 75 | use_display_name=True): 76 | """Loads label map proto and returns categories list compatible with eval. 77 | 78 | This function loads a label map and returns a list of dicts, each of which 79 | has the following keys: 80 | 'id': (required) an integer id uniquely identifying this category. 81 | 'name': (required) string representing category name 82 | e.g., 'cat', 'dog', 'pizza'. 83 | We only allow class into the list if its id-label_id_offset is 84 | between 0 (inclusive) and max_num_classes (exclusive). 85 | If there are several items mapping to the same id in the label map, 86 | we will only keep the first one in the categories list. 87 | 88 | Args: 89 | label_map: a StringIntLabelMapProto or None. If None, a default categories 90 | list is created with max_num_classes categories. 91 | max_num_classes: maximum number of (consecutive) label indices to include. 92 | use_display_name: (boolean) choose whether to load 'display_name' field 93 | as category name. If False or if the display_name field does not exist, 94 | uses 'name' field as category names instead. 95 | Returns: 96 | categories: a list of dictionaries representing all possible categories. 97 | """ 98 | categories = [] 99 | list_of_ids_already_added = [] 100 | if not label_map: 101 | label_id_offset = 1 102 | for class_id in range(max_num_classes): 103 | categories.append({ 104 | 'id': class_id + label_id_offset, 105 | 'name': 'category_{}'.format(class_id + label_id_offset) 106 | }) 107 | return categories 108 | for item in label_map.item: 109 | if not 0 < item.id <= max_num_classes: 110 | logging.info('Ignore item %d since it falls outside of requested ' 111 | 'label range.', item.id) 112 | continue 113 | if use_display_name and item.HasField('display_name'): 114 | name = item.display_name 115 | else: 116 | name = item.name 117 | if item.id not in list_of_ids_already_added: 118 | list_of_ids_already_added.append(item.id) 119 | categories.append({'id': item.id, 'name': name}) 120 | return categories 121 | 122 | 123 | def load_labelmap(path): 124 | """Loads label map proto. 125 | 126 | Args: 127 | path: path to StringIntLabelMap proto text file. 128 | Returns: 129 | a StringIntLabelMapProto 130 | """ 131 | with tf.gfile.GFile(path, 'r') as fid: 132 | label_map_string = fid.read() 133 | label_map = string_int_label_map_pb2.StringIntLabelMap() 134 | try: 135 | text_format.Merge(label_map_string, label_map) 136 | except text_format.ParseError: 137 | label_map.ParseFromString(label_map_string) 138 | _validate_label_map(label_map) 139 | return label_map 140 | 141 | 142 | def get_label_map_dict(label_map_path, use_display_name=False): 143 | """Reads a label map and returns a dictionary of label names to id. 144 | 145 | Args: 146 | label_map_path: path to label_map. 147 | use_display_name: whether to use the label map items' display names as keys. 148 | 149 | Returns: 150 | A dictionary mapping label names to id. 151 | """ 152 | label_map = load_labelmap(label_map_path) 153 | label_map_dict = {} 154 | for item in label_map.item: 155 | if use_display_name: 156 | label_map_dict[item.display_name] = item.id 157 | else: 158 | label_map_dict[item.name] = item.id 159 | return label_map_dict 160 | 161 | 162 | def create_category_index_from_labelmap(label_map_path): 163 | """Reads a label map and returns a category index. 164 | 165 | Args: 166 | label_map_path: Path to `StringIntLabelMap` proto text file. 167 | 168 | Returns: 169 | A category index, which is a dictionary that maps integer ids to dicts 170 | containing categories, e.g. 171 | {1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}, ...} 172 | """ 173 | label_map = load_labelmap(label_map_path) 174 | max_num_classes = max(item.id for item in label_map.item) 175 | categories = convert_label_map_to_categories(label_map, max_num_classes) 176 | return create_category_index(categories) 177 | 178 | 179 | def create_class_agnostic_category_index(): 180 | """Creates a category index with a single `object` class.""" 181 | return {1: {'id': 1, 'name': 'object'}} 182 | -------------------------------------------------------------------------------- /src/main/deep_learning_service.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import uuid 5 | from inference.inference_engines_factory import InferenceEngineFactory 6 | from inference.exceptions import ModelNotFound, InvalidModelConfiguration, ModelNotLoaded, InferenceEngineNotFound, \ 7 | InvalidInputData, ApplicationError 8 | 9 | 10 | class DeepLearningService: 11 | 12 | def __init__(self): 13 | """ 14 | Sets the models base directory, and initializes some dictionaries. 15 | Saves the loaded model's hashes to a json file, so the values are saved even though the API went down. 16 | """ 17 | # dictionary to hold the model instances (model_name: string -> model_instance: AbstractInferenceEngine) 18 | self.models_dict = {} 19 | # read from json file and append to dict 20 | file_name = '/models_hash/model_hash.json' 21 | file_exists = os.path.exists(file_name) 22 | if file_exists: 23 | try: 24 | with open(file_name) as json_file: 25 | self.models_hash_dict = json.load(json_file) 26 | except: 27 | self.models_hash_dict = {} 28 | else: 29 | with open('/models_hash/model_hash.json', 'w'): 30 | self.models_hash_dict = {} 31 | self.labels_hash_dict = {} 32 | self.base_models_dir = '/models' 33 | 34 | def load_model(self, model_name, force_reload=False): 35 | """ 36 | Loads a model by passing the model path to the factory. 37 | The factory will return a loaded model instance that will be stored in a dictionary. 38 | :param model_name: Model name 39 | :param force_reload: Boolean to specify if we need to reload a model on each call 40 | :return: Boolean 41 | """ 42 | if not force_reload and self.model_loaded(model_name): 43 | return True 44 | model_path = os.path.join(self.base_models_dir, model_name) 45 | try: 46 | self.models_dict[model_name] = InferenceEngineFactory.get_engine(model_path) 47 | return True 48 | except ApplicationError as e: 49 | raise e 50 | 51 | def load_all_models(self): 52 | """ 53 | Loads all the available models. 54 | :return: Returns a List of all models and their respective hashed values 55 | """ 56 | self.load_models(self.list_models()) 57 | models = self.list_models() 58 | for model in models: 59 | if model not in self.models_hash_dict: 60 | self.models_hash_dict[model] = str(uuid.uuid4()) 61 | for key in list(self.models_hash_dict): 62 | if key not in models: 63 | del self.models_hash_dict[key] 64 | # append to json file 65 | with open('/models_hash/model_hash.json', "w") as fp: 66 | json.dump(self.models_hash_dict, fp) 67 | return self.models_hash_dict 68 | 69 | def load_models(self, model_names): 70 | """ 71 | Loads a set of available models. 72 | :param model_names: List of available models 73 | :return: 74 | """ 75 | for model in model_names: 76 | self.load_model(model) 77 | 78 | async def run_model(self, model_name, input_data, draw, predict_batch): 79 | """ 80 | Loads the model in case it was never loaded and calls the inference engine class to get a prediction. 81 | :param model_name: Model name 82 | :param input_data: Batch of images or a single image 83 | :param draw: Boolean to specify if we need to draw the response on the input image 84 | :param predict_batch: Boolean to specify if there is a batch of images in a request or not 85 | :return: Model response in case draw was set to False, else an actual image 86 | """ 87 | if re.match(r'[0-9a-fA-F]{8}\-[0-9a-fA-F]{4}\-[0-9a-fA-F]{4}\-[0-9a-fA-F]{4}\-[0-9a-fA-F]{12}', model_name, 88 | flags=0): 89 | for key, value in self.models_hash_dict.items(): 90 | if value == model_name: 91 | model_name = key 92 | if self.model_loaded(model_name): 93 | try: 94 | if predict_batch: 95 | return await self.models_dict[model_name].run_batch(input_data, draw, predict_batch) 96 | else: 97 | if not draw: 98 | return await self.models_dict[model_name].infer(input_data, draw, predict_batch) 99 | else: 100 | await self.models_dict[model_name].infer(input_data, draw, predict_batch) 101 | except ApplicationError as e: 102 | raise e 103 | else: 104 | try: 105 | self.load_model(model_name) 106 | return await self.run_model(model_name, input_data, draw, predict_batch) 107 | except ApplicationError as e: 108 | raise e 109 | 110 | def list_models(self): 111 | """ 112 | Lists all the available models. 113 | :return: List of models 114 | """ 115 | return [folder for folder in os.listdir(self.base_models_dir) if 116 | os.path.isdir(os.path.join(self.base_models_dir, folder))] 117 | 118 | def model_loaded(self, model_name): 119 | """ 120 | Returns the model in case it was loaded. 121 | :param model_name: Model name 122 | :return: Model name 123 | """ 124 | return model_name in self.models_dict.keys() 125 | 126 | def get_labels(self, model_name): 127 | """ 128 | Loads the model in case it's not loaded. 129 | Returns the model's labels. 130 | :param model_name: Model name 131 | :return: List of model labels 132 | """ 133 | if not self.model_loaded(model_name): 134 | self.load_model(model_name) 135 | return self.models_dict[model_name].labels 136 | 137 | def get_labels_custom(self, model_name): 138 | """ 139 | Hashes every label of a specific model. 140 | :param model_name: Model name 141 | :return: A list of mode's labels with their hashed values 142 | """ 143 | if re.match(r'[0-9a-fA-F]{8}\-[0-9a-fA-F]{4}\-[0-9a-fA-F]{4}\-[0-9a-fA-F]{4}\-[0-9a-fA-F]{12}', model_name, 144 | flags=0): 145 | for key, value in self.models_hash_dict.items(): 146 | if value == model_name: 147 | model_name = key 148 | models = self.list_models() 149 | if model_name not in self.labels_hash_dict: 150 | model_dict = {} 151 | for label in self.models_dict[model_name].labels: 152 | model_dict[label] = str(uuid.uuid4()) 153 | self.labels_hash_dict[model_name] = model_dict 154 | for key in list(self.labels_hash_dict): 155 | if key not in models: 156 | del self.labels_hash_dict[key] 157 | return self.labels_hash_dict[model_name] 158 | 159 | def get_config(self, model_name): 160 | """ 161 | Returns the model's configuration. 162 | :param model_name: Model name 163 | :return: List of model's configuration 164 | """ 165 | if not self.model_loaded(model_name): 166 | self.load_model(model_name) 167 | return self.models_dict[model_name].configuration 168 | -------------------------------------------------------------------------------- /src/main/object_detection/core/box_list.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Bounding Box List definition. 17 | 18 | BoxList represents a list of bounding boxes as tensorflow 19 | tensors, where each bounding box is represented as a row of 4 numbers, 20 | [y_min, x_min, y_max, x_max]. It is assumed that all bounding boxes 21 | within a given list correspond to a single image. See also 22 | box_list_ops.py for common box related operations (such as area, iou, etc). 23 | 24 | Optionally, users can add additional related fields (such as weights). 25 | We assume the following things to be true about fields: 26 | * they correspond to boxes in the box_list along the 0th dimension 27 | * they have inferrable rank at graph construction time 28 | * all dimensions except for possibly the 0th can be inferred 29 | (i.e., not None) at graph construction time. 30 | 31 | Some other notes: 32 | * Following tensorflow conventions, we use height, width ordering, 33 | and correspondingly, y,x (or ymin, xmin, ymax, xmax) ordering 34 | * Tensors are always provided as (flat) [N, 4] tensors. 35 | """ 36 | 37 | import tensorflow as tf 38 | 39 | 40 | class BoxList(object): 41 | """Box collection.""" 42 | 43 | def __init__(self, boxes): 44 | """Constructs box collection. 45 | 46 | Args: 47 | boxes: a tensor of shape [N, 4] representing box corners 48 | 49 | Raises: 50 | ValueError: if invalid dimensions for bbox data or if bbox data is not in 51 | float32 format. 52 | """ 53 | if len(boxes.get_shape()) != 2 or boxes.get_shape()[-1] != 4: 54 | raise ValueError('Invalid dimensions for box data.') 55 | if boxes.dtype != tf.float32: 56 | raise ValueError('Invalid tensor type: should be tf.float32') 57 | self.data = {'boxes': boxes} 58 | 59 | def num_boxes(self): 60 | """Returns number of boxes held in collection. 61 | 62 | Returns: 63 | a tensor representing the number of boxes held in the collection. 64 | """ 65 | return tf.shape(self.data['boxes'])[0] 66 | 67 | def num_boxes_static(self): 68 | """Returns number of boxes held in collection. 69 | 70 | This number is inferred at graph construction time rather than run-time. 71 | 72 | Returns: 73 | Number of boxes held in collection (integer) or None if this is not 74 | inferrable at graph construction time. 75 | """ 76 | return self.data['boxes'].get_shape()[0].value 77 | 78 | def get_all_fields(self): 79 | """Returns all fields.""" 80 | return self.data.keys() 81 | 82 | def get_extra_fields(self): 83 | """Returns all non-box fields (i.e., everything not named 'boxes').""" 84 | return [k for k in self.data.keys() if k != 'boxes'] 85 | 86 | def add_field(self, field, field_data): 87 | """Add field to box list. 88 | 89 | This method can be used to add related box data such as 90 | weights/labels, etc. 91 | 92 | Args: 93 | field: a string key to access the data via `get` 94 | field_data: a tensor containing the data to store in the BoxList 95 | """ 96 | self.data[field] = field_data 97 | 98 | def has_field(self, field): 99 | return field in self.data 100 | 101 | def get(self): 102 | """Convenience function for accessing box coordinates. 103 | 104 | Returns: 105 | a tensor with shape [N, 4] representing box coordinates. 106 | """ 107 | return self.get_field('boxes') 108 | 109 | def set(self, boxes): 110 | """Convenience function for setting box coordinates. 111 | 112 | Args: 113 | boxes: a tensor of shape [N, 4] representing box corners 114 | 115 | Raises: 116 | ValueError: if invalid dimensions for bbox data 117 | """ 118 | if len(boxes.get_shape()) != 2 or boxes.get_shape()[-1] != 4: 119 | raise ValueError('Invalid dimensions for box data.') 120 | self.data['boxes'] = boxes 121 | 122 | def get_field(self, field): 123 | """Accesses a box collection and associated fields. 124 | 125 | This function returns specified field with object; if no field is specified, 126 | it returns the box coordinates. 127 | 128 | Args: 129 | field: this optional string parameter can be used to specify 130 | a related field to be accessed. 131 | 132 | Returns: 133 | a tensor representing the box collection or an associated field. 134 | 135 | Raises: 136 | ValueError: if invalid field 137 | """ 138 | if not self.has_field(field): 139 | raise ValueError('field ' + str(field) + ' does not exist') 140 | return self.data[field] 141 | 142 | def set_field(self, field, value): 143 | """Sets the value of a field. 144 | 145 | Updates the field of a box_list with a given value. 146 | 147 | Args: 148 | field: (string) name of the field to set value. 149 | value: the value to assign to the field. 150 | 151 | Raises: 152 | ValueError: if the box_list does not have specified field. 153 | """ 154 | if not self.has_field(field): 155 | raise ValueError('field %s does not exist' % field) 156 | self.data[field] = value 157 | 158 | def get_center_coordinates_and_sizes(self, scope=None): 159 | """Computes the center coordinates, height and width of the boxes. 160 | 161 | Args: 162 | scope: name scope of the function. 163 | 164 | Returns: 165 | a list of 4 1-D tensors [ycenter, xcenter, height, width]. 166 | """ 167 | with tf.name_scope(scope, 'get_center_coordinates_and_sizes'): 168 | box_corners = self.get() 169 | ymin, xmin, ymax, xmax = tf.unstack(tf.transpose(box_corners)) 170 | width = xmax - xmin 171 | height = ymax - ymin 172 | ycenter = ymin + height / 2. 173 | xcenter = xmin + width / 2. 174 | return [ycenter, xcenter, height, width] 175 | 176 | def transpose_coordinates(self, scope=None): 177 | """Transpose the coordinate representation in a boxlist. 178 | 179 | Args: 180 | scope: name scope of the function. 181 | """ 182 | with tf.name_scope(scope, 'transpose_coordinates'): 183 | y_min, x_min, y_max, x_max = tf.split( 184 | value=self.get(), num_or_size_splits=4, axis=1) 185 | self.set(tf.concat([x_min, y_min, x_max, y_max], 1)) 186 | 187 | def as_tensor_dict(self, fields=None): 188 | """Retrieves specified fields as a dictionary of tensors. 189 | 190 | Args: 191 | fields: (optional) list of fields to return in the dictionary. 192 | If None (default), all fields are returned. 193 | 194 | Returns: 195 | tensor_dict: A dictionary of tensors specified by fields. 196 | 197 | Raises: 198 | ValueError: if specified field is not contained in boxlist. 199 | """ 200 | tensor_dict = {} 201 | if fields is None: 202 | fields = self.get_all_fields() 203 | for field in fields: 204 | if not self.has_field(field): 205 | raise ValueError('boxlist must contain all specified fields') 206 | tensor_dict[field] = self.get_field(field) 207 | return tensor_dict 208 | -------------------------------------------------------------------------------- /src/main/inference/tensorflow_detection.py: -------------------------------------------------------------------------------- 1 | import os 2 | import jsonschema 3 | import asyncio 4 | import json 5 | import numpy as np 6 | import tensorflow as tf 7 | from PIL import Image, ImageDraw, ImageFont 8 | from object_detection.utils import label_map_util 9 | from inference.base_inference_engine import AbstractInferenceEngine 10 | from inference.exceptions import InvalidModelConfiguration, InvalidInputData, ApplicationError 11 | 12 | 13 | class InferenceEngine(AbstractInferenceEngine): 14 | 15 | def __init__(self, model_path): 16 | self.label_path = "" 17 | self.NUM_CLASSES = None 18 | self.sess = None 19 | self.label_map = None 20 | self.categories = None 21 | self.category_index = None 22 | self.detection_graph = None 23 | self.image_tensor = None 24 | self.d_boxes = None 25 | self.d_scores = None 26 | self.d_classes = None 27 | self.num_d = None 28 | self.font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 20) 29 | super().__init__(model_path) 30 | 31 | def load(self): 32 | with open(os.path.join(self.model_path, 'config.json')) as f: 33 | data = json.load(f) 34 | try: 35 | self.validate_json_configuration(data) 36 | self.set_model_configuration(data) 37 | except ApplicationError as e: 38 | raise e 39 | 40 | self.label_path = os.path.join(self.model_path, 'object-detection.pbtxt') 41 | self.label_map = label_map_util.load_labelmap(self.label_path) 42 | self.categories = label_map_util.convert_label_map_to_categories(self.label_map, 43 | max_num_classes=self.NUM_CLASSES, 44 | use_display_name=True) 45 | for dict in self.categories: 46 | self.labels.append(dict['name']) 47 | 48 | self.category_index = label_map_util.create_category_index(self.categories) 49 | self.detection_graph = tf.Graph() 50 | with self.detection_graph.as_default(): 51 | od_graph_def = tf.GraphDef() 52 | with tf.gfile.GFile(os.path.join(self.model_path, 'frozen_inference_graph.pb'), 'rb') as fid: 53 | serialized_graph = fid.read() 54 | od_graph_def.ParseFromString(serialized_graph) 55 | tf.import_graph_def(od_graph_def, name='') 56 | self.image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0') 57 | self.d_boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0') 58 | self.d_scores = self.detection_graph.get_tensor_by_name('detection_scores:0') 59 | self.d_classes = self.detection_graph.get_tensor_by_name('detection_classes:0') 60 | self.num_d = self.detection_graph.get_tensor_by_name('num_detections:0') 61 | self.sess = tf.Session(graph=self.detection_graph) 62 | img = Image.open("object_detection/image1.jpg") 63 | img_expanded = np.expand_dims(img, axis=0) 64 | self.sess.run( 65 | [self.d_boxes, self.d_scores, self.d_classes, self.num_d], 66 | feed_dict={self.image_tensor: img_expanded}) 67 | 68 | async def infer(self, input_data, draw, predict_batch): 69 | await asyncio.sleep(0.00001) 70 | try: 71 | pillow_image = Image.open(input_data.file).convert('RGB') 72 | np_image = np.array(pillow_image) 73 | except Exception as e: 74 | raise InvalidInputData('corrupted image') 75 | try: 76 | with open(self.model_path + '/config.json') as f: 77 | data = json.load(f) 78 | except Exception as e: 79 | raise InvalidModelConfiguration('config.json not found or corrupted') 80 | json_confidence = data['confidence'] 81 | json_predictions = data['predictions'] 82 | with self.detection_graph.as_default(): 83 | # Expand dimension since the model expects image to have shape [1, None, None, 3]. 84 | img_expanded = np.expand_dims(np_image, axis=0) 85 | (boxes, scores, classes, num) = self.sess.run( 86 | [self.d_boxes, self.d_scores, self.d_classes, self.num_d], 87 | feed_dict={self.image_tensor: img_expanded}) 88 | classes_names = ([self.category_index.get(i) for i in classes[0]]) 89 | names_start = [] 90 | for name in classes_names: 91 | if name is not None: 92 | names_start.append(name['name']) 93 | height, width, depth = np_image.shape 94 | names = [] 95 | confidence = [] 96 | ids = [] 97 | bounding_boxes = [] 98 | for i in range(json_predictions): 99 | if scores[0][i] * 100 >= json_confidence: 100 | ymin = int(round(boxes[0][i][0] * height)) if int(round(boxes[0][i][0] * height)) > 0 else 0 101 | xmin = int(round(boxes[0][i][1] * width)) if int(round(boxes[0][i][1] * height)) > 0 else 0 102 | ymax = int(round(boxes[0][i][2] * height)) if int(round(boxes[0][i][2] * height)) > 0 else 0 103 | xmax = int(round(boxes[0][i][3] * width)) if int(round(boxes[0][i][3] * height)) > 0 else 0 104 | tmp = dict([('left', xmin), ('top', ymin), ('right', xmax), ('bottom', ymax)]) 105 | bounding_boxes.append(tmp) 106 | confidence.append(float(scores[0][i] * 100)) 107 | ids.append(int(classes[0][i])) 108 | names.append(names_start[i]) 109 | 110 | responses_list = zip(names, confidence, bounding_boxes, ids) 111 | 112 | output = [] 113 | for response in responses_list: 114 | tmp = dict([('ObjectClassName', response[0]), ('confidence', response[1]), ('coordinates', response[2]), 115 | ('ObjectClassId', response[3])]) 116 | output.append(tmp) 117 | if predict_batch: 118 | response = dict([('bounding-boxes', output), ('ImageName', input_data.filename)]) 119 | else: 120 | response = dict([('bounding-boxes', output)]) 121 | 122 | if not draw: 123 | return response 124 | else: 125 | try: 126 | self.draw_image(pillow_image, response) 127 | except ApplicationError as e: 128 | raise e 129 | except Exception as e: 130 | raise e 131 | 132 | async def run_batch(self, input_data, draw, predict_batch): 133 | result_list = [] 134 | for image in input_data: 135 | post_process = await self.infer(image, draw, predict_batch) 136 | if post_process is not None: 137 | result_list.append(post_process) 138 | return result_list 139 | 140 | def draw_image(self, image, response): 141 | """ 142 | Draws on image and saves it. 143 | :param image: image of type pillow image 144 | :param response: inference response 145 | :return: 146 | """ 147 | draw = ImageDraw.Draw(image) 148 | for bbox in response['bounding-boxes']: 149 | draw.rectangle([bbox['coordinates']['left'], bbox['coordinates']['top'], bbox['coordinates']['right'], 150 | bbox['coordinates']['bottom']], outline="red") 151 | left = bbox['coordinates']['left'] 152 | top = bbox['coordinates']['top'] 153 | conf = "{0:.2f}".format(bbox['confidence']) 154 | draw.text((int(left), int(top) - 20), str(conf) + "% " + str(bbox['ObjectClassName']), 'red', self.font) 155 | image.save('/main/result.jpg', 'PNG') 156 | 157 | def free(self): 158 | pass 159 | 160 | def validate_configuration(self): 161 | # check if weights file exists 162 | if not os.path.exists(os.path.join(self.model_path, 'frozen_inference_graph.pb')): 163 | raise InvalidModelConfiguration('frozen_inference_graph.pb not found') 164 | # check if labels file exists 165 | if not os.path.exists(os.path.join(self.model_path, 'object-detection.pbtxt')): 166 | raise InvalidModelConfiguration('object-detection.pbtxt not found') 167 | return True 168 | 169 | def set_model_configuration(self, data): 170 | self.configuration['framework'] = data['framework'] 171 | self.configuration['type'] = data['type'] 172 | self.configuration['network'] = data['network'] 173 | self.NUM_CLASSES = data['number_of_classes'] 174 | 175 | def validate_json_configuration(self, data): 176 | with open(os.path.join('inference', 'ConfigurationSchema.json')) as f: 177 | schema = json.load(f) 178 | try: 179 | jsonschema.validate(data, schema) 180 | except Exception as e: 181 | raise InvalidModelConfiguration(e) 182 | -------------------------------------------------------------------------------- /src/main/start.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import List 3 | from models import ApiResponse 4 | from inference.errors import Error 5 | from starlette.responses import FileResponse 6 | from starlette.staticfiles import StaticFiles 7 | from starlette.middleware.cors import CORSMiddleware 8 | from deep_learning_service import DeepLearningService 9 | from fastapi import FastAPI, Form, File, UploadFile, Header, HTTPException 10 | from inference.exceptions import ModelNotFound, InvalidModelConfiguration, ApplicationError, ModelNotLoaded, \ 11 | InferenceEngineNotFound, InvalidInputData 12 | from ocr import ocr_service, one_shot_ocr_service 13 | from datetime import datetime 14 | import pytz 15 | from PIL import Image 16 | 17 | sys.path.append('./inference') 18 | 19 | tz = pytz.timezone("Europe/Berlin") 20 | 21 | dl_service = DeepLearningService() 22 | error_logging = Error() 23 | app = FastAPI(version='1.0', title='BMW InnovationLab tensorflow cpu inference Automation', 24 | description="API for performing tensorflow cpu inference

" 25 | "Contact the developers:
" 26 | "Hadi Koubeissy: hadi.koubeissy@inmind.ai
" 27 | "BMW Innovation Lab: innovation-lab@bmw.de") 28 | 29 | 30 | # app.mount("/public", StaticFiles(directory="/main/public"), name="public") 31 | 32 | # app.add_middleware( 33 | # CORSMiddleware, 34 | # allow_origins=["*"], 35 | # allow_credentials=True, 36 | # allow_methods=["*"], 37 | # allow_headers=["*"], 38 | # ) 39 | 40 | 41 | @app.get('/load') 42 | def load_custom(): 43 | """ 44 | Loads all the available models. 45 | :return: All the available models with their respective hashed values 46 | """ 47 | try: 48 | return dl_service.load_all_models() 49 | except ApplicationError as e: 50 | return ApiResponse(success=False, error=e) 51 | except Exception: 52 | return ApiResponse(success=False, error='unexpected server error') 53 | 54 | 55 | @app.post('/detect') 56 | async def detect_custom(model: str = Form(...), image: UploadFile = File(...)): 57 | """ 58 | Performs a prediction for a specified image using one of the available models. 59 | :param model: Model name or model hash 60 | :param image: Image file 61 | :return: Model's Bounding boxes 62 | """ 63 | draw_boxes = False 64 | predict_batch = False 65 | try: 66 | output = await dl_service.run_model(model, image, draw_boxes, predict_batch) 67 | error_logging.info('request successful;' + str(output)) 68 | return output 69 | except ApplicationError as e: 70 | error_logging.warning(model + ';' + str(e)) 71 | return ApiResponse(success=False, error=e) 72 | except Exception as e: 73 | error_logging.error(model + ' ' + str(e)) 74 | return ApiResponse(success=False, error='unexpected server error') 75 | 76 | 77 | @app.post('/get_labels') 78 | def get_labels_custom(model: str = Form(...)): 79 | """ 80 | Lists the model's labels with their hashed values. 81 | :param model: Model name or model hash 82 | :return: A list of the model's labels with their hashed values 83 | """ 84 | return dl_service.get_labels_custom(model) 85 | 86 | 87 | @app.get('/models/{model_name}/load') 88 | async def load(model_name: str, force: bool = False): 89 | """ 90 | Loads a model specified as a query parameter. 91 | :param model_name: Model name 92 | :param force: Boolean for model force reload on each call 93 | :return: APIResponse 94 | """ 95 | try: 96 | dl_service.load_model(model_name, force) 97 | return ApiResponse(success=True) 98 | except ApplicationError as e: 99 | return ApiResponse(success=False, error=e) 100 | 101 | 102 | @app.get('/models') 103 | async def list_models(user_agent: str = Header(None)): 104 | """ 105 | Lists all available models. 106 | :param user_agent: 107 | :return: APIResponse 108 | """ 109 | return ApiResponse(data={'models': dl_service.list_models()}) 110 | 111 | 112 | @app.post('/models/{model_name}/predict') 113 | async def run_model(model_name: str, input_data: UploadFile = File(...)): 114 | """ 115 | Performs a prediction by giving both model name and image file. 116 | :param model_name: Model name 117 | :param input_data: An image file 118 | :return: APIResponse containing the prediction's bounding boxes 119 | """ 120 | try: 121 | output = await dl_service.run_model(model_name, input_data, draw=False, predict_batch=False) 122 | error_logging.info('request successful;' + str(output)) 123 | return ApiResponse(data=output) 124 | except ApplicationError as e: 125 | error_logging.warning(model_name + ';' + str(e)) 126 | return ApiResponse(success=False, error=e) 127 | except Exception as e: 128 | error_logging.error(model_name + ' ' + str(e)) 129 | return ApiResponse(success=False, error='unexpected server error') 130 | 131 | 132 | @app.post('/models/{model_name}/predict_batch', include_in_schema=False) 133 | async def run_model_batch(model_name: str, input_data: List[UploadFile] = File(...)): 134 | """ 135 | Performs a prediction by giving both model name and image file(s). 136 | :param model_name: Model name 137 | :param input_data: A batch of image files or a single image file 138 | :return: APIResponse containing prediction(s) bounding boxes 139 | """ 140 | try: 141 | output = await dl_service.run_model(model_name, input_data, draw=False, predict_batch=True) 142 | error_logging.info('request successful;' + str(output)) 143 | return ApiResponse(data=output) 144 | except ApplicationError as e: 145 | error_logging.warning(model_name + ';' + str(e)) 146 | return ApiResponse(success=False, error=e) 147 | except Exception as e: 148 | print(e) 149 | error_logging.error(model_name + ' ' + str(e)) 150 | return ApiResponse(success=False, error='unexpected server error') 151 | 152 | 153 | @app.post('/models/{model_name}/predict_image') 154 | async def predict_image(model_name: str, input_data: UploadFile = File(...)): 155 | """ 156 | Draws bounding box(es) on image and returns it. 157 | :param model_name: Model name 158 | :param input_data: Image file 159 | :return: Image file 160 | """ 161 | try: 162 | output = await dl_service.run_model(model_name, input_data, draw=True, predict_batch=False) 163 | error_logging.info('request successful;' + str(output)) 164 | return FileResponse("/main/result.jpg", media_type="image/jpg") 165 | except ApplicationError as e: 166 | error_logging.warning(model_name + ';' + str(e)) 167 | return ApiResponse(success=False, error=e) 168 | except Exception as e: 169 | error_logging.error(model_name + ' ' + str(e)) 170 | return ApiResponse(success=False, error='unexpected server error') 171 | 172 | 173 | @app.get('/models/{model_name}/labels') 174 | async def list_model_labels(model_name: str): 175 | """ 176 | Lists all the model's labels. 177 | :param model_name: Model name 178 | :return: List of model's labels 179 | """ 180 | labels = dl_service.get_labels(model_name) 181 | return ApiResponse(data=labels) 182 | 183 | 184 | @app.get('/models/{model_name}/config') 185 | async def list_model_config(model_name: str): 186 | """ 187 | Lists all the model's configuration. 188 | :param model_name: Model name 189 | :return: List of model's configuration 190 | """ 191 | config = dl_service.get_config(model_name) 192 | return ApiResponse(data=config) 193 | 194 | 195 | @app.post('/models/{model_name}/one_shot_ocr') 196 | async def one_shot_ocr( 197 | model_name: str, 198 | image: UploadFile = File( 199 | ..., description="Image to perform optical character recognition based on layout inference:") 200 | ): 201 | """ 202 | Takes an image and returns extracted text details. 203 | 204 | In first place a detection model will be used for cropping interesting areas in the uploaded image. These areas will then be passed to the OCR-Service for text extraction. 205 | 206 | :param model_name: Model name or model hash for layout detection 207 | 208 | :param image: Image file 209 | 210 | :return: Text fields with the detected files inside 211 | 212 | """ 213 | output = None 214 | # call detection on image with choosen model 215 | try: 216 | output = await run_model(model_name, image) 217 | except: 218 | raise HTTPException(status_code=404, detail='Invalid Model') 219 | 220 | # run ocr_service 221 | response = None 222 | try: 223 | image = Image.open(image.file).convert('RGB') 224 | response = one_shot_ocr_service(image, output.data) 225 | except: 226 | raise HTTPException( 227 | status_code=500, detail='Unexpected Error during Inference (Determination of Texts)') 228 | 229 | if not response: 230 | raise HTTPException( 231 | status_code=400, detail='Inference (Determination of Texts) is not Possible with the Specified Model') 232 | 233 | return response 234 | 235 | 236 | @app.post('/models/{model_name}/ocr') 237 | async def optical_character_recognition( 238 | model_name: str, 239 | image: UploadFile = File( 240 | ..., description="Image to perform optical character recognition based on layout inference:"), 241 | ): 242 | """ 243 | Takes an image and returns extracted text informations. 244 | 245 | The image is passed to the OCR-Service for text extraction 246 | 247 | :param model: Model name or model hash 248 | 249 | :param image: Image file 250 | 251 | :return: Text fields with the detected files inside 252 | 253 | """ 254 | # run ocr_service 255 | response = None 256 | try: 257 | image = Image.open(image.file).convert('RGB') 258 | response = ocr_service(image) 259 | except: 260 | raise HTTPException( 261 | status_code=500, detail='Unexpected Error during Inference (Determination of Texts)') 262 | 263 | if not response: 264 | raise HTTPException( 265 | status_code=400, detail='Inference (Determination of Texts) is not Possible with the Specified Model') 266 | 267 | return response 268 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tensorflow CPU Inference API For Windows and Linux 2 | This is a repository for an object detection inference API using the Tensorflow framework. 3 | 4 | This repo is based on [Tensorflow Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection). 5 | 6 | The Tensorflow version used is 1.13.1. The inference REST API works on CPU and doesn't require any GPU usage. It's supported on both Windows and Linux Operating systems. 7 | 8 | Models trained using our training tensorflow repository can be deployed in this API. Several object detection models can be loaded and used at the same time. 9 | This repo also offers optical character recognition services to extract textboxes from images. 10 | 11 | This repo can be deployed using either **docker** or **docker swarm**. 12 | 13 | Please use **docker swarm** only if you need to: 14 | 15 | * Provide redundancy in terms of API containers: In case a container went down, the incoming requests will be redirected to another running instance. 16 | 17 | * Coordinate between the containers: Swarm will orchestrate between the APIs and choose one of them to listen to the incoming request. 18 | 19 | * Scale up the Inference service in order to get a faster prediction especially if there's traffic on the service. 20 | 21 | If none of the aforementioned requirements are needed, simply use **docker**. 22 | 23 | ![predict image](./docs/4.gif) 24 | 25 | ## Prerequisites 26 | 27 | - OS: 28 | - Ubuntu 16.04/18.04 29 | - Windows 10 pro/enterprise 30 | - Docker 31 | 32 | ### Check for prerequisites 33 | 34 | To check if you have docker-ce installed: 35 | 36 | ```sh 37 | docker --version 38 | ``` 39 | 40 | ### Install prerequisites 41 | 42 | #### Ubuntu 43 | 44 | Use the following command to install docker on Ubuntu: 45 | 46 | ```sh 47 | chmod +x install_prerequisites.sh && source install_prerequisites.sh 48 | ``` 49 | 50 | #### Windows 10 51 | 52 | To [install Docker on Windows](https://docs.docker.com/docker-for-windows/install/), please follow the link. 53 | 54 | **P.S: For Windows users, open the Docker Desktop menu by clicking the Docker Icon in the Notifications area. Select Settings, and then Advanced tab to adjust the resources available to Docker Engine.** 55 | 56 | ## Build The Docker Image 57 | 58 | In order to build the project run the following command from the project's root directory: 59 | 60 | ```sh 61 | sudo docker build -t tensorflow_inference_api_cpu -f docker/dockerfile . 62 | ``` 63 | 64 | ### Behind a proxy 65 | 66 | ```sh 67 | sudo docker build --build-arg http_proxy='' --build-arg https_proxy='' -t tensorflow_inference_api_cpu -f ./docker/dockerfile . 68 | ``` 69 | 70 | ## Run the docker container 71 | 72 | As mentioned before, this container can be deployed using either **docker** or **docker swarm**. 73 | 74 | If you wish to deploy this API using **docker**, please issue the following run command. 75 | 76 | If you wish to deploy this API using **docker swarm**, please refer to following link [docker swarm documentation](./README-docker_swarm.md). After deploying the API with docker swarm, please consider returning to this documentation for further information about the API endpoints as well as the model structure sections. 77 | 78 | To run the API, go the to the API's directory and run the following: 79 | 80 | #### Using Linux based docker: 81 | 82 | ```sh 83 | sudo docker run -itv $(pwd)/models:/models -v $(pwd)/models_hash:/models_hash -p :4343 tensorflow_inference_api_cpu 84 | ``` 85 | 86 | #### Using Windows based docker: 87 | 88 | ```sh 89 | docker run -itv ${PWD}/models:/models -v ${PWD}/models_hash:/models_hash -p :4343 tensorflow_inference_api_cpu 90 | ``` 91 | 92 | The can be any unique port of your choice. 93 | 94 | The API file will be run automatically, and the service will listen to http requests on the chosen port. 95 | 96 | ## API Endpoints 97 | 98 | To see all available endpoints, open your favorite browser and navigate to: 99 | 100 | ``` 101 | http://:/docs 102 | ``` 103 | The 'predict_batch' endpoint is not shown on swagger. The list of files input is not yet supported. 104 | 105 | **P.S: If you are using custom endpoints like /load, /detect, and /get_labels, you should always use the /load endpoint first and then use /detect or /get_labels** 106 | 107 | ### Endpoints summary 108 | 109 | #### /load (GET) 110 | 111 | Loads all available models and returns every model with it's hashed value. Loaded models are stored and aren't loaded again 112 | 113 | ![load model](./docs/1.gif) 114 | 115 | #### /detect (POST) 116 | 117 | Performs inference on specified model, image, and returns bounding-boxes 118 | 119 | ![detect image](./docs/3.gif) 120 | 121 | #### /get_labels (POST) 122 | 123 | Returns all of the specified model labels with their hashed values 124 | 125 | ![get model labels](./docs/2.gif) 126 | 127 | #### /models/{model_name}/predict_image (POST) 128 | 129 | Performs inference on specified model, image, draws bounding boxes on the image, and returns the actual image as response 130 | 131 | ![predict image](./docs/4.gif) 132 | 133 | #### /models (GET) 134 | 135 | Lists all available models 136 | 137 | #### /models/{model_name}/load (GET) 138 | 139 | Loads the specified model. Loaded models are stored and aren't loaded again 140 | 141 | #### /models/{model_name}/predict (POST) 142 | 143 | Performs inference on specified model, image, and returns bounding boxes. 144 | 145 | #### /models/{model_name}/labels (GET) 146 | 147 | Returns all of the specified model labels 148 | 149 | #### /models/{model_name}/config (GET) 150 | 151 | Returns the specified model's configuration 152 | 153 | #### /models/{model_name}/predict_batch (POST) 154 | 155 | Performs inference on specified model and a list of images, and returns bounding boxes 156 | 157 | #### /models/{model_name}/one_shot_ocr (POST) 158 | 159 | Takes an image and returns extracted text details. In first place a detection model will be used for cropping interesting areas in the uploaded image. Then, these areas will be passed to the OCR-Service for text extraction. 160 | 161 | #### /models/{model_name}/ocr (POST) 162 | 163 | ![predict image](./docs/5.gif) 164 | 165 | Takes an image and returns extracted text details without using an object detection model 166 | 167 | **P.S:** Custom endpoints like /load, /detect, /get_labels and /one_shot_ocr should be used in a chronological order. First you have to call /load, and then call /detect, /get_labels or /one_shot_ocr 168 | ## Model structure 169 | 170 | The folder "models" contains subfolders of all the models to be loaded. 171 | Inside each subfolder there should be a: 172 | 173 | - pb file (frozen_inference_graph.pb): contains the model weights 174 | 175 | - pbtxt file (object-detection.pbtxt): contains model classes 176 | 177 | - Config.json (This is a json file containing information about the model) 178 | 179 | ```json 180 | { 181 | "inference_engine_name": "tensorflow_detection", 182 | "confidence": 60, 183 | "predictions": 15, 184 | "number_of_classes": 2, 185 | "framework": "tensorflow", 186 | "type": "detection", 187 | "network": "inception" 188 | } 189 | ``` 190 | P.S: 191 | - You can change confidence and predictions values while running the API 192 | - The API will return bounding boxes with a confidence higher than the "confidence" value. A high "confidence" can show you only accurate predictions 193 | - The "predictions" value specifies the maximum number of bounding boxes in the API response 194 | 195 | 196 | ## Benchmarking 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 |
WindowsUbuntu
Network\HardwareIntel Xeon CPU 2.3 GHzIntel Xeon CPU 2.3 GHzIntel Xeon CPU 3.60 GHzGeForce GTX 1080
ssd_fpn0.867 seconds/image1.016 seconds/image0.434 seconds/image0.0658 seconds/image
frcnn_resnet_504.029 seconds/image4.219 seconds/image1.994 seconds/image0.148 seconds/image
ssd_mobilenet0.055 seconds/image0.106 seconds/image0.051 seconds/image0.052 seconds/image
frcnn_resnet_1014.469 seconds/image4.985 seconds/image2.254 seconds/image0.364 seconds/image
ssd_resnet_501.34 seconds/image1.462 seconds/image0.668 seconds/image0.091 seconds/image
ssd_inception0.094 seconds/image0.15 seconds/image0.074 seconds/image0.0513 seconds/image
260 | 261 | ## Acknowledgment 262 | 263 | [inmind.ai](https://inmind.ai) 264 | 265 | [robotron.de](https://robotron.de) 266 | 267 | Joe Sleiman, inmind.ai , Beirut, Lebanon 268 | 269 | Antoine Charbel, inmind.ai, Beirut, Lebanon 270 | 271 | [Anis Ismail](https://www.linkedin.com/in/anisdismail), Beirut, Lebanon 272 | -------------------------------------------------------------------------------- /src/main/object_detection/core/standard_fields.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Contains classes specifying naming conventions used for object detection. 17 | 18 | 19 | Specifies: 20 | InputDataFields: standard fields used by reader/preprocessor/batcher. 21 | DetectionResultFields: standard fields returned by object detector. 22 | BoxListFields: standard field used by BoxList 23 | TfExampleFields: standard fields for tf-example data format (go/tf-example). 24 | """ 25 | 26 | 27 | class InputDataFields(object): 28 | """Names for the input tensors. 29 | 30 | Holds the standard data field names to use for identifying input tensors. This 31 | should be used by the decoder to identify keys for the returned tensor_dict 32 | containing input tensors. And it should be used by the model to identify the 33 | tensors it needs. 34 | 35 | Attributes: 36 | image: image. 37 | image_additional_channels: additional channels. 38 | original_image: image in the original input size. 39 | key: unique key corresponding to image. 40 | source_id: source of the original image. 41 | filename: original filename of the dataset (without common path). 42 | groundtruth_image_classes: image-level class labels. 43 | groundtruth_boxes: coordinates of the ground truth boxes in the image. 44 | groundtruth_classes: box-level class labels. 45 | groundtruth_label_types: box-level label types (e.g. explicit negative). 46 | groundtruth_is_crowd: [DEPRECATED, use groundtruth_group_of instead] 47 | is the groundtruth a single object or a crowd. 48 | groundtruth_area: area of a groundtruth segment. 49 | groundtruth_difficult: is a `difficult` object 50 | groundtruth_group_of: is a `group_of` objects, e.g. multiple objects of the 51 | same class, forming a connected group, where instances are heavily 52 | occluding each other. 53 | proposal_boxes: coordinates of object proposal boxes. 54 | proposal_objectness: objectness score of each proposal. 55 | groundtruth_instance_masks: ground truth instance masks. 56 | groundtruth_instance_boundaries: ground truth instance boundaries. 57 | groundtruth_instance_classes: instance mask-level class labels. 58 | groundtruth_keypoints: ground truth keypoints. 59 | groundtruth_keypoint_visibilities: ground truth keypoint visibilities. 60 | groundtruth_label_scores: groundtruth label scores. 61 | groundtruth_weights: groundtruth weight factor for bounding boxes. 62 | num_groundtruth_boxes: number of groundtruth boxes. 63 | true_image_shapes: true shapes of images in the resized images, as resized 64 | images can be padded with zeros. 65 | verified_labels: list of human-verified image-level labels (note, that a 66 | label can be verified both as positive and negative). 67 | multiclass_scores: the label score per class for each box. 68 | """ 69 | image = 'image' 70 | image_additional_channels = 'image_additional_channels' 71 | original_image = 'original_image' 72 | key = 'key' 73 | source_id = 'source_id' 74 | filename = 'filename' 75 | groundtruth_image_classes = 'groundtruth_image_classes' 76 | groundtruth_boxes = 'groundtruth_boxes' 77 | groundtruth_classes = 'groundtruth_classes' 78 | groundtruth_label_types = 'groundtruth_label_types' 79 | groundtruth_is_crowd = 'groundtruth_is_crowd' 80 | groundtruth_area = 'groundtruth_area' 81 | groundtruth_difficult = 'groundtruth_difficult' 82 | groundtruth_group_of = 'groundtruth_group_of' 83 | proposal_boxes = 'proposal_boxes' 84 | proposal_objectness = 'proposal_objectness' 85 | groundtruth_instance_masks = 'groundtruth_instance_masks' 86 | groundtruth_instance_boundaries = 'groundtruth_instance_boundaries' 87 | groundtruth_instance_classes = 'groundtruth_instance_classes' 88 | groundtruth_keypoints = 'groundtruth_keypoints' 89 | groundtruth_keypoint_visibilities = 'groundtruth_keypoint_visibilities' 90 | groundtruth_label_scores = 'groundtruth_label_scores' 91 | groundtruth_weights = 'groundtruth_weights' 92 | num_groundtruth_boxes = 'num_groundtruth_boxes' 93 | true_image_shape = 'true_image_shape' 94 | verified_labels = 'verified_labels' 95 | multiclass_scores = 'multiclass_scores' 96 | 97 | 98 | class DetectionResultFields(object): 99 | """Naming conventions for storing the output of the detector. 100 | 101 | Attributes: 102 | source_id: source of the original image. 103 | key: unique key corresponding to image. 104 | detection_boxes: coordinates of the detection boxes in the image. 105 | detection_scores: detection scores for the detection boxes in the image. 106 | detection_classes: detection-level class labels. 107 | detection_masks: contains a segmentation mask for each detection box. 108 | detection_boundaries: contains an object boundary for each detection box. 109 | detection_keypoints: contains detection keypoints for each detection box. 110 | num_detections: number of detections in the batch. 111 | """ 112 | 113 | source_id = 'source_id' 114 | key = 'key' 115 | detection_boxes = 'detection_boxes' 116 | detection_scores = 'detection_scores' 117 | detection_classes = 'detection_classes' 118 | detection_masks = 'detection_masks' 119 | detection_boundaries = 'detection_boundaries' 120 | detection_keypoints = 'detection_keypoints' 121 | num_detections = 'num_detections' 122 | 123 | 124 | class BoxListFields(object): 125 | """Naming conventions for BoxLists. 126 | 127 | Attributes: 128 | boxes: bounding box coordinates. 129 | classes: classes per bounding box. 130 | scores: scores per bounding box. 131 | weights: sample weights per bounding box. 132 | objectness: objectness score per bounding box. 133 | masks: masks per bounding box. 134 | boundaries: boundaries per bounding box. 135 | keypoints: keypoints per bounding box. 136 | keypoint_heatmaps: keypoint heatmaps per bounding box. 137 | is_crowd: is_crowd annotation per bounding box. 138 | """ 139 | boxes = 'boxes' 140 | classes = 'classes' 141 | scores = 'scores' 142 | weights = 'weights' 143 | objectness = 'objectness' 144 | masks = 'masks' 145 | boundaries = 'boundaries' 146 | keypoints = 'keypoints' 147 | keypoint_heatmaps = 'keypoint_heatmaps' 148 | is_crowd = 'is_crowd' 149 | 150 | 151 | class TfExampleFields(object): 152 | """TF-example proto feature names for object detection. 153 | 154 | Holds the standard feature names to load from an Example proto for object 155 | detection. 156 | 157 | Attributes: 158 | image_encoded: JPEG encoded string 159 | image_format: image format, e.g. "JPEG" 160 | filename: filename 161 | channels: number of channels of image 162 | colorspace: colorspace, e.g. "RGB" 163 | height: height of image in pixels, e.g. 462 164 | width: width of image in pixels, e.g. 581 165 | source_id: original source of the image 166 | image_class_text: image-level label in text format 167 | image_class_label: image-level label in numerical format 168 | object_class_text: labels in text format, e.g. ["person", "cat"] 169 | object_class_label: labels in numbers, e.g. [16, 8] 170 | object_bbox_xmin: xmin coordinates of groundtruth box, e.g. 10, 30 171 | object_bbox_xmax: xmax coordinates of groundtruth box, e.g. 50, 40 172 | object_bbox_ymin: ymin coordinates of groundtruth box, e.g. 40, 50 173 | object_bbox_ymax: ymax coordinates of groundtruth box, e.g. 80, 70 174 | object_view: viewpoint of object, e.g. ["frontal", "left"] 175 | object_truncated: is object truncated, e.g. [true, false] 176 | object_occluded: is object occluded, e.g. [true, false] 177 | object_difficult: is object difficult, e.g. [true, false] 178 | object_group_of: is object a single object or a group of objects 179 | object_depiction: is object a depiction 180 | object_is_crowd: [DEPRECATED, use object_group_of instead] 181 | is the object a single object or a crowd 182 | object_segment_area: the area of the segment. 183 | object_weight: a weight factor for the object's bounding box. 184 | instance_masks: instance segmentation masks. 185 | instance_boundaries: instance boundaries. 186 | instance_classes: Classes for each instance segmentation mask. 187 | detection_class_label: class label in numbers. 188 | detection_bbox_ymin: ymin coordinates of a detection box. 189 | detection_bbox_xmin: xmin coordinates of a detection box. 190 | detection_bbox_ymax: ymax coordinates of a detection box. 191 | detection_bbox_xmax: xmax coordinates of a detection box. 192 | detection_score: detection score for the class label and box. 193 | """ 194 | image_encoded = 'image/encoded' 195 | image_format = 'image/format' # format is reserved keyword 196 | filename = 'image/filename' 197 | channels = 'image/channels' 198 | colorspace = 'image/colorspace' 199 | height = 'image/height' 200 | width = 'image/width' 201 | source_id = 'image/source_id' 202 | image_class_text = 'image/class/text' 203 | image_class_label = 'image/class/label' 204 | object_class_text = 'image/object/class/text' 205 | object_class_label = 'image/object/class/label' 206 | object_bbox_ymin = 'image/object/bbox/ymin' 207 | object_bbox_xmin = 'image/object/bbox/xmin' 208 | object_bbox_ymax = 'image/object/bbox/ymax' 209 | object_bbox_xmax = 'image/object/bbox/xmax' 210 | object_view = 'image/object/view' 211 | object_truncated = 'image/object/truncated' 212 | object_occluded = 'image/object/occluded' 213 | object_difficult = 'image/object/difficult' 214 | object_group_of = 'image/object/group_of' 215 | object_depiction = 'image/object/depiction' 216 | object_is_crowd = 'image/object/is_crowd' 217 | object_segment_area = 'image/object/segment/area' 218 | object_weight = 'image/object/weight' 219 | instance_masks = 'image/segmentation/object' 220 | instance_boundaries = 'image/boundaries/object' 221 | instance_classes = 'image/segmentation/object/class' 222 | detection_class_label = 'image/detection/label' 223 | detection_bbox_ymin = 'image/detection/bbox/ymin' 224 | detection_bbox_xmin = 'image/detection/bbox/xmin' 225 | detection_bbox_ymax = 'image/detection/bbox/ymax' 226 | detection_bbox_xmax = 'image/detection/bbox/xmax' 227 | detection_score = 'image/detection/score' 228 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2019 BMW Group 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/main/object_detection/utils/shape_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utils used to manipulate tensor shapes.""" 17 | 18 | import tensorflow as tf 19 | 20 | from object_detection.utils import static_shape 21 | 22 | 23 | def _is_tensor(t): 24 | """Returns a boolean indicating whether the input is a tensor. 25 | 26 | Args: 27 | t: the input to be tested. 28 | 29 | Returns: 30 | a boolean that indicates whether t is a tensor. 31 | """ 32 | return isinstance(t, (tf.Tensor, tf.SparseTensor, tf.Variable)) 33 | 34 | 35 | def _set_dim_0(t, d0): 36 | """Sets the 0-th dimension of the input tensor. 37 | 38 | Args: 39 | t: the input tensor, assuming the rank is at least 1. 40 | d0: an integer indicating the 0-th dimension of the input tensor. 41 | 42 | Returns: 43 | the tensor t with the 0-th dimension set. 44 | """ 45 | t_shape = t.get_shape().as_list() 46 | t_shape[0] = d0 47 | t.set_shape(t_shape) 48 | return t 49 | 50 | 51 | def pad_tensor(t, length): 52 | """Pads the input tensor with 0s along the first dimension up to the length. 53 | 54 | Args: 55 | t: the input tensor, assuming the rank is at least 1. 56 | length: a tensor of shape [1] or an integer, indicating the first dimension 57 | of the input tensor t after padding, assuming length <= t.shape[0]. 58 | 59 | Returns: 60 | padded_t: the padded tensor, whose first dimension is length. If the length 61 | is an integer, the first dimension of padded_t is set to length 62 | statically. 63 | """ 64 | t_rank = tf.rank(t) 65 | t_shape = tf.shape(t) 66 | t_d0 = t_shape[0] 67 | pad_d0 = tf.expand_dims(length - t_d0, 0) 68 | pad_shape = tf.cond( 69 | tf.greater(t_rank, 1), lambda: tf.concat([pad_d0, t_shape[1:]], 0), 70 | lambda: tf.expand_dims(length - t_d0, 0)) 71 | padded_t = tf.concat([t, tf.zeros(pad_shape, dtype=t.dtype)], 0) 72 | if not _is_tensor(length): 73 | padded_t = _set_dim_0(padded_t, length) 74 | return padded_t 75 | 76 | 77 | def clip_tensor(t, length): 78 | """Clips the input tensor along the first dimension up to the length. 79 | 80 | Args: 81 | t: the input tensor, assuming the rank is at least 1. 82 | length: a tensor of shape [1] or an integer, indicating the first dimension 83 | of the input tensor t after clipping, assuming length <= t.shape[0]. 84 | 85 | Returns: 86 | clipped_t: the clipped tensor, whose first dimension is length. If the 87 | length is an integer, the first dimension of clipped_t is set to length 88 | statically. 89 | """ 90 | clipped_t = tf.gather(t, tf.range(length)) 91 | if not _is_tensor(length): 92 | clipped_t = _set_dim_0(clipped_t, length) 93 | return clipped_t 94 | 95 | 96 | def pad_or_clip_tensor(t, length): 97 | """Pad or clip the input tensor along the first dimension. 98 | 99 | Args: 100 | t: the input tensor, assuming the rank is at least 1. 101 | length: a tensor of shape [1] or an integer, indicating the first dimension 102 | of the input tensor t after processing. 103 | 104 | Returns: 105 | processed_t: the processed tensor, whose first dimension is length. If the 106 | length is an integer, the first dimension of the processed tensor is set 107 | to length statically. 108 | """ 109 | processed_t = tf.cond( 110 | tf.greater(tf.shape(t)[0], length), 111 | lambda: clip_tensor(t, length), 112 | lambda: pad_tensor(t, length)) 113 | if not _is_tensor(length): 114 | processed_t = _set_dim_0(processed_t, length) 115 | return processed_t 116 | 117 | 118 | def combined_static_and_dynamic_shape(tensor): 119 | """Returns a list containing static and dynamic values for the dimensions. 120 | 121 | Returns a list of static and dynamic values for shape dimensions. This is 122 | useful to preserve static shapes when available in reshape operation. 123 | 124 | Args: 125 | tensor: A tensor of any type. 126 | 127 | Returns: 128 | A list of size tensor.shape.ndims containing integers or a scalar tensor. 129 | """ 130 | static_tensor_shape = tensor.shape.as_list() 131 | dynamic_tensor_shape = tf.shape(tensor) 132 | combined_shape = [] 133 | for index, dim in enumerate(static_tensor_shape): 134 | if dim is not None: 135 | combined_shape.append(dim) 136 | else: 137 | combined_shape.append(dynamic_tensor_shape[index]) 138 | return combined_shape 139 | 140 | 141 | def static_or_dynamic_map_fn(fn, elems, dtype=None, 142 | parallel_iterations=32, back_prop=True): 143 | """Runs map_fn as a (static) for loop when possible. 144 | 145 | This function rewrites the map_fn as an explicit unstack input -> for loop 146 | over function calls -> stack result combination. This allows our graphs to 147 | be acyclic when the batch size is static. 148 | For comparison, see https://www.tensorflow.org/api_docs/python/tf/map_fn. 149 | 150 | Note that `static_or_dynamic_map_fn` currently is not *fully* interchangeable 151 | with the default tf.map_fn function as it does not accept nested inputs (only 152 | Tensors or lists of Tensors). Likewise, the output of `fn` can only be a 153 | Tensor or list of Tensors. 154 | 155 | TODO(jonathanhuang): make this function fully interchangeable with tf.map_fn. 156 | 157 | Args: 158 | fn: The callable to be performed. It accepts one argument, which will have 159 | the same structure as elems. Its output must have the 160 | same structure as elems. 161 | elems: A tensor or list of tensors, each of which will 162 | be unpacked along their first dimension. The sequence of the 163 | resulting slices will be applied to fn. 164 | dtype: (optional) The output type(s) of fn. If fn returns a structure of 165 | Tensors differing from the structure of elems, then dtype is not optional 166 | and must have the same structure as the output of fn. 167 | parallel_iterations: (optional) number of batch items to process in 168 | parallel. This flag is only used if the native tf.map_fn is used 169 | and defaults to 32 instead of 10 (unlike the standard tf.map_fn default). 170 | back_prop: (optional) True enables support for back propagation. 171 | This flag is only used if the native tf.map_fn is used. 172 | 173 | Returns: 174 | A tensor or sequence of tensors. Each tensor packs the 175 | results of applying fn to tensors unpacked from elems along the first 176 | dimension, from first to last. 177 | Raises: 178 | ValueError: if `elems` a Tensor or a list of Tensors. 179 | ValueError: if `fn` does not return a Tensor or list of Tensors 180 | """ 181 | if isinstance(elems, list): 182 | for elem in elems: 183 | if not isinstance(elem, tf.Tensor): 184 | raise ValueError('`elems` must be a Tensor or list of Tensors.') 185 | 186 | elem_shapes = [elem.shape.as_list() for elem in elems] 187 | # Fall back on tf.map_fn if shapes of each entry of `elems` are None or fail 188 | # to all be the same size along the batch dimension. 189 | for elem_shape in elem_shapes: 190 | if (not elem_shape or not elem_shape[0] 191 | or elem_shape[0] != elem_shapes[0][0]): 192 | return tf.map_fn(fn, elems, dtype, parallel_iterations, back_prop) 193 | arg_tuples = zip(*[tf.unstack(elem) for elem in elems]) 194 | outputs = [fn(arg_tuple) for arg_tuple in arg_tuples] 195 | else: 196 | if not isinstance(elems, tf.Tensor): 197 | raise ValueError('`elems` must be a Tensor or list of Tensors.') 198 | elems_shape = elems.shape.as_list() 199 | if not elems_shape or not elems_shape[0]: 200 | return tf.map_fn(fn, elems, dtype, parallel_iterations, back_prop) 201 | outputs = [fn(arg) for arg in tf.unstack(elems)] 202 | # Stack `outputs`, which is a list of Tensors or list of lists of Tensors 203 | if all([isinstance(output, tf.Tensor) for output in outputs]): 204 | return tf.stack(outputs) 205 | else: 206 | if all([isinstance(output, list) for output in outputs]): 207 | if all([all( 208 | [isinstance(entry, tf.Tensor) for entry in output_list]) 209 | for output_list in outputs]): 210 | return [tf.stack(output_tuple) for output_tuple in zip(*outputs)] 211 | raise ValueError('`fn` should return a Tensor or a list of Tensors.') 212 | 213 | 214 | def check_min_image_dim(min_dim, image_tensor): 215 | """Checks that the image width/height are greater than some number. 216 | 217 | This function is used to check that the width and height of an image are above 218 | a certain value. If the image shape is static, this function will perform the 219 | check at graph construction time. Otherwise, if the image shape varies, an 220 | Assertion control dependency will be added to the graph. 221 | 222 | Args: 223 | min_dim: The minimum number of pixels along the width and height of the 224 | image. 225 | image_tensor: The image tensor to check size for. 226 | 227 | Returns: 228 | If `image_tensor` has dynamic size, return `image_tensor` with a Assert 229 | control dependency. Otherwise returns image_tensor. 230 | 231 | Raises: 232 | ValueError: if `image_tensor`'s' width or height is smaller than `min_dim`. 233 | """ 234 | image_shape = image_tensor.get_shape() 235 | image_height = static_shape.get_height(image_shape) 236 | image_width = static_shape.get_width(image_shape) 237 | if image_height is None or image_width is None: 238 | shape_assert = tf.Assert( 239 | tf.logical_and(tf.greater_equal(tf.shape(image_tensor)[1], min_dim), 240 | tf.greater_equal(tf.shape(image_tensor)[2], min_dim)), 241 | ['image size must be >= {} in both height and width.'.format(min_dim)]) 242 | with tf.control_dependencies([shape_assert]): 243 | return tf.identity(image_tensor) 244 | 245 | if image_height < min_dim or image_width < min_dim: 246 | raise ValueError( 247 | 'image size must be >= %d in both height and width; image dim = %d,%d' % 248 | (min_dim, image_height, image_width)) 249 | 250 | return image_tensor 251 | 252 | 253 | def assert_shape_equal(shape_a, shape_b): 254 | """Asserts that shape_a and shape_b are equal. 255 | 256 | If the shapes are static, raises a ValueError when the shapes 257 | mismatch. 258 | 259 | If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes 260 | mismatch. 261 | 262 | Args: 263 | shape_a: a list containing shape of the first tensor. 264 | shape_b: a list containing shape of the second tensor. 265 | 266 | Returns: 267 | Either a tf.no_op() when shapes are all static and a tf.assert_equal() op 268 | when the shapes are dynamic. 269 | 270 | Raises: 271 | ValueError: When shapes are both static and unequal. 272 | """ 273 | if (all(isinstance(dim, int) for dim in shape_a) and 274 | all(isinstance(dim, int) for dim in shape_b)): 275 | if shape_a != shape_b: 276 | raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b)) 277 | else: return tf.no_op() 278 | else: 279 | return tf.assert_equal(shape_a, shape_b) 280 | 281 | 282 | def assert_shape_equal_along_first_dimension(shape_a, shape_b): 283 | """Asserts that shape_a and shape_b are the same along the 0th-dimension. 284 | 285 | If the shapes are static, raises a ValueError when the shapes 286 | mismatch. 287 | 288 | If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes 289 | mismatch. 290 | 291 | Args: 292 | shape_a: a list containing shape of the first tensor. 293 | shape_b: a list containing shape of the second tensor. 294 | 295 | Returns: 296 | Either a tf.no_op() when shapes are all static and a tf.assert_equal() op 297 | when the shapes are dynamic. 298 | 299 | Raises: 300 | ValueError: When shapes are both static and unequal. 301 | """ 302 | if isinstance(shape_a[0], int) and isinstance(shape_b[0], int): 303 | if shape_a[0] != shape_b[0]: 304 | raise ValueError('Unequal first dimension {}, {}'.format( 305 | shape_a[0], shape_b[0])) 306 | else: return tf.no_op() 307 | else: 308 | return tf.assert_equal(shape_a[0], shape_b[0]) 309 | 310 | -------------------------------------------------------------------------------- /src/main/object_detection/utils/visualization_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A set of functions that are used for visualization. 17 | 18 | These functions often receive an image, perform some visualization on the image. 19 | The functions do not return a value, instead they modify the image itself. 20 | 21 | """ 22 | import collections 23 | import functools 24 | # Set headless-friendly backend. 25 | import matplotlib; matplotlib.use('Agg') # pylint: disable=multiple-statements 26 | import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top 27 | import numpy as np 28 | import PIL.Image as Image 29 | import PIL.ImageColor as ImageColor 30 | import PIL.ImageDraw as ImageDraw 31 | import PIL.ImageFont as ImageFont 32 | import six 33 | import tensorflow as tf 34 | 35 | from object_detection.core import standard_fields as fields 36 | 37 | 38 | _TITLE_LEFT_MARGIN = 10 39 | _TITLE_TOP_MARGIN = 10 40 | STANDARD_COLORS = [ 41 | 'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque', 42 | 'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite', 43 | 'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan', 44 | 'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange', 45 | 'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet', 46 | 'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite', 47 | 'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod', 48 | 'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki', 49 | 'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue', 50 | 'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey', 51 | 'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue', 52 | 'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime', 53 | 'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid', 54 | 'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen', 55 | 'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin', 56 | 'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed', 57 | 'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed', 58 | 'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple', 59 | 'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown', 60 | 'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue', 61 | 'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow', 62 | 'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White', 63 | 'WhiteSmoke', 'Yellow', 'YellowGreen' 64 | ] 65 | 66 | 67 | def save_image_array_as_png(image, output_path): 68 | """Saves an image (represented as a numpy array) to PNG. 69 | 70 | Args: 71 | image: a numpy array with shape [height, width, 3]. 72 | output_path: path to which image should be written. 73 | """ 74 | image_pil = Image.fromarray(np.uint8(image)).convert('RGB') 75 | with tf.gfile.Open(output_path, 'w') as fid: 76 | image_pil.save(fid, 'PNG') 77 | 78 | 79 | def encode_image_array_as_png_str(image): 80 | """Encodes a numpy array into a PNG string. 81 | 82 | Args: 83 | image: a numpy array with shape [height, width, 3]. 84 | 85 | Returns: 86 | PNG encoded image string. 87 | """ 88 | image_pil = Image.fromarray(np.uint8(image)) 89 | output = six.BytesIO() 90 | image_pil.save(output, format='PNG') 91 | png_string = output.getvalue() 92 | output.close() 93 | return png_string 94 | 95 | 96 | def draw_bounding_box_on_image_array(image, 97 | ymin, 98 | xmin, 99 | ymax, 100 | xmax, 101 | color='red', 102 | thickness=4, 103 | display_str_list=(), 104 | use_normalized_coordinates=True): 105 | """Adds a bounding box to an image (numpy array). 106 | 107 | Bounding box coordinates can be specified in either absolute (pixel) or 108 | normalized coordinates by setting the use_normalized_coordinates argument. 109 | 110 | Args: 111 | image: a numpy array with shape [height, width, 3]. 112 | ymin: ymin of bounding box. 113 | xmin: xmin of bounding box. 114 | ymax: ymax of bounding box. 115 | xmax: xmax of bounding box. 116 | color: color to draw bounding box. Default is red. 117 | thickness: line thickness. Default value is 4. 118 | display_str_list: list of strings to display in box 119 | (each to be shown on its own line). 120 | use_normalized_coordinates: If True (default), treat coordinates 121 | ymin, xmin, ymax, xmax as relative to the image. Otherwise treat 122 | coordinates as absolute. 123 | """ 124 | image_pil = Image.fromarray(np.uint8(image)).convert('RGB') 125 | draw_bounding_box_on_image(image_pil, ymin, xmin, ymax, xmax, color, 126 | thickness, display_str_list, 127 | use_normalized_coordinates) 128 | np.copyto(image, np.array(image_pil)) 129 | 130 | 131 | def draw_bounding_box_on_image(image, 132 | ymin, 133 | xmin, 134 | ymax, 135 | xmax, 136 | color='red', 137 | thickness=4, 138 | display_str_list=(), 139 | use_normalized_coordinates=True): 140 | """Adds a bounding box to an image. 141 | 142 | Bounding box coordinates can be specified in either absolute (pixel) or 143 | normalized coordinates by setting the use_normalized_coordinates argument. 144 | 145 | Each string in display_str_list is displayed on a separate line above the 146 | bounding box in black text on a rectangle filled with the input 'color'. 147 | If the top of the bounding box extends to the edge of the image, the strings 148 | are displayed below the bounding box. 149 | 150 | Args: 151 | image: a PIL.Image object. 152 | ymin: ymin of bounding box. 153 | xmin: xmin of bounding box. 154 | ymax: ymax of bounding box. 155 | xmax: xmax of bounding box. 156 | color: color to draw bounding box. Default is red. 157 | thickness: line thickness. Default value is 4. 158 | display_str_list: list of strings to display in box 159 | (each to be shown on its own line). 160 | use_normalized_coordinates: If True (default), treat coordinates 161 | ymin, xmin, ymax, xmax as relative to the image. Otherwise treat 162 | coordinates as absolute. 163 | """ 164 | draw = ImageDraw.Draw(image) 165 | im_width, im_height = image.size 166 | if use_normalized_coordinates: 167 | (left, right, top, bottom) = (xmin * im_width, xmax * im_width, 168 | ymin * im_height, ymax * im_height) 169 | else: 170 | (left, right, top, bottom) = (xmin, xmax, ymin, ymax) 171 | draw.line([(left, top), (left, bottom), (right, bottom), 172 | (right, top), (left, top)], width=thickness, fill=color) 173 | try: 174 | font = ImageFont.truetype('arial.ttf', 24) 175 | except IOError: 176 | font = ImageFont.load_default() 177 | 178 | # If the total height of the display strings added to the top of the bounding 179 | # box exceeds the top of the image, stack the strings below the bounding box 180 | # instead of above. 181 | display_str_heights = [font.getsize(ds)[1] for ds in display_str_list] 182 | # Each display_str has a top and bottom margin of 0.05x. 183 | total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights) 184 | 185 | if top > total_display_str_height: 186 | text_bottom = top 187 | else: 188 | text_bottom = bottom + total_display_str_height 189 | # Reverse list and print from bottom to top. 190 | for display_str in display_str_list[::-1]: 191 | text_width, text_height = font.getsize(display_str) 192 | margin = np.ceil(0.05 * text_height) 193 | draw.rectangle( 194 | [(left, text_bottom - text_height - 2 * margin), (left + text_width, 195 | text_bottom)], 196 | fill=color) 197 | draw.text( 198 | (left + margin, text_bottom - text_height - margin), 199 | display_str, 200 | fill='black', 201 | font=font) 202 | text_bottom -= text_height - 2 * margin 203 | 204 | 205 | def draw_bounding_boxes_on_image_array(image, 206 | boxes, 207 | color='red', 208 | thickness=4, 209 | display_str_list_list=()): 210 | """Draws bounding boxes on image (numpy array). 211 | 212 | Args: 213 | image: a numpy array object. 214 | boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax). 215 | The coordinates are in normalized format between [0, 1]. 216 | color: color to draw bounding box. Default is red. 217 | thickness: line thickness. Default value is 4. 218 | display_str_list_list: list of list of strings. 219 | a list of strings for each bounding box. 220 | The reason to pass a list of strings for a 221 | bounding box is that it might contain 222 | multiple labels. 223 | 224 | Raises: 225 | ValueError: if boxes is not a [N, 4] array 226 | """ 227 | image_pil = Image.fromarray(image) 228 | draw_bounding_boxes_on_image(image_pil, boxes, color, thickness, 229 | display_str_list_list) 230 | np.copyto(image, np.array(image_pil)) 231 | 232 | 233 | def draw_bounding_boxes_on_image(image, 234 | boxes, 235 | color='red', 236 | thickness=4, 237 | display_str_list_list=()): 238 | """Draws bounding boxes on image. 239 | 240 | Args: 241 | image: a PIL.Image object. 242 | boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax). 243 | The coordinates are in normalized format between [0, 1]. 244 | color: color to draw bounding box. Default is red. 245 | thickness: line thickness. Default value is 4. 246 | display_str_list_list: list of list of strings. 247 | a list of strings for each bounding box. 248 | The reason to pass a list of strings for a 249 | bounding box is that it might contain 250 | multiple labels. 251 | 252 | Raises: 253 | ValueError: if boxes is not a [N, 4] array 254 | """ 255 | boxes_shape = boxes.shape 256 | if not boxes_shape: 257 | return 258 | if len(boxes_shape) != 2 or boxes_shape[1] != 4: 259 | raise ValueError('Input must be of size [N, 4]') 260 | for i in range(boxes_shape[0]): 261 | display_str_list = () 262 | if display_str_list_list: 263 | display_str_list = display_str_list_list[i] 264 | draw_bounding_box_on_image(image, boxes[i, 0], boxes[i, 1], boxes[i, 2], 265 | boxes[i, 3], color, thickness, display_str_list) 266 | 267 | 268 | def _visualize_boxes(image, boxes, classes, scores, category_index, **kwargs): 269 | return visualize_boxes_and_labels_on_image_array( 270 | image, boxes, classes, scores, category_index=category_index, **kwargs) 271 | 272 | 273 | def _visualize_boxes_and_masks(image, boxes, classes, scores, masks, 274 | category_index, **kwargs): 275 | return visualize_boxes_and_labels_on_image_array( 276 | image, 277 | boxes, 278 | classes, 279 | scores, 280 | category_index=category_index, 281 | instance_masks=masks, 282 | **kwargs) 283 | 284 | 285 | def _visualize_boxes_and_keypoints(image, boxes, classes, scores, keypoints, 286 | category_index, **kwargs): 287 | return visualize_boxes_and_labels_on_image_array( 288 | image, 289 | boxes, 290 | classes, 291 | scores, 292 | category_index=category_index, 293 | keypoints=keypoints, 294 | **kwargs) 295 | 296 | 297 | def _visualize_boxes_and_masks_and_keypoints( 298 | image, boxes, classes, scores, masks, keypoints, category_index, **kwargs): 299 | return visualize_boxes_and_labels_on_image_array( 300 | image, 301 | boxes, 302 | classes, 303 | scores, 304 | category_index=category_index, 305 | instance_masks=masks, 306 | keypoints=keypoints, 307 | **kwargs) 308 | 309 | 310 | def draw_bounding_boxes_on_image_tensors(images, 311 | boxes, 312 | classes, 313 | scores, 314 | category_index, 315 | instance_masks=None, 316 | keypoints=None, 317 | max_boxes_to_draw=20, 318 | min_score_thresh=0.2, 319 | use_normalized_coordinates=True): 320 | """Draws bounding boxes, masks, and keypoints on batch of image tensors. 321 | 322 | Args: 323 | images: A 4D uint8 image tensor of shape [N, H, W, C]. If C > 3, additional 324 | channels will be ignored. 325 | boxes: [N, max_detections, 4] float32 tensor of detection boxes. 326 | classes: [N, max_detections] int tensor of detection classes. Note that 327 | classes are 1-indexed. 328 | scores: [N, max_detections] float32 tensor of detection scores. 329 | category_index: a dict that maps integer ids to category dicts. e.g. 330 | {1: {1: 'dog'}, 2: {2: 'cat'}, ...} 331 | instance_masks: A 4D uint8 tensor of shape [N, max_detection, H, W] with 332 | instance masks. 333 | keypoints: A 4D float32 tensor of shape [N, max_detection, num_keypoints, 2] 334 | with keypoints. 335 | max_boxes_to_draw: Maximum number of boxes to draw on an image. Default 20. 336 | min_score_thresh: Minimum score threshold for visualization. Default 0.2. 337 | use_normalized_coordinates: Whether to assume boxes and kepoints are in 338 | normalized coordinates (as opposed to absolute coordiantes). 339 | Default is True. 340 | 341 | Returns: 342 | 4D image tensor of type uint8, with boxes drawn on top. 343 | """ 344 | # Additional channels are being ignored. 345 | images = images[:, :, :, 0:3] 346 | visualization_keyword_args = { 347 | 'use_normalized_coordinates': use_normalized_coordinates, 348 | 'max_boxes_to_draw': max_boxes_to_draw, 349 | 'min_score_thresh': min_score_thresh, 350 | 'agnostic_mode': False, 351 | 'line_thickness': 4 352 | } 353 | 354 | if instance_masks is not None and keypoints is None: 355 | visualize_boxes_fn = functools.partial( 356 | _visualize_boxes_and_masks, 357 | category_index=category_index, 358 | **visualization_keyword_args) 359 | elems = [images, boxes, classes, scores, instance_masks] 360 | elif instance_masks is None and keypoints is not None: 361 | visualize_boxes_fn = functools.partial( 362 | _visualize_boxes_and_keypoints, 363 | category_index=category_index, 364 | **visualization_keyword_args) 365 | elems = [images, boxes, classes, scores, keypoints] 366 | elif instance_masks is not None and keypoints is not None: 367 | visualize_boxes_fn = functools.partial( 368 | _visualize_boxes_and_masks_and_keypoints, 369 | category_index=category_index, 370 | **visualization_keyword_args) 371 | elems = [images, boxes, classes, scores, instance_masks, keypoints] 372 | else: 373 | visualize_boxes_fn = functools.partial( 374 | _visualize_boxes, 375 | category_index=category_index, 376 | **visualization_keyword_args) 377 | elems = [images, boxes, classes, scores] 378 | 379 | def draw_boxes(image_and_detections): 380 | """Draws boxes on image.""" 381 | image_with_boxes = tf.py_func(visualize_boxes_fn, image_and_detections, 382 | tf.uint8) 383 | return image_with_boxes 384 | 385 | images = tf.map_fn(draw_boxes, elems, dtype=tf.uint8, back_prop=False) 386 | return images 387 | 388 | 389 | def draw_side_by_side_evaluation_image(eval_dict, 390 | category_index, 391 | max_boxes_to_draw=20, 392 | min_score_thresh=0.2, 393 | use_normalized_coordinates=True): 394 | """Creates a side-by-side image with detections and groundtruth. 395 | 396 | Bounding boxes (and instance masks, if available) are visualized on both 397 | subimages. 398 | 399 | Args: 400 | eval_dict: The evaluation dictionary returned by 401 | eval_util.result_dict_for_single_example(). 402 | category_index: A category index (dictionary) produced from a labelmap. 403 | max_boxes_to_draw: The maximum number of boxes to draw for detections. 404 | min_score_thresh: The minimum score threshold for showing detections. 405 | use_normalized_coordinates: Whether to assume boxes and kepoints are in 406 | normalized coordinates (as opposed to absolute coordiantes). 407 | Default is True. 408 | 409 | Returns: 410 | A [1, H, 2 * W, C] uint8 tensor. The subimage on the left corresponds to 411 | detections, while the subimage on the right corresponds to groundtruth. 412 | """ 413 | detection_fields = fields.DetectionResultFields() 414 | input_data_fields = fields.InputDataFields() 415 | instance_masks = None 416 | if detection_fields.detection_masks in eval_dict: 417 | instance_masks = tf.cast( 418 | tf.expand_dims(eval_dict[detection_fields.detection_masks], axis=0), 419 | tf.uint8) 420 | keypoints = None 421 | if detection_fields.detection_keypoints in eval_dict: 422 | keypoints = tf.expand_dims( 423 | eval_dict[detection_fields.detection_keypoints], axis=0) 424 | groundtruth_instance_masks = None 425 | if input_data_fields.groundtruth_instance_masks in eval_dict: 426 | groundtruth_instance_masks = tf.cast( 427 | tf.expand_dims( 428 | eval_dict[input_data_fields.groundtruth_instance_masks], axis=0), 429 | tf.uint8) 430 | images_with_detections = draw_bounding_boxes_on_image_tensors( 431 | eval_dict[input_data_fields.original_image], 432 | tf.expand_dims(eval_dict[detection_fields.detection_boxes], axis=0), 433 | tf.expand_dims(eval_dict[detection_fields.detection_classes], axis=0), 434 | tf.expand_dims(eval_dict[detection_fields.detection_scores], axis=0), 435 | category_index, 436 | instance_masks=instance_masks, 437 | keypoints=keypoints, 438 | max_boxes_to_draw=max_boxes_to_draw, 439 | min_score_thresh=min_score_thresh, 440 | use_normalized_coordinates=use_normalized_coordinates) 441 | images_with_groundtruth = draw_bounding_boxes_on_image_tensors( 442 | eval_dict[input_data_fields.original_image], 443 | tf.expand_dims(eval_dict[input_data_fields.groundtruth_boxes], axis=0), 444 | tf.expand_dims(eval_dict[input_data_fields.groundtruth_classes], axis=0), 445 | tf.expand_dims( 446 | tf.ones_like( 447 | eval_dict[input_data_fields.groundtruth_classes], 448 | dtype=tf.float32), 449 | axis=0), 450 | category_index, 451 | instance_masks=groundtruth_instance_masks, 452 | keypoints=None, 453 | max_boxes_to_draw=None, 454 | min_score_thresh=0.0, 455 | use_normalized_coordinates=use_normalized_coordinates) 456 | return tf.concat([images_with_detections, images_with_groundtruth], axis=2) 457 | 458 | 459 | def draw_keypoints_on_image_array(image, 460 | keypoints, 461 | color='red', 462 | radius=2, 463 | use_normalized_coordinates=True): 464 | """Draws keypoints on an image (numpy array). 465 | 466 | Args: 467 | image: a numpy array with shape [height, width, 3]. 468 | keypoints: a numpy array with shape [num_keypoints, 2]. 469 | color: color to draw the keypoints with. Default is red. 470 | radius: keypoint radius. Default value is 2. 471 | use_normalized_coordinates: if True (default), treat keypoint values as 472 | relative to the image. Otherwise treat them as absolute. 473 | """ 474 | image_pil = Image.fromarray(np.uint8(image)).convert('RGB') 475 | draw_keypoints_on_image(image_pil, keypoints, color, radius, 476 | use_normalized_coordinates) 477 | np.copyto(image, np.array(image_pil)) 478 | 479 | 480 | def draw_keypoints_on_image(image, 481 | keypoints, 482 | color='red', 483 | radius=2, 484 | use_normalized_coordinates=True): 485 | """Draws keypoints on an image. 486 | 487 | Args: 488 | image: a PIL.Image object. 489 | keypoints: a numpy array with shape [num_keypoints, 2]. 490 | color: color to draw the keypoints with. Default is red. 491 | radius: keypoint radius. Default value is 2. 492 | use_normalized_coordinates: if True (default), treat keypoint values as 493 | relative to the image. Otherwise treat them as absolute. 494 | """ 495 | draw = ImageDraw.Draw(image) 496 | im_width, im_height = image.size 497 | keypoints_x = [k[1] for k in keypoints] 498 | keypoints_y = [k[0] for k in keypoints] 499 | if use_normalized_coordinates: 500 | keypoints_x = tuple([im_width * x for x in keypoints_x]) 501 | keypoints_y = tuple([im_height * y for y in keypoints_y]) 502 | for keypoint_x, keypoint_y in zip(keypoints_x, keypoints_y): 503 | draw.ellipse([(keypoint_x - radius, keypoint_y - radius), 504 | (keypoint_x + radius, keypoint_y + radius)], 505 | outline=color, fill=color) 506 | 507 | 508 | def draw_mask_on_image_array(image, mask, color='red', alpha=0.4): 509 | """Draws mask on an image. 510 | 511 | Args: 512 | image: uint8 numpy array with shape (img_height, img_height, 3) 513 | mask: a uint8 numpy array of shape (img_height, img_height) with 514 | values between either 0 or 1. 515 | color: color to draw the keypoints with. Default is red. 516 | alpha: transparency value between 0 and 1. (default: 0.4) 517 | 518 | Raises: 519 | ValueError: On incorrect data type for image or masks. 520 | """ 521 | if image.dtype != np.uint8: 522 | raise ValueError('`image` not of type np.uint8') 523 | if mask.dtype != np.uint8: 524 | raise ValueError('`mask` not of type np.uint8') 525 | if np.any(np.logical_and(mask != 1, mask != 0)): 526 | raise ValueError('`mask` elements should be in [0, 1]') 527 | if image.shape[:2] != mask.shape: 528 | raise ValueError('The image has spatial dimensions %s but the mask has ' 529 | 'dimensions %s' % (image.shape[:2], mask.shape)) 530 | rgb = ImageColor.getrgb(color) 531 | pil_image = Image.fromarray(image) 532 | 533 | solid_color = np.expand_dims( 534 | np.ones_like(mask), axis=2) * np.reshape(list(rgb), [1, 1, 3]) 535 | pil_solid_color = Image.fromarray(np.uint8(solid_color)).convert('RGBA') 536 | pil_mask = Image.fromarray(np.uint8(255.0*alpha*mask)).convert('L') 537 | pil_image = Image.composite(pil_solid_color, pil_image, pil_mask) 538 | np.copyto(image, np.array(pil_image.convert('RGB'))) 539 | 540 | 541 | def visualize_boxes_and_labels_on_image_array( 542 | image, 543 | boxes, 544 | classes, 545 | scores, 546 | category_index, 547 | instance_masks=None, 548 | instance_boundaries=None, 549 | keypoints=None, 550 | use_normalized_coordinates=False, 551 | max_boxes_to_draw=20, 552 | min_score_thresh=.5, 553 | agnostic_mode=False, 554 | line_thickness=4, 555 | groundtruth_box_visualization_color='black', 556 | skip_scores=False, 557 | skip_labels=False): 558 | """Overlay labeled boxes on an image with formatted scores and label names. 559 | 560 | This function groups boxes that correspond to the same location 561 | and creates a display string for each detection and overlays these 562 | on the image. Note that this function modifies the image in place, and returns 563 | that same image. 564 | 565 | Args: 566 | image: uint8 numpy array with shape (img_height, img_width, 3) 567 | boxes: a numpy array of shape [N, 4] 568 | classes: a numpy array of shape [N]. Note that class indices are 1-based, 569 | and match the keys in the label map. 570 | scores: a numpy array of shape [N] or None. If scores=None, then 571 | this function assumes that the boxes to be plotted are groundtruth 572 | boxes and plot all boxes as black with no classes or scores. 573 | category_index: a dict containing category dictionaries (each holding 574 | category index `id` and category name `name`) keyed by category indices. 575 | instance_masks: a numpy array of shape [N, image_height, image_width] with 576 | values ranging between 0 and 1, can be None. 577 | instance_boundaries: a numpy array of shape [N, image_height, image_width] 578 | with values ranging between 0 and 1, can be None. 579 | keypoints: a numpy array of shape [N, num_keypoints, 2], can 580 | be None 581 | use_normalized_coordinates: whether boxes is to be interpreted as 582 | normalized coordinates or not. 583 | max_boxes_to_draw: maximum number of boxes to visualize. If None, draw 584 | all boxes. 585 | min_score_thresh: minimum score threshold for a box to be visualized 586 | agnostic_mode: boolean (default: False) controlling whether to evaluate in 587 | class-agnostic mode or not. This mode will display scores but ignore 588 | classes. 589 | line_thickness: integer (default: 4) controlling line width of the boxes. 590 | groundtruth_box_visualization_color: box color for visualizing groundtruth 591 | boxes 592 | skip_scores: whether to skip score when drawing a single detection 593 | skip_labels: whether to skip label when drawing a single detection 594 | 595 | Returns: 596 | uint8 numpy array with shape (img_height, img_width, 3) with overlaid boxes. 597 | """ 598 | # Create a display string (and color) for every box location, group any boxes 599 | # that correspond to the same location. 600 | box_to_display_str_map = collections.defaultdict(list) 601 | box_to_color_map = collections.defaultdict(str) 602 | box_to_instance_masks_map = {} 603 | box_to_instance_boundaries_map = {} 604 | box_to_keypoints_map = collections.defaultdict(list) 605 | if not max_boxes_to_draw: 606 | max_boxes_to_draw = boxes.shape[0] 607 | for i in range(min(max_boxes_to_draw, boxes.shape[0])): 608 | if scores is None or scores[i] > min_score_thresh: 609 | box = tuple(boxes[i].tolist()) 610 | if instance_masks is not None: 611 | box_to_instance_masks_map[box] = instance_masks[i] 612 | if instance_boundaries is not None: 613 | box_to_instance_boundaries_map[box] = instance_boundaries[i] 614 | if keypoints is not None: 615 | box_to_keypoints_map[box].extend(keypoints[i]) 616 | if scores is None: 617 | box_to_color_map[box] = groundtruth_box_visualization_color 618 | else: 619 | display_str = '' 620 | if not skip_labels: 621 | if not agnostic_mode: 622 | if classes[i] in category_index.keys(): 623 | class_name = category_index[classes[i]]['name'] 624 | else: 625 | class_name = 'N/A' 626 | display_str = str(class_name) 627 | if not skip_scores: 628 | if not display_str: 629 | display_str = '{}%'.format(int(100*scores[i])) 630 | else: 631 | display_str = '{}: {}%'.format(display_str, int(100*scores[i])) 632 | box_to_display_str_map[box].append(display_str) 633 | if agnostic_mode: 634 | box_to_color_map[box] = 'DarkOrange' 635 | else: 636 | box_to_color_map[box] = STANDARD_COLORS[ 637 | classes[i] % len(STANDARD_COLORS)] 638 | 639 | # Draw all boxes onto image. 640 | for box, color in box_to_color_map.items(): 641 | ymin, xmin, ymax, xmax = box 642 | if instance_masks is not None: 643 | draw_mask_on_image_array( 644 | image, 645 | box_to_instance_masks_map[box], 646 | color=color 647 | ) 648 | if instance_boundaries is not None: 649 | draw_mask_on_image_array( 650 | image, 651 | box_to_instance_boundaries_map[box], 652 | color='red', 653 | alpha=1.0 654 | ) 655 | draw_bounding_box_on_image_array( 656 | image, 657 | ymin, 658 | xmin, 659 | ymax, 660 | xmax, 661 | color=color, 662 | thickness=line_thickness, 663 | display_str_list=box_to_display_str_map[box], 664 | use_normalized_coordinates=use_normalized_coordinates) 665 | if keypoints is not None: 666 | draw_keypoints_on_image_array( 667 | image, 668 | box_to_keypoints_map[box], 669 | color=color, 670 | radius=line_thickness / 2, 671 | use_normalized_coordinates=use_normalized_coordinates) 672 | 673 | return image 674 | 675 | 676 | def add_cdf_image_summary(values, name): 677 | """Adds a tf.summary.image for a CDF plot of the values. 678 | 679 | Normalizes `values` such that they sum to 1, plots the cumulative distribution 680 | function and creates a tf image summary. 681 | 682 | Args: 683 | values: a 1-D float32 tensor containing the values. 684 | name: name for the image summary. 685 | """ 686 | def cdf_plot(values): 687 | """Numpy function to plot CDF.""" 688 | normalized_values = values / np.sum(values) 689 | sorted_values = np.sort(normalized_values) 690 | cumulative_values = np.cumsum(sorted_values) 691 | fraction_of_examples = (np.arange(cumulative_values.size, dtype=np.float32) 692 | / cumulative_values.size) 693 | fig = plt.figure(frameon=False) 694 | ax = fig.add_subplot('111') 695 | ax.plot(fraction_of_examples, cumulative_values) 696 | ax.set_ylabel('cumulative normalized values') 697 | ax.set_xlabel('fraction of examples') 698 | fig.canvas.draw() 699 | width, height = fig.get_size_inches() * fig.get_dpi() 700 | image = np.fromstring(fig.canvas.tostring_rgb(), dtype='uint8').reshape( 701 | 1, int(height), int(width), 3) 702 | return image 703 | cdf_plot = tf.py_func(cdf_plot, [values], tf.uint8) 704 | tf.summary.image(name, cdf_plot) 705 | 706 | 707 | def add_hist_image_summary(values, bins, name): 708 | """Adds a tf.summary.image for a histogram plot of the values. 709 | 710 | Plots the histogram of values and creates a tf image summary. 711 | 712 | Args: 713 | values: a 1-D float32 tensor containing the values. 714 | bins: bin edges which will be directly passed to np.histogram. 715 | name: name for the image summary. 716 | """ 717 | 718 | def hist_plot(values, bins): 719 | """Numpy function to plot hist.""" 720 | fig = plt.figure(frameon=False) 721 | ax = fig.add_subplot('111') 722 | y, x = np.histogram(values, bins=bins) 723 | ax.plot(x[:-1], y) 724 | ax.set_ylabel('count') 725 | ax.set_xlabel('value') 726 | fig.canvas.draw() 727 | width, height = fig.get_size_inches() * fig.get_dpi() 728 | image = np.fromstring( 729 | fig.canvas.tostring_rgb(), dtype='uint8').reshape( 730 | 1, int(height), int(width), 3) 731 | return image 732 | hist_plot = tf.py_func(hist_plot, [values, bins], tf.uint8) 733 | tf.summary.image(name, hist_plot) 734 | -------------------------------------------------------------------------------- /src/main/object_detection/utils/ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A module for helper tensorflow ops.""" 17 | import math 18 | import numpy as np 19 | import six 20 | 21 | import tensorflow as tf 22 | 23 | from object_detection.core import box_list 24 | from object_detection.core import box_list_ops 25 | from object_detection.core import standard_fields as fields 26 | from object_detection.utils import shape_utils 27 | from object_detection.utils import static_shape 28 | 29 | 30 | def expanded_shape(orig_shape, start_dim, num_dims): 31 | """Inserts multiple ones into a shape vector. 32 | 33 | Inserts an all-1 vector of length num_dims at position start_dim into a shape. 34 | Can be combined with tf.reshape to generalize tf.expand_dims. 35 | 36 | Args: 37 | orig_shape: the shape into which the all-1 vector is added (int32 vector) 38 | start_dim: insertion position (int scalar) 39 | num_dims: length of the inserted all-1 vector (int scalar) 40 | Returns: 41 | An int32 vector of length tf.size(orig_shape) + num_dims. 42 | """ 43 | with tf.name_scope('ExpandedShape'): 44 | start_dim = tf.expand_dims(start_dim, 0) # scalar to rank-1 45 | before = tf.slice(orig_shape, [0], start_dim) 46 | add_shape = tf.ones(tf.reshape(num_dims, [1]), dtype=tf.int32) 47 | after = tf.slice(orig_shape, start_dim, [-1]) 48 | new_shape = tf.concat([before, add_shape, after], 0) 49 | return new_shape 50 | 51 | 52 | def normalized_to_image_coordinates(normalized_boxes, image_shape, 53 | parallel_iterations=32): 54 | """Converts a batch of boxes from normal to image coordinates. 55 | 56 | Args: 57 | normalized_boxes: a float32 tensor of shape [None, num_boxes, 4] in 58 | normalized coordinates. 59 | image_shape: a float32 tensor of shape [4] containing the image shape. 60 | parallel_iterations: parallelism for the map_fn op. 61 | 62 | Returns: 63 | absolute_boxes: a float32 tensor of shape [None, num_boxes, 4] containg the 64 | boxes in image coordinates. 65 | """ 66 | def _to_absolute_coordinates(normalized_boxes): 67 | return box_list_ops.to_absolute_coordinates( 68 | box_list.BoxList(normalized_boxes), 69 | image_shape[1], image_shape[2], check_range=False).get() 70 | 71 | absolute_boxes = shape_utils.static_or_dynamic_map_fn( 72 | _to_absolute_coordinates, 73 | elems=(normalized_boxes), 74 | dtype=tf.float32, 75 | parallel_iterations=parallel_iterations, 76 | back_prop=True) 77 | return absolute_boxes 78 | 79 | 80 | def meshgrid(x, y): 81 | """Tiles the contents of x and y into a pair of grids. 82 | 83 | Multidimensional analog of numpy.meshgrid, giving the same behavior if x and y 84 | are vectors. Generally, this will give: 85 | 86 | xgrid(i1, ..., i_m, j_1, ..., j_n) = x(j_1, ..., j_n) 87 | ygrid(i1, ..., i_m, j_1, ..., j_n) = y(i_1, ..., i_m) 88 | 89 | Keep in mind that the order of the arguments and outputs is reverse relative 90 | to the order of the indices they go into, done for compatibility with numpy. 91 | The output tensors have the same shapes. Specifically: 92 | 93 | xgrid.get_shape() = y.get_shape().concatenate(x.get_shape()) 94 | ygrid.get_shape() = y.get_shape().concatenate(x.get_shape()) 95 | 96 | Args: 97 | x: A tensor of arbitrary shape and rank. xgrid will contain these values 98 | varying in its last dimensions. 99 | y: A tensor of arbitrary shape and rank. ygrid will contain these values 100 | varying in its first dimensions. 101 | Returns: 102 | A tuple of tensors (xgrid, ygrid). 103 | """ 104 | with tf.name_scope('Meshgrid'): 105 | x = tf.convert_to_tensor(x) 106 | y = tf.convert_to_tensor(y) 107 | x_exp_shape = expanded_shape(tf.shape(x), 0, tf.rank(y)) 108 | y_exp_shape = expanded_shape(tf.shape(y), tf.rank(y), tf.rank(x)) 109 | 110 | xgrid = tf.tile(tf.reshape(x, x_exp_shape), y_exp_shape) 111 | ygrid = tf.tile(tf.reshape(y, y_exp_shape), x_exp_shape) 112 | new_shape = y.get_shape().concatenate(x.get_shape()) 113 | xgrid.set_shape(new_shape) 114 | ygrid.set_shape(new_shape) 115 | 116 | return xgrid, ygrid 117 | 118 | 119 | def fixed_padding(inputs, kernel_size, rate=1): 120 | """Pads the input along the spatial dimensions independently of input size. 121 | 122 | Args: 123 | inputs: A tensor of size [batch, height_in, width_in, channels]. 124 | kernel_size: The kernel to be used in the conv2d or max_pool2d operation. 125 | Should be a positive integer. 126 | rate: An integer, rate for atrous convolution. 127 | 128 | Returns: 129 | output: A tensor of size [batch, height_out, width_out, channels] with the 130 | input, either intact (if kernel_size == 1) or padded (if kernel_size > 1). 131 | """ 132 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) 133 | pad_total = kernel_size_effective - 1 134 | pad_beg = pad_total // 2 135 | pad_end = pad_total - pad_beg 136 | padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end], 137 | [pad_beg, pad_end], [0, 0]]) 138 | return padded_inputs 139 | 140 | 141 | def pad_to_multiple(tensor, multiple): 142 | """Returns the tensor zero padded to the specified multiple. 143 | 144 | Appends 0s to the end of the first and second dimension (height and width) of 145 | the tensor until both dimensions are a multiple of the input argument 146 | 'multiple'. E.g. given an input tensor of shape [1, 3, 5, 1] and an input 147 | multiple of 4, PadToMultiple will append 0s so that the resulting tensor will 148 | be of shape [1, 4, 8, 1]. 149 | 150 | Args: 151 | tensor: rank 4 float32 tensor, where 152 | tensor -> [batch_size, height, width, channels]. 153 | multiple: the multiple to pad to. 154 | 155 | Returns: 156 | padded_tensor: the tensor zero padded to the specified multiple. 157 | """ 158 | tensor_shape = tensor.get_shape() 159 | batch_size = static_shape.get_batch_size(tensor_shape) 160 | tensor_height = static_shape.get_height(tensor_shape) 161 | tensor_width = static_shape.get_width(tensor_shape) 162 | tensor_depth = static_shape.get_depth(tensor_shape) 163 | 164 | if batch_size is None: 165 | batch_size = tf.shape(tensor)[0] 166 | 167 | if tensor_height is None: 168 | tensor_height = tf.shape(tensor)[1] 169 | padded_tensor_height = tf.to_int32( 170 | tf.ceil(tf.to_float(tensor_height) / tf.to_float(multiple))) * multiple 171 | else: 172 | padded_tensor_height = int( 173 | math.ceil(float(tensor_height) / multiple) * multiple) 174 | 175 | if tensor_width is None: 176 | tensor_width = tf.shape(tensor)[2] 177 | padded_tensor_width = tf.to_int32( 178 | tf.ceil(tf.to_float(tensor_width) / tf.to_float(multiple))) * multiple 179 | else: 180 | padded_tensor_width = int( 181 | math.ceil(float(tensor_width) / multiple) * multiple) 182 | 183 | if tensor_depth is None: 184 | tensor_depth = tf.shape(tensor)[3] 185 | 186 | # Use tf.concat instead of tf.pad to preserve static shape 187 | if padded_tensor_height != tensor_height: 188 | height_pad = tf.zeros([ 189 | batch_size, padded_tensor_height - tensor_height, tensor_width, 190 | tensor_depth 191 | ]) 192 | tensor = tf.concat([tensor, height_pad], 1) 193 | if padded_tensor_width != tensor_width: 194 | width_pad = tf.zeros([ 195 | batch_size, padded_tensor_height, padded_tensor_width - tensor_width, 196 | tensor_depth 197 | ]) 198 | tensor = tf.concat([tensor, width_pad], 2) 199 | 200 | return tensor 201 | 202 | 203 | def padded_one_hot_encoding(indices, depth, left_pad): 204 | """Returns a zero padded one-hot tensor. 205 | 206 | This function converts a sparse representation of indices (e.g., [4]) to a 207 | zero padded one-hot representation (e.g., [0, 0, 0, 0, 1] with depth = 4 and 208 | left_pad = 1). If `indices` is empty, the result will simply be a tensor of 209 | shape (0, depth + left_pad). If depth = 0, then this function just returns 210 | `None`. 211 | 212 | Args: 213 | indices: an integer tensor of shape [num_indices]. 214 | depth: depth for the one-hot tensor (integer). 215 | left_pad: number of zeros to left pad the one-hot tensor with (integer). 216 | 217 | Returns: 218 | padded_onehot: a tensor with shape (num_indices, depth + left_pad). Returns 219 | `None` if the depth is zero. 220 | 221 | Raises: 222 | ValueError: if `indices` does not have rank 1 or if `left_pad` or `depth are 223 | either negative or non-integers. 224 | 225 | TODO(rathodv): add runtime checks for depth and indices. 226 | """ 227 | if depth < 0 or not isinstance(depth, six.integer_types): 228 | raise ValueError('`depth` must be a non-negative integer.') 229 | if left_pad < 0 or not isinstance(left_pad, six.integer_types): 230 | raise ValueError('`left_pad` must be a non-negative integer.') 231 | if depth == 0: 232 | return None 233 | 234 | rank = len(indices.get_shape().as_list()) 235 | if rank != 1: 236 | raise ValueError('`indices` must have rank 1, but has rank=%s' % rank) 237 | 238 | def one_hot_and_pad(): 239 | one_hot = tf.cast(tf.one_hot(tf.cast(indices, tf.int64), depth, 240 | on_value=1, off_value=0), tf.float32) 241 | return tf.pad(one_hot, [[0, 0], [left_pad, 0]], mode='CONSTANT') 242 | result = tf.cond(tf.greater(tf.size(indices), 0), one_hot_and_pad, 243 | lambda: tf.zeros((depth + left_pad, 0))) 244 | return tf.reshape(result, [-1, depth + left_pad]) 245 | 246 | 247 | def dense_to_sparse_boxes(dense_locations, dense_num_boxes, num_classes): 248 | """Converts bounding boxes from dense to sparse form. 249 | 250 | Args: 251 | dense_locations: a [max_num_boxes, 4] tensor in which only the first k rows 252 | are valid bounding box location coordinates, where k is the sum of 253 | elements in dense_num_boxes. 254 | dense_num_boxes: a [max_num_classes] tensor indicating the counts of 255 | various bounding box classes e.g. [1, 0, 0, 2] means that the first 256 | bounding box is of class 0 and the second and third bounding boxes are 257 | of class 3. The sum of elements in this tensor is the number of valid 258 | bounding boxes. 259 | num_classes: number of classes 260 | 261 | Returns: 262 | box_locations: a [num_boxes, 4] tensor containing only valid bounding 263 | boxes (i.e. the first num_boxes rows of dense_locations) 264 | box_classes: a [num_boxes] tensor containing the classes of each bounding 265 | box (e.g. dense_num_boxes = [1, 0, 0, 2] => box_classes = [0, 3, 3] 266 | """ 267 | 268 | num_valid_boxes = tf.reduce_sum(dense_num_boxes) 269 | box_locations = tf.slice(dense_locations, 270 | tf.constant([0, 0]), tf.stack([num_valid_boxes, 4])) 271 | tiled_classes = [tf.tile([i], tf.expand_dims(dense_num_boxes[i], 0)) 272 | for i in range(num_classes)] 273 | box_classes = tf.concat(tiled_classes, 0) 274 | box_locations.set_shape([None, 4]) 275 | return box_locations, box_classes 276 | 277 | 278 | def indices_to_dense_vector(indices, 279 | size, 280 | indices_value=1., 281 | default_value=0, 282 | dtype=tf.float32): 283 | """Creates dense vector with indices set to specific value and rest to zeros. 284 | 285 | This function exists because it is unclear if it is safe to use 286 | tf.sparse_to_dense(indices, [size], 1, validate_indices=False) 287 | with indices which are not ordered. 288 | This function accepts a dynamic size (e.g. tf.shape(tensor)[0]) 289 | 290 | Args: 291 | indices: 1d Tensor with integer indices which are to be set to 292 | indices_values. 293 | size: scalar with size (integer) of output Tensor. 294 | indices_value: values of elements specified by indices in the output vector 295 | default_value: values of other elements in the output vector. 296 | dtype: data type. 297 | 298 | Returns: 299 | dense 1D Tensor of shape [size] with indices set to indices_values and the 300 | rest set to default_value. 301 | """ 302 | size = tf.to_int32(size) 303 | zeros = tf.ones([size], dtype=dtype) * default_value 304 | values = tf.ones_like(indices, dtype=dtype) * indices_value 305 | 306 | return tf.dynamic_stitch([tf.range(size), tf.to_int32(indices)], 307 | [zeros, values]) 308 | 309 | 310 | def reduce_sum_trailing_dimensions(tensor, ndims): 311 | """Computes sum across all dimensions following first `ndims` dimensions.""" 312 | return tf.reduce_sum(tensor, axis=tuple(range(ndims, tensor.shape.ndims))) 313 | 314 | 315 | def retain_groundtruth(tensor_dict, valid_indices): 316 | """Retains groundtruth by valid indices. 317 | 318 | Args: 319 | tensor_dict: a dictionary of following groundtruth tensors - 320 | fields.InputDataFields.groundtruth_boxes 321 | fields.InputDataFields.groundtruth_classes 322 | fields.InputDataFields.groundtruth_keypoints 323 | fields.InputDataFields.groundtruth_instance_masks 324 | fields.InputDataFields.groundtruth_is_crowd 325 | fields.InputDataFields.groundtruth_area 326 | fields.InputDataFields.groundtruth_label_types 327 | fields.InputDataFields.groundtruth_difficult 328 | valid_indices: a tensor with valid indices for the box-level groundtruth. 329 | 330 | Returns: 331 | a dictionary of tensors containing only the groundtruth for valid_indices. 332 | 333 | Raises: 334 | ValueError: If the shape of valid_indices is invalid. 335 | ValueError: field fields.InputDataFields.groundtruth_boxes is 336 | not present in tensor_dict. 337 | """ 338 | input_shape = valid_indices.get_shape().as_list() 339 | if not (len(input_shape) == 1 or 340 | (len(input_shape) == 2 and input_shape[1] == 1)): 341 | raise ValueError('The shape of valid_indices is invalid.') 342 | valid_indices = tf.reshape(valid_indices, [-1]) 343 | valid_dict = {} 344 | if fields.InputDataFields.groundtruth_boxes in tensor_dict: 345 | # Prevents reshape failure when num_boxes is 0. 346 | num_boxes = tf.maximum(tf.shape( 347 | tensor_dict[fields.InputDataFields.groundtruth_boxes])[0], 1) 348 | for key in tensor_dict: 349 | if key in [fields.InputDataFields.groundtruth_boxes, 350 | fields.InputDataFields.groundtruth_classes, 351 | fields.InputDataFields.groundtruth_keypoints, 352 | fields.InputDataFields.groundtruth_instance_masks]: 353 | valid_dict[key] = tf.gather(tensor_dict[key], valid_indices) 354 | # Input decoder returns empty tensor when these fields are not provided. 355 | # Needs to reshape into [num_boxes, -1] for tf.gather() to work. 356 | elif key in [fields.InputDataFields.groundtruth_is_crowd, 357 | fields.InputDataFields.groundtruth_area, 358 | fields.InputDataFields.groundtruth_difficult, 359 | fields.InputDataFields.groundtruth_label_types]: 360 | valid_dict[key] = tf.reshape( 361 | tf.gather(tf.reshape(tensor_dict[key], [num_boxes, -1]), 362 | valid_indices), [-1]) 363 | # Fields that are not associated with boxes. 364 | else: 365 | valid_dict[key] = tensor_dict[key] 366 | else: 367 | raise ValueError('%s not present in input tensor dict.' % ( 368 | fields.InputDataFields.groundtruth_boxes)) 369 | return valid_dict 370 | 371 | 372 | def retain_groundtruth_with_positive_classes(tensor_dict): 373 | """Retains only groundtruth with positive class ids. 374 | 375 | Args: 376 | tensor_dict: a dictionary of following groundtruth tensors - 377 | fields.InputDataFields.groundtruth_boxes 378 | fields.InputDataFields.groundtruth_classes 379 | fields.InputDataFields.groundtruth_keypoints 380 | fields.InputDataFields.groundtruth_instance_masks 381 | fields.InputDataFields.groundtruth_is_crowd 382 | fields.InputDataFields.groundtruth_area 383 | fields.InputDataFields.groundtruth_label_types 384 | fields.InputDataFields.groundtruth_difficult 385 | 386 | Returns: 387 | a dictionary of tensors containing only the groundtruth with positive 388 | classes. 389 | 390 | Raises: 391 | ValueError: If groundtruth_classes tensor is not in tensor_dict. 392 | """ 393 | if fields.InputDataFields.groundtruth_classes not in tensor_dict: 394 | raise ValueError('`groundtruth classes` not in tensor_dict.') 395 | keep_indices = tf.where(tf.greater( 396 | tensor_dict[fields.InputDataFields.groundtruth_classes], 0)) 397 | return retain_groundtruth(tensor_dict, keep_indices) 398 | 399 | 400 | def replace_nan_groundtruth_label_scores_with_ones(label_scores): 401 | """Replaces nan label scores with 1.0. 402 | 403 | Args: 404 | label_scores: a tensor containing object annoation label scores. 405 | 406 | Returns: 407 | a tensor where NaN label scores have been replaced by ones. 408 | """ 409 | return tf.where( 410 | tf.is_nan(label_scores), tf.ones(tf.shape(label_scores)), label_scores) 411 | 412 | 413 | def filter_groundtruth_with_crowd_boxes(tensor_dict): 414 | """Filters out groundtruth with boxes corresponding to crowd. 415 | 416 | Args: 417 | tensor_dict: a dictionary of following groundtruth tensors - 418 | fields.InputDataFields.groundtruth_boxes 419 | fields.InputDataFields.groundtruth_classes 420 | fields.InputDataFields.groundtruth_keypoints 421 | fields.InputDataFields.groundtruth_instance_masks 422 | fields.InputDataFields.groundtruth_is_crowd 423 | fields.InputDataFields.groundtruth_area 424 | fields.InputDataFields.groundtruth_label_types 425 | 426 | Returns: 427 | a dictionary of tensors containing only the groundtruth that have bounding 428 | boxes. 429 | """ 430 | if fields.InputDataFields.groundtruth_is_crowd in tensor_dict: 431 | is_crowd = tensor_dict[fields.InputDataFields.groundtruth_is_crowd] 432 | is_not_crowd = tf.logical_not(is_crowd) 433 | is_not_crowd_indices = tf.where(is_not_crowd) 434 | tensor_dict = retain_groundtruth(tensor_dict, is_not_crowd_indices) 435 | return tensor_dict 436 | 437 | 438 | def filter_groundtruth_with_nan_box_coordinates(tensor_dict): 439 | """Filters out groundtruth with no bounding boxes. 440 | 441 | Args: 442 | tensor_dict: a dictionary of following groundtruth tensors - 443 | fields.InputDataFields.groundtruth_boxes 444 | fields.InputDataFields.groundtruth_classes 445 | fields.InputDataFields.groundtruth_keypoints 446 | fields.InputDataFields.groundtruth_instance_masks 447 | fields.InputDataFields.groundtruth_is_crowd 448 | fields.InputDataFields.groundtruth_area 449 | fields.InputDataFields.groundtruth_label_types 450 | 451 | Returns: 452 | a dictionary of tensors containing only the groundtruth that have bounding 453 | boxes. 454 | """ 455 | groundtruth_boxes = tensor_dict[fields.InputDataFields.groundtruth_boxes] 456 | nan_indicator_vector = tf.greater(tf.reduce_sum(tf.to_int32( 457 | tf.is_nan(groundtruth_boxes)), reduction_indices=[1]), 0) 458 | valid_indicator_vector = tf.logical_not(nan_indicator_vector) 459 | valid_indices = tf.where(valid_indicator_vector) 460 | 461 | return retain_groundtruth(tensor_dict, valid_indices) 462 | 463 | 464 | def normalize_to_target(inputs, 465 | target_norm_value, 466 | dim, 467 | epsilon=1e-7, 468 | trainable=True, 469 | scope='NormalizeToTarget', 470 | summarize=True): 471 | """L2 normalizes the inputs across the specified dimension to a target norm. 472 | 473 | This op implements the L2 Normalization layer introduced in 474 | Liu, Wei, et al. "SSD: Single Shot MultiBox Detector." 475 | and Liu, Wei, Andrew Rabinovich, and Alexander C. Berg. 476 | "Parsenet: Looking wider to see better." and is useful for bringing 477 | activations from multiple layers in a convnet to a standard scale. 478 | 479 | Note that the rank of `inputs` must be known and the dimension to which 480 | normalization is to be applied should be statically defined. 481 | 482 | TODO(jonathanhuang): Add option to scale by L2 norm of the entire input. 483 | 484 | Args: 485 | inputs: A `Tensor` of arbitrary size. 486 | target_norm_value: A float value that specifies an initial target norm or 487 | a list of floats (whose length must be equal to the depth along the 488 | dimension to be normalized) specifying a per-dimension multiplier 489 | after normalization. 490 | dim: The dimension along which the input is normalized. 491 | epsilon: A small value to add to the inputs to avoid dividing by zero. 492 | trainable: Whether the norm is trainable or not 493 | scope: Optional scope for variable_scope. 494 | summarize: Whether or not to add a tensorflow summary for the op. 495 | 496 | Returns: 497 | The input tensor normalized to the specified target norm. 498 | 499 | Raises: 500 | ValueError: If dim is smaller than the number of dimensions in 'inputs'. 501 | ValueError: If target_norm_value is not a float or a list of floats with 502 | length equal to the depth along the dimension to be normalized. 503 | """ 504 | with tf.variable_scope(scope, 'NormalizeToTarget', [inputs]): 505 | if not inputs.get_shape(): 506 | raise ValueError('The input rank must be known.') 507 | input_shape = inputs.get_shape().as_list() 508 | input_rank = len(input_shape) 509 | if dim < 0 or dim >= input_rank: 510 | raise ValueError( 511 | 'dim must be non-negative but smaller than the input rank.') 512 | if not input_shape[dim]: 513 | raise ValueError('input shape should be statically defined along ' 514 | 'the specified dimension.') 515 | depth = input_shape[dim] 516 | if not (isinstance(target_norm_value, float) or 517 | (isinstance(target_norm_value, list) and 518 | len(target_norm_value) == depth) and 519 | all([isinstance(val, float) for val in target_norm_value])): 520 | raise ValueError('target_norm_value must be a float or a list of floats ' 521 | 'with length equal to the depth along the dimension to ' 522 | 'be normalized.') 523 | if isinstance(target_norm_value, float): 524 | initial_norm = depth * [target_norm_value] 525 | else: 526 | initial_norm = target_norm_value 527 | target_norm = tf.contrib.framework.model_variable( 528 | name='weights', dtype=tf.float32, 529 | initializer=tf.constant(initial_norm, dtype=tf.float32), 530 | trainable=trainable) 531 | if summarize: 532 | mean = tf.reduce_mean(target_norm) 533 | mean = tf.Print(mean, ['NormalizeToTarget:', mean]) 534 | tf.summary.scalar(tf.get_variable_scope().name, mean) 535 | lengths = epsilon + tf.sqrt(tf.reduce_sum(tf.square(inputs), dim, True)) 536 | mult_shape = input_rank*[1] 537 | mult_shape[dim] = depth 538 | return tf.reshape(target_norm, mult_shape) * tf.truediv(inputs, lengths) 539 | 540 | 541 | def position_sensitive_crop_regions(image, 542 | boxes, 543 | box_ind, 544 | crop_size, 545 | num_spatial_bins, 546 | global_pool, 547 | extrapolation_value=None): 548 | """Position-sensitive crop and pool rectangular regions from a feature grid. 549 | 550 | The output crops are split into `spatial_bins_y` vertical bins 551 | and `spatial_bins_x` horizontal bins. For each intersection of a vertical 552 | and a horizontal bin the output values are gathered by performing 553 | `tf.image.crop_and_resize` (bilinear resampling) on a a separate subset of 554 | channels of the image. This reduces `depth` by a factor of 555 | `(spatial_bins_y * spatial_bins_x)`. 556 | 557 | When global_pool is True, this function implements a differentiable version 558 | of position-sensitive RoI pooling used in 559 | [R-FCN detection system](https://arxiv.org/abs/1605.06409). 560 | 561 | When global_pool is False, this function implements a differentiable version 562 | of position-sensitive assembling operation used in 563 | [instance FCN](https://arxiv.org/abs/1603.08678). 564 | 565 | Args: 566 | image: A `Tensor`. Must be one of the following types: `uint8`, `int8`, 567 | `int16`, `int32`, `int64`, `half`, `float32`, `float64`. 568 | A 4-D tensor of shape `[batch, image_height, image_width, depth]`. 569 | Both `image_height` and `image_width` need to be positive. 570 | boxes: A `Tensor` of type `float32`. 571 | A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor 572 | specifies the coordinates of a box in the `box_ind[i]` image and is 573 | specified in normalized coordinates `[y1, x1, y2, x2]`. A normalized 574 | coordinate value of `y` is mapped to the image coordinate at 575 | `y * (image_height - 1)`, so as the `[0, 1]` interval of normalized image 576 | height is mapped to `[0, image_height - 1] in image height coordinates. 577 | We do allow y1 > y2, in which case the sampled crop is an up-down flipped 578 | version of the original image. The width dimension is treated similarly. 579 | Normalized coordinates outside the `[0, 1]` range are allowed, in which 580 | case we use `extrapolation_value` to extrapolate the input image values. 581 | box_ind: A `Tensor` of type `int32`. 582 | A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. 583 | The value of `box_ind[i]` specifies the image that the `i`-th box refers 584 | to. 585 | crop_size: A list of two integers `[crop_height, crop_width]`. All 586 | cropped image patches are resized to this size. The aspect ratio of the 587 | image content is not preserved. Both `crop_height` and `crop_width` need 588 | to be positive. 589 | num_spatial_bins: A list of two integers `[spatial_bins_y, spatial_bins_x]`. 590 | Represents the number of position-sensitive bins in y and x directions. 591 | Both values should be >= 1. `crop_height` should be divisible by 592 | `spatial_bins_y`, and similarly for width. 593 | The number of image channels should be divisible by 594 | (spatial_bins_y * spatial_bins_x). 595 | Suggested value from R-FCN paper: [3, 3]. 596 | global_pool: A boolean variable. 597 | If True, we perform average global pooling on the features assembled from 598 | the position-sensitive score maps. 599 | If False, we keep the position-pooled features without global pooling 600 | over the spatial coordinates. 601 | Note that using global_pool=True is equivalent to but more efficient than 602 | running the function with global_pool=False and then performing global 603 | average pooling. 604 | extrapolation_value: An optional `float`. Defaults to `0`. 605 | Value used for extrapolation, when applicable. 606 | Returns: 607 | position_sensitive_features: A 4-D tensor of shape 608 | `[num_boxes, K, K, crop_channels]`, 609 | where `crop_channels = depth / (spatial_bins_y * spatial_bins_x)`, 610 | where K = 1 when global_pool is True (Average-pooled cropped regions), 611 | and K = crop_size when global_pool is False. 612 | Raises: 613 | ValueError: Raised in four situations: 614 | `num_spatial_bins` is not >= 1; 615 | `num_spatial_bins` does not divide `crop_size`; 616 | `(spatial_bins_y*spatial_bins_x)` does not divide `depth`; 617 | `bin_crop_size` is not square when global_pool=False due to the 618 | constraint in function space_to_depth. 619 | """ 620 | total_bins = 1 621 | bin_crop_size = [] 622 | 623 | for (num_bins, crop_dim) in zip(num_spatial_bins, crop_size): 624 | if num_bins < 1: 625 | raise ValueError('num_spatial_bins should be >= 1') 626 | 627 | if crop_dim % num_bins != 0: 628 | raise ValueError('crop_size should be divisible by num_spatial_bins') 629 | 630 | total_bins *= num_bins 631 | bin_crop_size.append(crop_dim // num_bins) 632 | 633 | if not global_pool and bin_crop_size[0] != bin_crop_size[1]: 634 | raise ValueError('Only support square bin crop size for now.') 635 | 636 | ymin, xmin, ymax, xmax = tf.unstack(boxes, axis=1) 637 | spatial_bins_y, spatial_bins_x = num_spatial_bins 638 | 639 | # Split each box into spatial_bins_y * spatial_bins_x bins. 640 | position_sensitive_boxes = [] 641 | for bin_y in range(spatial_bins_y): 642 | step_y = (ymax - ymin) / spatial_bins_y 643 | for bin_x in range(spatial_bins_x): 644 | step_x = (xmax - xmin) / spatial_bins_x 645 | box_coordinates = [ymin + bin_y * step_y, 646 | xmin + bin_x * step_x, 647 | ymin + (bin_y + 1) * step_y, 648 | xmin + (bin_x + 1) * step_x, 649 | ] 650 | position_sensitive_boxes.append(tf.stack(box_coordinates, axis=1)) 651 | 652 | image_splits = tf.split(value=image, num_or_size_splits=total_bins, axis=3) 653 | 654 | image_crops = [] 655 | for (split, box) in zip(image_splits, position_sensitive_boxes): 656 | crop = tf.image.crop_and_resize(split, box, box_ind, bin_crop_size, 657 | extrapolation_value=extrapolation_value) 658 | image_crops.append(crop) 659 | 660 | if global_pool: 661 | # Average over all bins. 662 | position_sensitive_features = tf.add_n(image_crops) / len(image_crops) 663 | # Then average over spatial positions within the bins. 664 | position_sensitive_features = tf.reduce_mean( 665 | position_sensitive_features, [1, 2], keep_dims=True) 666 | else: 667 | # Reorder height/width to depth channel. 668 | block_size = bin_crop_size[0] 669 | if block_size >= 2: 670 | image_crops = [tf.space_to_depth( 671 | crop, block_size=block_size) for crop in image_crops] 672 | 673 | # Pack image_crops so that first dimension is for position-senstive boxes. 674 | position_sensitive_features = tf.stack(image_crops, axis=0) 675 | 676 | # Unroll the position-sensitive boxes to spatial positions. 677 | position_sensitive_features = tf.squeeze( 678 | tf.batch_to_space_nd(position_sensitive_features, 679 | block_shape=[1] + num_spatial_bins, 680 | crops=tf.zeros((3, 2), dtype=tf.int32)), 681 | squeeze_dims=[0]) 682 | 683 | # Reorder back the depth channel. 684 | if block_size >= 2: 685 | position_sensitive_features = tf.depth_to_space( 686 | position_sensitive_features, block_size=block_size) 687 | 688 | return position_sensitive_features 689 | 690 | 691 | def reframe_box_masks_to_image_masks(box_masks, boxes, image_height, 692 | image_width): 693 | """Transforms the box masks back to full image masks. 694 | 695 | Embeds masks in bounding boxes of larger masks whose shapes correspond to 696 | image shape. 697 | 698 | Args: 699 | box_masks: A tf.float32 tensor of size [num_masks, mask_height, mask_width]. 700 | boxes: A tf.float32 tensor of size [num_masks, 4] containing the box 701 | corners. Row i contains [ymin, xmin, ymax, xmax] of the box 702 | corresponding to mask i. Note that the box corners are in 703 | normalized coordinates. 704 | image_height: Image height. The output mask will have the same height as 705 | the image height. 706 | image_width: Image width. The output mask will have the same width as the 707 | image width. 708 | 709 | Returns: 710 | A tf.float32 tensor of size [num_masks, image_height, image_width]. 711 | """ 712 | # TODO(rathodv): Make this a public function. 713 | def reframe_box_masks_to_image_masks_default(): 714 | """The default function when there are more than 0 box masks.""" 715 | def transform_boxes_relative_to_boxes(boxes, reference_boxes): 716 | boxes = tf.reshape(boxes, [-1, 2, 2]) 717 | min_corner = tf.expand_dims(reference_boxes[:, 0:2], 1) 718 | max_corner = tf.expand_dims(reference_boxes[:, 2:4], 1) 719 | transformed_boxes = (boxes - min_corner) / (max_corner - min_corner) 720 | return tf.reshape(transformed_boxes, [-1, 4]) 721 | 722 | box_masks_expanded = tf.expand_dims(box_masks, axis=3) 723 | num_boxes = tf.shape(box_masks_expanded)[0] 724 | unit_boxes = tf.concat( 725 | [tf.zeros([num_boxes, 2]), tf.ones([num_boxes, 2])], axis=1) 726 | reverse_boxes = transform_boxes_relative_to_boxes(unit_boxes, boxes) 727 | return tf.image.crop_and_resize( 728 | image=box_masks_expanded, 729 | boxes=reverse_boxes, 730 | box_ind=tf.range(num_boxes), 731 | crop_size=[image_height, image_width], 732 | extrapolation_value=0.0) 733 | image_masks = tf.cond( 734 | tf.shape(box_masks)[0] > 0, 735 | reframe_box_masks_to_image_masks_default, 736 | lambda: tf.zeros([0, image_height, image_width, 1], dtype=tf.float32)) 737 | return tf.squeeze(image_masks, axis=3) 738 | 739 | 740 | def merge_boxes_with_multiple_labels(boxes, classes, num_classes): 741 | """Merges boxes with same coordinates and returns K-hot encoded classes. 742 | 743 | Args: 744 | boxes: A tf.float32 tensor with shape [N, 4] holding N boxes. 745 | classes: A tf.int32 tensor with shape [N] holding class indices. 746 | The class index starts at 0. 747 | num_classes: total number of classes to use for K-hot encoding. 748 | 749 | Returns: 750 | merged_boxes: A tf.float32 tensor with shape [N', 4] holding boxes, 751 | where N' <= N. 752 | class_encodings: A tf.int32 tensor with shape [N', num_classes] holding 753 | k-hot encodings for the merged boxes. 754 | merged_box_indices: A tf.int32 tensor with shape [N'] holding original 755 | indices of the boxes. 756 | """ 757 | def merge_numpy_boxes(boxes, classes, num_classes): 758 | """Python function to merge numpy boxes.""" 759 | if boxes.size < 1: 760 | return (np.zeros([0, 4], dtype=np.float32), 761 | np.zeros([0, num_classes], dtype=np.int32), 762 | np.zeros([0], dtype=np.int32)) 763 | box_to_class_indices = {} 764 | for box_index in range(boxes.shape[0]): 765 | box = tuple(boxes[box_index, :].tolist()) 766 | class_index = classes[box_index] 767 | if box not in box_to_class_indices: 768 | box_to_class_indices[box] = [box_index, np.zeros([num_classes])] 769 | box_to_class_indices[box][1][class_index] = 1 770 | merged_boxes = np.vstack(box_to_class_indices.keys()).astype(np.float32) 771 | class_encodings = [item[1] for item in box_to_class_indices.values()] 772 | class_encodings = np.vstack(class_encodings).astype(np.int32) 773 | merged_box_indices = [item[0] for item in box_to_class_indices.values()] 774 | merged_box_indices = np.array(merged_box_indices).astype(np.int32) 775 | return merged_boxes, class_encodings, merged_box_indices 776 | 777 | merged_boxes, class_encodings, merged_box_indices = tf.py_func( 778 | merge_numpy_boxes, [boxes, classes, num_classes], 779 | [tf.float32, tf.int32, tf.int32]) 780 | merged_boxes = tf.reshape(merged_boxes, [-1, 4]) 781 | class_encodings = tf.reshape(class_encodings, [-1, num_classes]) 782 | merged_box_indices = tf.reshape(merged_box_indices, [-1]) 783 | return merged_boxes, class_encodings, merged_box_indices 784 | 785 | 786 | def nearest_neighbor_upsampling(input_tensor, scale): 787 | """Nearest neighbor upsampling implementation. 788 | 789 | Nearest neighbor upsampling function that maps input tensor with shape 790 | [batch_size, height, width, channels] to [batch_size, height * scale 791 | , width * scale, channels]. This implementation only uses reshape and 792 | broadcasting to make it TPU compatible. 793 | 794 | Args: 795 | input_tensor: A float32 tensor of size [batch, height_in, width_in, 796 | channels]. 797 | scale: An integer multiple to scale resolution of input data. 798 | Returns: 799 | data_up: A float32 tensor of size 800 | [batch, height_in*scale, width_in*scale, channels]. 801 | """ 802 | with tf.name_scope('nearest_neighbor_upsampling'): 803 | (batch_size, height, width, 804 | channels) = shape_utils.combined_static_and_dynamic_shape(input_tensor) 805 | output_tensor = tf.reshape( 806 | input_tensor, [batch_size, height, 1, width, 1, channels]) * tf.ones( 807 | [1, 1, scale, 1, scale, 1], dtype=input_tensor.dtype) 808 | return tf.reshape(output_tensor, 809 | [batch_size, height * scale, width * scale, channels]) 810 | 811 | 812 | def matmul_gather_on_zeroth_axis(params, indices, scope=None): 813 | """Matrix multiplication based implementation of tf.gather on zeroth axis. 814 | 815 | TODO(rathodv, jonathanhuang): enable sparse matmul option. 816 | 817 | Args: 818 | params: A float32 Tensor. The tensor from which to gather values. 819 | Must be at least rank 1. 820 | indices: A Tensor. Must be one of the following types: int32, int64. 821 | Must be in range [0, params.shape[0]) 822 | scope: A name for the operation (optional). 823 | 824 | Returns: 825 | A Tensor. Has the same type as params. Values from params gathered 826 | from indices given by indices, with shape indices.shape + params.shape[1:]. 827 | """ 828 | with tf.name_scope(scope, 'MatMulGather'): 829 | params_shape = shape_utils.combined_static_and_dynamic_shape(params) 830 | indices_shape = shape_utils.combined_static_and_dynamic_shape(indices) 831 | params2d = tf.reshape(params, [params_shape[0], -1]) 832 | indicator_matrix = tf.one_hot(indices, params_shape[0]) 833 | gathered_result_flattened = tf.matmul(indicator_matrix, params2d) 834 | return tf.reshape(gathered_result_flattened, 835 | tf.stack(indices_shape + params_shape[1:])) 836 | 837 | 838 | def matmul_crop_and_resize(image, boxes, crop_size, scope=None): 839 | """Matrix multiplication based implementation of the crop and resize op. 840 | 841 | Extracts crops from the input image tensor and bilinearly resizes them 842 | (possibly with aspect ratio change) to a common output size specified by 843 | crop_size. This is more general than the crop_to_bounding_box op which 844 | extracts a fixed size slice from the input image and does not allow 845 | resizing or aspect ratio change. 846 | 847 | Returns a tensor with crops from the input image at positions defined at 848 | the bounding box locations in boxes. The cropped boxes are all resized 849 | (with bilinear interpolation) to a fixed size = `[crop_height, crop_width]`. 850 | The result is a 4-D tensor `[num_boxes, crop_height, crop_width, depth]`. 851 | 852 | Running time complexity: 853 | O((# channels) * (# boxes) * (crop_size)^2 * M), where M is the number 854 | of pixels of the longer edge of the image. 855 | 856 | Note that this operation is meant to replicate the behavior of the standard 857 | tf.image.crop_and_resize operation but there are a few differences. 858 | Specifically: 859 | 1) The extrapolation value (the values that are interpolated from outside 860 | the bounds of the image window) is always zero 861 | 2) Only XLA supported operations are used (e.g., matrix multiplication). 862 | 3) There is no `box_indices` argument --- to run this op on multiple images, 863 | one must currently call this op independently on each image. 864 | 4) All shapes and the `crop_size` parameter are assumed to be statically 865 | defined. Moreover, the number of boxes must be strictly nonzero. 866 | 867 | Args: 868 | image: A `Tensor`. Must be one of the following types: `uint8`, `int8`, 869 | `int16`, `int32`, `int64`, `half`, `float32`, `float64`. 870 | A 4-D tensor of shape `[batch, image_height, image_width, depth]`. 871 | Both `image_height` and `image_width` need to be positive. 872 | boxes: A `Tensor` of type `float32`. 873 | A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor 874 | specifies the coordinates of a box in the `box_ind[i]` image and is 875 | specified in normalized coordinates `[y1, x1, y2, x2]`. A normalized 876 | coordinate value of `y` is mapped to the image coordinate at 877 | `y * (image_height - 1)`, so as the `[0, 1]` interval of normalized image 878 | height is mapped to `[0, image_height - 1] in image height coordinates. 879 | We do allow y1 > y2, in which case the sampled crop is an up-down flipped 880 | version of the original image. The width dimension is treated similarly. 881 | Normalized coordinates outside the `[0, 1]` range are allowed, in which 882 | case we use `extrapolation_value` to extrapolate the input image values. 883 | crop_size: A list of two integers `[crop_height, crop_width]`. All 884 | cropped image patches are resized to this size. The aspect ratio of the 885 | image content is not preserved. Both `crop_height` and `crop_width` need 886 | to be positive. 887 | scope: A name for the operation (optional). 888 | 889 | Returns: 890 | A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]` 891 | 892 | Raises: 893 | ValueError: if image tensor does not have shape 894 | `[1, image_height, image_width, depth]` and all dimensions statically 895 | defined. 896 | ValueError: if boxes tensor does not have shape `[num_boxes, 4]` where 897 | num_boxes > 0. 898 | ValueError: if crop_size is not a list of two positive integers 899 | """ 900 | img_shape = image.shape.as_list() 901 | boxes_shape = boxes.shape.as_list() 902 | _, img_height, img_width, _ = img_shape 903 | if not isinstance(crop_size, list) or len(crop_size) != 2: 904 | raise ValueError('`crop_size` must be a list of length 2') 905 | dimensions = img_shape + crop_size + boxes_shape 906 | if not all([isinstance(dim, int) for dim in dimensions]): 907 | raise ValueError('all input shapes must be statically defined') 908 | if len(crop_size) != 2: 909 | raise ValueError('`crop_size` must be a list of length 2') 910 | if len(boxes_shape) != 2 or boxes_shape[1] != 4: 911 | raise ValueError('`boxes` should have shape `[num_boxes, 4]`') 912 | if len(img_shape) != 4 and img_shape[0] != 1: 913 | raise ValueError('image should have shape ' 914 | '`[1, image_height, image_width, depth]`') 915 | num_crops = boxes_shape[0] 916 | if not num_crops > 0: 917 | raise ValueError('number of boxes must be > 0') 918 | if not (crop_size[0] > 0 and crop_size[1] > 0): 919 | raise ValueError('`crop_size` must be a list of two positive integers.') 920 | 921 | def _lin_space_weights(num, img_size): 922 | if num > 1: 923 | alpha = (img_size - 1) / float(num - 1) 924 | indices = np.reshape(np.arange(num), (1, num)) 925 | start_weights = alpha * (num - 1 - indices) 926 | stop_weights = alpha * indices 927 | else: 928 | start_weights = num * [.5 * (img_size - 1)] 929 | stop_weights = num * [.5 * (img_size - 1)] 930 | return (tf.constant(start_weights, dtype=tf.float32), 931 | tf.constant(stop_weights, dtype=tf.float32)) 932 | 933 | with tf.name_scope(scope, 'MatMulCropAndResize'): 934 | y1_weights, y2_weights = _lin_space_weights(crop_size[0], img_height) 935 | x1_weights, x2_weights = _lin_space_weights(crop_size[1], img_width) 936 | [y1, x1, y2, x2] = tf.split(value=boxes, num_or_size_splits=4, axis=1) 937 | 938 | # Pixel centers of input image and grid points along height and width 939 | image_idx_h = tf.constant( 940 | np.reshape(np.arange(img_height), (1, 1, img_height)), dtype=tf.float32) 941 | image_idx_w = tf.constant( 942 | np.reshape(np.arange(img_width), (1, 1, img_width)), dtype=tf.float32) 943 | grid_pos_h = tf.expand_dims(y1 * y1_weights + y2 * y2_weights, 2) 944 | grid_pos_w = tf.expand_dims(x1 * x1_weights + x2 * x2_weights, 2) 945 | 946 | # Create kernel matrices of pairwise kernel evaluations between pixel 947 | # centers of image and grid points. 948 | kernel_h = tf.nn.relu(1 - tf.abs(image_idx_h - grid_pos_h)) 949 | kernel_w = tf.nn.relu(1 - tf.abs(image_idx_w - grid_pos_w)) 950 | 951 | # TODO(jonathanhuang): investigate whether all channels can be processed 952 | # without the explicit unstack --- possibly with a permute and map_fn call. 953 | result_channels = [] 954 | for channel in tf.unstack(image, axis=3): 955 | result_channels.append( 956 | tf.matmul( 957 | tf.matmul(kernel_h, tf.tile(channel, [num_crops, 1, 1])), 958 | kernel_w, transpose_b=True)) 959 | return tf.stack(result_channels, axis=3) 960 | --------------------------------------------------------------------------------