├── .gitignore
├── models_hash
└── model_hash.json
├── src
└── main
│ ├── inference
│ ├── __init__.py
│ ├── ConfigurationSchema.json
│ ├── inference_engines_factory.py
│ ├── base_error.py
│ ├── errors.py
│ ├── exceptions.py
│ ├── base_inference_engine.py
│ └── tensorflow_detection.py
│ ├── .gitignore
│ ├── object_detection
│ ├── core
│ │ ├── __init__.py
│ │ ├── box_list.py
│ │ └── standard_fields.py
│ ├── image1.jpg
│ ├── utils
│ │ ├── static_shape.py
│ │ ├── label_map_util.py
│ │ ├── shape_utils.py
│ │ ├── visualization_utils.py
│ │ └── ops.py
│ └── protos
│ │ └── string_int_label_map_pb2.py
│ ├── fonts
│ └── DejaVuSans.ttf
│ ├── models.py
│ ├── ocr.py
│ ├── deep_learning_service.py
│ └── start.py
├── docs
├── 1.gif
├── 2.gif
├── 3.gif
├── 4.gif
├── 5.gif
├── tcpu.png
├── tcpu2.png
├── TCPU20req.png
├── TCPU40req.png
├── nvidia-smi.gif
├── swagger_endpoints.png
└── uml
│ ├── InferenceClassDiagram.png
│ ├── InferenceSequenceDiagram.png
│ ├── InferenceSequenceDiagram.xml
│ └── InferenceClassDiagram.drawio
├── models
└── .gitignore
├── docker
├── requirements.txt
└── dockerfile
├── cpu-inference.yaml
├── install_prerequisites.sh
├── README-docker_swarm.md
├── README.md
└── LICENSE
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 |
--------------------------------------------------------------------------------
/models_hash/model_hash.json:
--------------------------------------------------------------------------------
1 | {}
--------------------------------------------------------------------------------
/src/main/inference/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/main/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.log
3 |
--------------------------------------------------------------------------------
/src/main/object_detection/core/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/docs/1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/1.gif
--------------------------------------------------------------------------------
/docs/2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/2.gif
--------------------------------------------------------------------------------
/docs/3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/3.gif
--------------------------------------------------------------------------------
/docs/4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/4.gif
--------------------------------------------------------------------------------
/docs/5.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/5.gif
--------------------------------------------------------------------------------
/docs/tcpu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/tcpu.png
--------------------------------------------------------------------------------
/models/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file
4 | !.gitignore
5 |
6 |
--------------------------------------------------------------------------------
/docs/tcpu2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/tcpu2.png
--------------------------------------------------------------------------------
/docs/TCPU20req.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/TCPU20req.png
--------------------------------------------------------------------------------
/docs/TCPU40req.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/TCPU40req.png
--------------------------------------------------------------------------------
/docs/nvidia-smi.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/nvidia-smi.gif
--------------------------------------------------------------------------------
/docs/swagger_endpoints.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/swagger_endpoints.png
--------------------------------------------------------------------------------
/src/main/fonts/DejaVuSans.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/src/main/fonts/DejaVuSans.ttf
--------------------------------------------------------------------------------
/docs/uml/InferenceClassDiagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/uml/InferenceClassDiagram.png
--------------------------------------------------------------------------------
/src/main/object_detection/image1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/src/main/object_detection/image1.jpg
--------------------------------------------------------------------------------
/docs/uml/InferenceSequenceDiagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BMW-InnovationLab/BMW-TensorFlow-Inference-API-CPU/HEAD/docs/uml/InferenceSequenceDiagram.png
--------------------------------------------------------------------------------
/docker/requirements.txt:
--------------------------------------------------------------------------------
1 | aiofiles
2 | celery
3 | fastapi
4 | h5py
5 | matplotlib
6 | numpy
7 | opencv-python
8 | python-multipart
9 | pandas
10 | Pillow
11 | python-socketio
12 | requests
13 | scipy
14 | sklearn
15 | socketIO-client-nexus
16 | tensorflow==1.13.1
17 | uvicorn
18 | jsonschema
19 | pytz
20 | pytesseract
21 |
22 |
23 |
24 |
25 |
--------------------------------------------------------------------------------
/docker/dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.6
2 |
3 | LABEL maintainer="antoine.charbel@inmind.ai"
4 |
5 | COPY docker/requirements.txt .
6 | COPY src/main /main
7 |
8 | RUN apt-get update && apt-get install -y tesseract-ocr
9 |
10 | RUN pip install -r requirements.txt
11 |
12 | WORKDIR /main
13 |
14 | CMD ["uvicorn", "start:app", "--host", "0.0.0.0", "--port", "4343"]
15 |
--------------------------------------------------------------------------------
/cpu-inference.yaml:
--------------------------------------------------------------------------------
1 | version: "3"
2 |
3 | services:
4 | api:
5 | ports:
6 | - "4343:4343"
7 | image: tensorflow_inference_api_cpu
8 | volumes:
9 | - "/mnt/models:/models"
10 | deploy:
11 | replicas: 1
12 | update_config:
13 | parallelism: 2
14 | delay: 10s
15 | restart_policy:
16 | condition: on-failure
17 |
--------------------------------------------------------------------------------
/src/main/models.py:
--------------------------------------------------------------------------------
1 | class ApiResponse:
2 |
3 | def __init__(self, success=True, data=None, error=None):
4 | """
5 | Defines the response shape
6 | :param success: A boolean that returns if the request has succeeded or not
7 | :param data: The model's response
8 | :param error: The error in case an exception was raised
9 | """
10 | self.data = data
11 | self.error = error.__str__() if error is not None else ''
12 | self.success = success
13 |
--------------------------------------------------------------------------------
/install_prerequisites.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # This will install docker following [https://docs.docker.com/install/linux/docker-ce/ubuntu/]
4 | sudo apt-get remove docker docker-engine docker.io
5 | sudo apt-get update
6 |
7 | sudo apt-get install \
8 | apt-transport-https \
9 | ca-certificates \
10 | curl \
11 | software-properties-common
12 |
13 | curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add -
14 | sudo apt-key fingerprint 0EBFCD88
15 |
16 | sudo add-apt-repository \
17 | "deb [arch=amd64] https://download.docker.com/linux/ubuntu \
18 | $(lsb_release -cs) \
19 | stable"
20 |
21 | sudo apt-get update
22 | sudo apt-get install -y docker-ce
23 | sudo groupadd docker
24 | sudo usermod -aG docker ${USER}
25 | docker run hello-world
26 |
27 |
--------------------------------------------------------------------------------
/src/main/inference/ConfigurationSchema.json:
--------------------------------------------------------------------------------
1 | {
2 | "type": "object",
3 | "properties": {
4 | "inference_engine_name": {
5 | "type": "string"
6 | },
7 | "confidence": {
8 | "type": "number",
9 | "minimum": 0,
10 | "maximum": 100
11 | },
12 | "predictions": {
13 | "type": "number",
14 | "minimum": 0
15 | },
16 | "number_of_classes": {
17 | "type": "number"
18 | },
19 | "framework": {
20 | "type": "string"
21 | },
22 | "type": {
23 | "type": "string"
24 | },
25 | "network": {
26 | "type": "string"
27 | }
28 | },
29 | "required": [
30 | "inference_engine_name",
31 | "confidence",
32 | "predictions",
33 | "number_of_classes",
34 | "framework",
35 | "type",
36 | "network"
37 | ]
38 | }
--------------------------------------------------------------------------------
/src/main/inference/inference_engines_factory.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from inference.exceptions import ModelNotFound, ApplicationError, InvalidModelConfiguration, InferenceEngineNotFound, ModelNotLoaded
4 |
5 |
6 | class InferenceEngineFactory:
7 |
8 | @staticmethod
9 | def get_engine(path_to_model):
10 | """
11 | Reads the model's inference engine from the model's configuration and calls the right inference engine class.
12 | :param path_to_model: Model's path
13 | :return: The model's instance
14 | """
15 | if not os.path.exists(path_to_model):
16 | raise ModelNotFound()
17 | try:
18 | configuration = json.loads(open(os.path.join(path_to_model, 'config.json')).read())
19 | except Exception:
20 | raise InvalidModelConfiguration('config.json not found or corrupted')
21 | try:
22 | inference_engine_name = configuration['inference_engine_name']
23 | except Exception:
24 | raise InvalidModelConfiguration('missing inference engine name in config.json')
25 | try:
26 | # import one of the available inference engine class (in this project there's only one), and return a
27 | # model instance
28 | return getattr(__import__(inference_engine_name), 'InferenceEngine')(path_to_model)
29 | except ApplicationError as e:
30 | print(e)
31 | raise e
32 | except Exception as e:
33 | print(e)
34 | raise InferenceEngineNotFound(inference_engine_name)
35 |
--------------------------------------------------------------------------------
/src/main/inference/base_error.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from datetime import datetime
3 | from abc import ABC, abstractmethod
4 |
5 |
6 | class AbstractError(ABC):
7 |
8 | def __init__(self):
9 | """
10 | Sets the logger file, level, and format.
11 | The logging file will contain the logging level, request date, request status, and model response.
12 | """
13 | self.logger = logging.getLogger('logger')
14 | date = datetime.now().strftime('%Y-%m-%d')
15 | file_path = 'logs/' + date + '.log'
16 | self.handler = logging.FileHandler(file_path)
17 | self.handler.setLevel(logging.INFO)
18 | self.handler.setFormatter(logging.Formatter("%(levelname)s;%(asctime)s;%(message)s"))
19 | self.logger.addHandler(self.handler)
20 |
21 | @abstractmethod
22 | def info(self, message):
23 | """
24 | Logs an info message to the logging file.
25 | :param message: Containing the request status and the model response
26 | :return:
27 | """
28 | pass
29 |
30 | @abstractmethod
31 | def warning(self, message):
32 | """
33 | Logs a warning message to the logging file.
34 | :param message: Containing the request status and the model response
35 | :return:
36 | """
37 | pass
38 |
39 | @abstractmethod
40 | def error(self, message):
41 | """
42 | Logs an Error message to the logging file.
43 | :param message: Containing the request status and the model response
44 | :return:
45 | """
46 | pass
47 |
--------------------------------------------------------------------------------
/src/main/inference/errors.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from datetime import datetime
4 | from inference.base_error import AbstractError
5 |
6 |
7 | class Error(AbstractError):
8 |
9 | def __init__(self):
10 | if 'logs' not in os.listdir():
11 | os.mkdir('logs')
12 | self.date = None
13 | super().__init__()
14 |
15 | def info(self, message):
16 | self.check_date()
17 | self.logger.info(message)
18 |
19 | def warning(self, message):
20 | self.check_date()
21 | self.logger.warning(message)
22 |
23 | def error(self, message):
24 | self.check_date()
25 | self.logger.error(message)
26 |
27 | def check_date(self):
28 | """
29 | Divides logging per day. Each logging file corresponds to a specific day.
30 | It also removes all logging files exceeding one year.
31 | :return:
32 | """
33 | self.date = datetime.now().strftime('%Y-%m-%d')
34 | file_path = self.date + '.log'
35 | if file_path not in os.listdir('logs'):
36 | self.logger.removeHandler(self.handler)
37 | self.handler = logging.FileHandler('logs/' + file_path)
38 | self.handler.setLevel(logging.INFO)
39 | self.handler.setFormatter(logging.Formatter("%(levelname)s;%(asctime)s;%(message)s"))
40 | self.logger.addHandler(self.handler)
41 | oldest_log_file = os.listdir('logs')[0]
42 | oldest_date = oldest_log_file.split('.')[0]
43 | a = datetime.strptime(datetime.now().strftime('%Y-%m-%d'), '%Y-%m-%d')
44 | b = datetime.strptime(oldest_date, '%Y-%m-%d')
45 | delta = a - b
46 | if delta.days > 365:
47 | os.remove('logs/' + oldest_log_file)
48 |
--------------------------------------------------------------------------------
/src/main/inference/exceptions.py:
--------------------------------------------------------------------------------
1 | __metaclass__ = type
2 |
3 |
4 | class ApplicationError(Exception):
5 | """Base class for other exceptions"""
6 |
7 | def __init__(self, default_message, additional_message=''):
8 | self.default_message = default_message
9 | self.additional_message = additional_message
10 |
11 | def __str__(self):
12 | return self.get_message()
13 |
14 | def get_message(self):
15 | return self.default_message if self.additional_message == '' else "{}: {}".format(self.default_message,
16 | self.additional_message)
17 |
18 |
19 | class InvalidModelConfiguration(ApplicationError):
20 | """Raised when the model's configuration is corrupted"""
21 |
22 | def __init__(self, additional_message=''):
23 | # super('Invalid model configuration', additional_message)
24 | super().__init__('Invalid model configuration', additional_message)
25 |
26 |
27 | class ModelNotFound(ApplicationError):
28 | """Raised when the model is not found"""
29 |
30 | def __init__(self, additional_message=''):
31 | # super('Model not found', additional_message)
32 | super().__init__('Model not found', additional_message)
33 |
34 |
35 | class ModelNotLoaded(ApplicationError):
36 | """Raised when the model is not loaded"""
37 |
38 | def __init__(self, additional_message=''):
39 | # super('Error loading model', additional_message)
40 | super().__init__('Error loading model', additional_message)
41 |
42 |
43 | class InvalidInputData(ApplicationError):
44 | """Raised when the input data is corrupted"""
45 |
46 | def __init__(self, additional_message=''):
47 | # super('Invalid input data', additional_message)
48 | super().__init__('Invalid input data', additional_message)
49 |
50 |
51 | class InferenceEngineNotFound(ApplicationError):
52 | """Raised when the Inference Engine is not found"""
53 |
54 | def __init__(self, additional_message=''):
55 | # super('Inference engine not found', additional_message)
56 | super().__init__('Inference engine not found', additional_message)
57 |
--------------------------------------------------------------------------------
/src/main/object_detection/utils/static_shape.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Helper functions to access TensorShape values.
17 |
18 | The rank 4 tensor_shape must be of the form [batch_size, height, width, depth].
19 | """
20 |
21 |
22 | def get_batch_size(tensor_shape):
23 | """Returns batch size from the tensor shape.
24 |
25 | Args:
26 | tensor_shape: A rank 4 TensorShape.
27 |
28 | Returns:
29 | An integer representing the batch size of the tensor.
30 | """
31 | tensor_shape.assert_has_rank(rank=4)
32 | return tensor_shape[0].value
33 |
34 |
35 | def get_height(tensor_shape):
36 | """Returns height from the tensor shape.
37 |
38 | Args:
39 | tensor_shape: A rank 4 TensorShape.
40 |
41 | Returns:
42 | An integer representing the height of the tensor.
43 | """
44 | tensor_shape.assert_has_rank(rank=4)
45 | return tensor_shape[1].value
46 |
47 |
48 | def get_width(tensor_shape):
49 | """Returns width from the tensor shape.
50 |
51 | Args:
52 | tensor_shape: A rank 4 TensorShape.
53 |
54 | Returns:
55 | An integer representing the width of the tensor.
56 | """
57 | tensor_shape.assert_has_rank(rank=4)
58 | return tensor_shape[2].value
59 |
60 |
61 | def get_depth(tensor_shape):
62 | """Returns depth from the tensor shape.
63 |
64 | Args:
65 | tensor_shape: A rank 4 TensorShape.
66 |
67 | Returns:
68 | An integer representing the depth of the tensor.
69 | """
70 | tensor_shape.assert_has_rank(rank=4)
71 | return tensor_shape[3].value
72 |
--------------------------------------------------------------------------------
/src/main/inference/base_inference_engine.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from inference.exceptions import InvalidModelConfiguration, ModelNotLoaded, ApplicationError
3 |
4 |
5 | class AbstractInferenceEngine(ABC):
6 |
7 | def __init__(self, model_path):
8 | """
9 | Takes a model path and calls the load function.
10 | :param model_path: The model's path
11 | :return:
12 | """
13 | self.labels = []
14 | self.configuration = {}
15 | self.model_path = model_path
16 | try:
17 | self.validate_configuration()
18 | except ApplicationError as e:
19 | raise e
20 | try:
21 | self.load()
22 | except ApplicationError as e:
23 | raise e
24 | except Exception as e:
25 | print(e)
26 | raise ModelNotLoaded()
27 |
28 | @abstractmethod
29 | def load(self):
30 | """
31 | Loads the model based on the underlying implementation.
32 | """
33 | pass
34 |
35 | @abstractmethod
36 | def free(self):
37 | """
38 | Performs any manual memory implementation required to when unloading a model.
39 | Will be called when the class's destructor is called.
40 | """
41 | pass
42 |
43 | @abstractmethod
44 | async def infer(self, input_data, draw, predict_batch):
45 | """
46 | Performs the required inference based on the underlying implementation of this class.
47 | Could be used to return classification predictions, object detection coordinates...
48 | :param predict_batch: Boolean
49 | :param input_data: A single image
50 | :param draw: Used to draw bounding boxes on image instead of returning them
51 | :return: A bounding-box
52 | """
53 | pass
54 |
55 | @abstractmethod
56 | async def run_batch(self, input_data, draw, predict_batch):
57 | """
58 | Iterates over images and returns a prediction for each one.
59 | :param predict_batch: Boolean
60 | :param input_data: List of images
61 | :param draw: Used to draw bounding boxes on image instead of returning them
62 | :return: List of bounding-boxes
63 | """
64 | pass
65 |
66 | @abstractmethod
67 | def validate_configuration(self):
68 | """
69 | Validates that the model and its files are valid based on the underlying implementation's requirements.
70 | Can check for configuration values, folder structure...
71 | """
72 | pass
73 |
74 | @abstractmethod
75 | def set_model_configuration(self, data):
76 | """
77 | Takes the configuration from the config.json file
78 | :param data: Json data
79 | :return:
80 | """
81 | pass
82 |
83 | @abstractmethod
84 | def validate_json_configuration(self, data):
85 | """
86 | Validates the configuration of the config.json file.
87 | :param data: Json data
88 | :return:
89 | """
90 | pass
91 |
92 | def __del__(self):
93 | self.free()
94 |
--------------------------------------------------------------------------------
/docs/uml/InferenceSequenceDiagram.xml:
--------------------------------------------------------------------------------
1 | 7VxZd6M2FP41fkyOxM5jlknb02nPdDKnnT5ikG06GLkCZ+mvrwRik4SNEwROTvIwAxchgfTdT3fDC/Nm+/QTCXab33CEkoUBoqeFebswDAgNn/7HJM+lxPZgKViTOOKNGsF9/B/iQsCl+zhCWadhjnGSx7uuMMRpisK8IwsIwY/dZiucdEfdBWskCe7DIJGlf8VRvqneC4Dmws8oXm/40J7NLyyD8Mea4H3Kx1sY5qr4Ky9vg6ov3j7bBBF+bInMTwvzhmCcl0fbpxuUsLmtpq28767nav3cBKX5kBvcZbAMYGgtQ8NdWUtwYZQ9PATJns/Fwry6RWj3GQUkjdP1PSIPcYj44+fP1ZTRN9mxw/02+RyvUBKn9Ox6h0i8RTki9ErCxV8a2fXjJs7R/S4I2a2PFExUtsm3CT2D9JCubx7QW0h9niTBLouXxaiASggK9ySLH9BXlJUwYlK8z9lINzU8iqZsWVDEu6pnHhT9buOQHyfBEiXX9Tre4ASz4VNcvFCWE/wDVUK6vKD4q69UcGFDrOIkabW8K/6YnL7VXbCNE6YdfyISBWnAxVwVoMHPVQMFSbxOqSyka1xMorzoHAcPiOToqSXiIPgJYboA5Jk24VcthwOSK6zHTx9b6K8wvmkB37G4MOAat667blBHDzjwBoKwooE2CkW8tVCyw3GaF+Pb1wv7VoAdJvkGr3EaJG3gNWAA7x0MvTo+GB1uFxw1DbbRoQCHMQY4vv/x5Sr+zc5urn7/+rP/a/Trt1+eL0xbO0AmmEfLswbNIwSOBiUzVUz/S7pC9I1D9CldUwK9C8Ic83E+yP5M9ftksnfBcbJ3FDg0XS1krwDiB9mPRvbmyeiwu2QP7WFkX7UbneyhJeEBRdRc56fNkn9qpMKSN20+Y7zji/QPyvNnPu3BPsddNho+tRnekxAdQXcekDXq66hqxF7r4BIRlAQ55b+O4qsmvLj1ipDgudWAq07T8xcmaFYeAmHpXcF5ENr7XRoRfA16UD5As/L1m7wCDKqd30nozF5H8QM9XLPDryiIKjEdpnVF0ZjNUHpS83CDcYbYqm7Yv1WjJRFvO9pZiAlhmxZd+2LD7btRwD9V5bwLV8J2xKDZKmWalMiDUUJMfd4rfmEbR1GhOz082PiyKu04qLoS+9SuO3/kRdv9VbESuORdnaYCEmatLsQvTLPbA16tMqqnImedhlxJwT2VrUUNFrpxJAldjA/76oy31JPtK8P3jtpXYDJn2v8wrzSaV96p4BAIyBtoXdm6rCvT0o6PCaaxjl8dc6X1xKugNId3mDwGJKptBIL+3aMsX9SBX+XsilvyEuc53tILKI2uWFybyRIc/jiqTgNVoW1Io2SJH9s2dCGgF6qnKp6DztV3rtDFyd/s5BLQfZQLbp/al2+f22ctVBTCXrI9bC4DLeayZCuYttm1hz330vF8x7aha1jQrMI3VY+lI8A7OWA324YQT/O8bkfl20sdjWGJQFndb4LilhKnlRkKqD/EJpDZ0eCfDKelybaK12PhtwdKcyG7F4qH3TswBK963DsJVy4UHDIABgF0NN9M3jlmcdQVC6d8XnciGrFEdReWpUfdT3XfTYlWDrvvR9pr8t/laN89Yt43SHDBNfVG+VZd3r6swMku7wW4hIZrdNbognf1Skgal4IbDIUMwyhu8CGlO2Ru9ni/AVu7eBcU00+vXbHcxHxusYilnPHVdUbHidP1t4K8TGcYtJQedi+vvTjf5Mo2sqlyQ+0RbGTl2ivDIELKaSgaPmIhbyIW4nldDKpiIYauWIgShB/xkBHxcFDNB4PE7u5Gdd3AEWdetKPGi4koiEp7TETLVFr+zFMJfWnmJvAJat8SXgKjjkyUoQrXMsYJVRxasKO+B58WwffoOpL9N47vOnZZ2hQ9x5FcFE9Q9CMJRrG57wt41OCgWHBmwEIPioD1zwCwFSeejtjyTt2QtUWrVQ9kbfMkyFZPpdendmaFLPVTgWWL8WB31Hjw2fCksL6m4Xa7GBoD9oXgi2kIieiR4Cs/sDkBIGUnnweZeZBntU/DPGZB5bca5XF6DLMXRXl8o5touBin1qF6yMr16t6vL8RTQbu1/J/LdS9zDNvye5Exlp6yEvfOYQsJCVrlwyI0dKZHcXN9wc01lflIVWJXl+FtzG3HAMfr7gm+PZLhPTS+78tbxbmW3nliaTiYwG4wVG6uopYu35N0ISUI4zTLgzREr6h+K/SDqscJ7cuHYNsIihoqeVdFc73s84KiOWjCbi1S9dHCK7cW6te7Qr/TbS/zmruU2aDgofljeWiDUs6DKor1lEicSmuO3UWfBbwJaM2dGx+2K+DDnxIfQ0oS1OGnqeFhCR+iWIcqyOX8taO+W0MZjRpnciah3ivv92GIsux9GbkWgELe2FSGl6c0c6sE5hnWnPQp3dHQhZ66lJO/BrEPKuep7fVwvSlXXtaqB4pz8I3s5aTym9ZDKEbrVVqorCLWpoXGrFsui6GYVmfPnTn6WFZVjx5+9CyBgG0xTT40/igWpUkdjRR/NDz1OH3PdaS9JhbpLUrbERTFxfdZb70urVLRj7q0nvmZxascZEgc/G5Dd/VqlfCqFVL8CHwkougbR6/iy4mK3g833pMFYXpO30ZyrDBLnw0xb6VIwWpO129vCaaJWFe/a3I2SXZf1EpNteuOaI5MUYtejSmnKZttv8lUyo7mm9n6q0DFGFs/cJxu5HecXKUQpa6q6PXv+5YhgeBLufpjrvq5EL+Yp3SgTPvulKxvzfLjHu1yK1sI5nvQmpj0vaEcr94dpg4M+QJXO4dqnGQ7zxXvFp5Pc9TWkn9ApI7a7tqKT7V6h9PsncWOHHH6TYXdp/rxJ30MMHe6BroiA9gvStegpzjv1sNRwd/FEACY/LzpkJ08t07GtyIHRqlmJxTR9XOcwzEiq9qz1e2PpY18cbSJCUiuuqgJaFzOmcHUrJR5jNoFAzhijcEopiYUY0x1VmsCa3MWJ/cVVGINpJLzyFmJ0WPnyAcGR9rrcTtt2e2Ug07v0faAPYZjx/ZQfZH6AtuDnja/ol2uXPNT5ean/wE=
--------------------------------------------------------------------------------
/docs/uml/InferenceClassDiagram.drawio:
--------------------------------------------------------------------------------
1 | 7V1bk6O2Ev41rjp58BZ348fxXHZzMptsZU7tyebFJYNsk8WICDyeya+PBBIGJF8mHiObaGtrFzWSkLo/Wt2tFh7Yt6uXjxiky88ohPHAMsKXgX03sMgfwyX/UcprSTF9xyspCxyFjLYlPEV/QUY0GHUdhTBrVMwRivMobRIDlCQwyBs0gDHaNKvNUdx8agoWUCA8BSAWqf+PwnxZUn3X2NI/wWix5E82DXZnBoLvC4zWCXvewLLnxZ/y9grwvlj9bAlCtKmR7PuBfYsRysur1cstjClzOdvKdg877lbjxjDJj2owAsEMuCPHh6499mZD1sMziNeMFz8mc0i6CyAbcf7KuUQGn9LL9Sp+wGBFLiebZZTDpxQElL4h6CC0Zb6KSckklwVnIH22QUrV5GkhQKsoYNcxmMF4UrHyFsUIk1sJSugzshyj75VcaLdzlOQPYBXFFG9fIQ5BAhiZYcuk3YI4WiSkEBDmQNLhROQWnz7EOXypkRj3PkK0gjl+JVXY3aHJMcuwPnTHjLDZIsesBL6swWbMmwIG10XV/VZk5IJJ7UgJ8ifVRHhLGIFRHJNJCzLcRKsYFHytyangHKtUCGcZxeEjeEVrOvQsJ6LhpckS4egvUh/wxuQ25ny3PNpbFMctGW4bPdHO2GMwzEizL1weZkV6BFle4SSOQZpFs2JwtMoK4EWUTFCeoxVHFpvVQ+3J21exjzh0Wyi0JCi0DAkILecsIBT1SBhPM4ifI6oabsitOwjTRwhwEiWLJ3ajjU4y+byJzJLpLUlIAMZ5HMM57YEyMiL6/YaRc0Q1U0YUFXn6Y1HnztlSfmX8cXaoNET6m8cFZpZRGMKkwFQOcjCr3pkURUle8NSdkL+Ey7fGB3fgknndkrK5LZO/tDrOyXtK5geiAgWQYH4DKe5l+Nj51h9GDIOI5R0JEO8c+HAEfECybuNpjBbkZV6UELmnJA2KDkHhWt2B4jl5fUx+h/Dr0Ln5/BO+T3/xbiX2x8CaxAiE//mhxMRNGv0Ks5QwRauLLpEx8hUjw5IiI8ry6Yo6IJkGiFqAmIajGCG2DCF4nZQA0fhQjA/7WIv0XPgQbY46PqYzkAdLjRLFKHE7NEulKHH3rzPTwinUq41qnPiqLVXvAE4ClMyjhcaJYi/XUG23jgScHBP50HG5fsTlbIvpqbcH5t4jOixFpC8gcli6UNMwopspRXSOXmlFdaSi2vviX2RETjrisQiMGcggW9DCCJfYeMoxDdFpdHSKDvWhOXFjicXmmg72BKEYgkTjo1t8KA/QmfLYrRiheyQkjY5u0aE+OmdKA7jMpyZKBFYhfr2+KEGI8vicKQ3gLmDeirlo/aEAHcrjcqY0fEvR0Yy0aM9FATqUR+NMMRwnzx7QgZV+BFbcMdMHbw6sVAvd+4NQjPUNQ5BDppjolVZMpyim6i2/npiKKYm20SQnyIIpjyzhSQOjW2CoD6dIom1LkISxRoZaZCgPpEgyuImlGyxh8H1KlxNu6X5FpBuNjm7RoT6QYssCKV5MZ502sOD9uabnSJi9WGLGSF/IvwWDjJI+LARL7zm1exRGQ4YIeo+ZblWf5GrB/i+eHNUIYEUFHoulH2kfcwoLTm42afRIeBO1nzLDAoUTJiCDzOKv2s/atQktbdOWmDKNH0Hi8zN3T/UfMbl8qSout/sekhtSq2DXvGUzOcvA5EtSt4Mg2o++FMkcNVWfyBxab1PurIsVux1ukc8uG29zGC31Lde6LdVc06VUj7c8LqvtmH2C8TOkvQ7OfxjrdMfO9444ymKOZY6ddS7HzvcFQcFwAXnEgKxDS7RACYjvt9QWf7d1HhGVaSGCP2CevzK5gXWOmgKCL1H+G21O1ryy9K125+6F9VwUXnkhIfP9rV74tu2BFrfNitK2XXhDD05uJUsoDwW47gzpesoEm6E1DuAe3tksoS4n7w7ctxZ7zHumnN0LFAxjkEfPsDEOmcyLpmRa4LVWgRkZ256/UMIWf2brKBULbT3sqD50Tqo+dEZuC57lgLdgrWZ+gskgJjVWxzrvE6LS4AMIcsTeFB0u62G4rAm7I51f71wK1ZZmT9JgPizgyNfNFkq1t3OSt1OpgesJnznirs8hTGiV1Q+VZR+rpfxzaSlXuuWYYhTALCNaQG85vptqqt7z61FNrvSkSIjBZjqj7xaRBbl4od8x0dE6RSBRHuX3XSWeI/cCq4JqL9C1r8wLNPymxTxufUjnUP1uHDtXljKlY8Egg4KVeK1R4SL1rZ74pjwkXOZbrTF5/VBSt4A6HkeZO5oW39Sq542qiPbWP1GxLziN10nTaux6sGXfxAxIhIGVJ6BrQ2rW6pytRKlFdL9z2gRc9wPJqiRDdWOomPFHhpJjRqM3FoSMsbFgDMr3Ec7nUIphr30Hg3Ukox+RjFbwtdqnOpipeLYjoK48UxHoeMY7uaru9WUqupJMxWITW5/KUQgL5REMV8xTpCbROqBBUH3iUyU2lGcqemII9CZNYyIqapfqAxh9NmqEL1u4jnKzxhPt62EI52Ad59MVUVfFZ7/1YvZOCsu7vu1kT2L3gjCMqLoCscbIBWBEucHjiXYwy01h8NDn1JUCRLnVM7IEmXeYDcqvazt615QNOjhuG5B/LeL9tgFPE7mjUuTGVYvcOlLkPP6gfOe3adaaI7sOn0P1Tb+Dbd+RmqyCXmggfu7yIB7tC1NBo+uT+QfDbGWgeIcFT0pfII4I16ireu68FOdY7WRfFhrGekH6xxrgWCOEZ6krXpGqA1DVksQesWtJEhrY+5OX7JOqO43BnGfB4zuy16T8LgXuHMWHVZx7USqOf1NIy/x8Njc3hi5Owx0wukUNd0Bnue1TpW9uYLL0kfPqOVtj/gjM7/w9sWtJL7ZG7UPOlrsXjkKDTtKLfTVhjzNlr4vbZyfYkMZ1Ac429quzww1s/4DCdKwTG3QDaa9PkN4J4GvBpWM1QWAdWJaFcxZvbrAfZgcHyHcf+ARLTcEatSzet8H1/v5/P8P16M+R735bfJpN/vtsBZLfQ+Qaat+Gk7Df395KWkVhWCKbZjvs3ACiezy0O4JndhZir9Y8PnlA+L1ak+OxsatTkhqbOu+QOiBltHiQpheMtg3zwhgtnqntBaNFRPMPlKpitJib1QtGu8w13JrO3IlVxWgx6cj88OGYbMiL57XZVB6mJ0La6ZLTYvZOLyDdjr+aY1stosUMmF7wWdDRlqNYR4u51UZPVIfltwwPZ+yqVR6S3y/qC7Ndy24zW8y07JbZPXVcbKMVtLBMWTpal5y2+srp9qqonNO99REbjJaeTuySzf8SD1E5n3vqIJotPHuWYj6L/mFf7A6LfxiXp4Io9g8lP8DSF1a7fG67tUe3rO6pj+i2v/skO7LUJZ9FF7EXfB754wafR4pXQ8mPbvSCzw5LLGh9n1odn3tqdYyE3yNTHZa2JHHpPnDab39CvTKs35/TpIgR/RRTde8jBunyM/3yGCH+DQ==
--------------------------------------------------------------------------------
/src/main/ocr.py:
--------------------------------------------------------------------------------
1 | import pytesseract
2 | import unicodedata
3 | import re
4 | import numpy as np
5 |
6 |
7 |
8 | # Define class variables
9 |
10 | bounding_box_order = ["left", "top", "right", "bottom"]
11 |
12 | # This method will take the model bounding box predictions and return the extracted text inside each box
13 | def one_shot_ocr_service(image, output):
14 | # iterate over detections
15 | response = []
16 | detections = output['bounding-boxes']
17 |
18 | for i in range(0, len(detections)):
19 |
20 | # crop image for every detection:
21 | coordinates = (detections[i]["coordinates"])
22 | cropped = image.crop((float(coordinates["left"]), float(
23 | coordinates["top"]), float(coordinates["right"]), float(coordinates["bottom"])))
24 |
25 | # convert image to grayscale for better accuracy
26 | processed_img=cropped.convert('L')
27 |
28 | # extract text with positive confidence from cropped image
29 | df = pytesseract.image_to_data(processed_img, output_type='data.frame')
30 | valid_df = df[df["conf"] > 0]
31 | extracted_text = " ".join(valid_df["text"].values)
32 |
33 | # process text
34 | extracted_text = str(unicodedata.normalize('NFKD', extracted_text).encode('ascii', 'ignore').decode()).strip().replace("\n", " ").replace(
35 | "...", ".").replace("..", ".").replace('”', ' ').replace('“', ' ').replace("'", ' ').replace('\"', '').replace("alt/1m", "").strip()
36 | extracted_text = re.sub(
37 | '[^A-Za-z0-9.!?,;%:=()\[\]$€&/\- ]+', '', extracted_text)
38 | extracted_text = " ".join(extracted_text.split())
39 |
40 | # wrap each prediction inside a dictionary
41 | if len(extracted_text) is not 0:
42 | prediction = dict()
43 | prediction["text"] = extracted_text
44 | bounding_box = [coordinates[el] for el in bounding_box_order]
45 | prediction["box"] = bounding_box
46 | prediction["score"] = valid_df["conf"].mean()/100.0
47 |
48 | response.append(prediction)
49 |
50 | return response
51 |
52 | # This method will take an image and return the extracted text from the image
53 | def ocr_service(image):
54 | # convert image to grayscale for better accuracy
55 | processed_img=image.convert('L')
56 |
57 | # Get data including boxes, confidences, line and page numbers
58 | df = pytesseract.image_to_data(processed_img, output_type='data.frame')
59 | valid_df = df[df["conf"] > 0]
60 |
61 | # process text
62 | extracted_text = " ".join(valid_df["text"].values)
63 | extracted_text = str(unicodedata.normalize('NFKD', extracted_text).encode('ascii', 'ignore').decode()).strip().replace("\n", " ").replace(
64 | "...", ".").replace("..", ".").replace('”', ' ').replace('“', ' ').replace("'", ' ').replace('\"', '').replace("alt/1m", "").strip()
65 | extracted_text = re.sub(
66 | '[^A-Za-z0-9.!?,;%:=()\[\]$€&/\- ]+', '', extracted_text)
67 | extracted_text = " ".join(extracted_text.split())
68 |
69 | # calculate the bounding box data based on pytesseract results
70 | coordinates = {}
71 | index = valid_df.index.values
72 | coordinates["left"] = valid_df.loc[index[0], "left"]
73 | coordinates["top"] = valid_df.loc[index[0], "top"]
74 | coordinates["bottom"] = valid_df.loc[index[-1],
75 | "top"] + valid_df.loc[index[-1], "height"]
76 | coordinates["right"] = valid_df.loc[index[-1],
77 | "left"] + valid_df.loc[index[-1], "width"]
78 | bounding_box = [coordinates[el].item() for el in bounding_box_order]
79 |
80 | # wrap each prediction inside a dictionary
81 | response = {}
82 | response["text"] = extracted_text
83 | response["box"] = bounding_box
84 | response["score"] = valid_df["conf"].mean()/100.0
85 |
86 | return [response]
87 |
--------------------------------------------------------------------------------
/README-docker_swarm.md:
--------------------------------------------------------------------------------
1 | # Tensorflow CPU Inference API For Windows and Linux with docker swarm
2 | Please use **docker swarm** only if you need to:
3 |
4 | * Provide redundancy in terms of API containers: In case a container went down, the incoming requests will be redirected to another running instance.
5 |
6 | * Coordinate between the containers: Swarm will orchestrate between the APIs and choose one of them to listen to the incoming request.
7 |
8 | * Scale up the Inference service in order to get a faster prediction especially if there's traffic on the service.
9 |
10 | ## Run the docker container
11 |
12 | Docker swarm can scale up the API into multiple replicas and can be used on one or multiple hosts(Linux users only). In both cases, a docker swarm setup is required for all hosts.
13 |
14 | #### Docker swarm setup
15 |
16 | 1- Initialize Swarm:
17 |
18 | ```sh
19 | docker swarm init
20 | ```
21 |
22 | 2- On the manager host, open the cpu-inference.yaml file and specify the number of replicas needed. In case you are using multiple hosts (With multiple hosts section), the number of replicas will be divided across all hosts.
23 |
24 | ```yaml
25 | version: "3"
26 |
27 | services:
28 | api:
29 | ports:
30 | - "4343:4343"
31 | image: tensorflow_inference_api_cpu
32 | volumes:
33 | - "/mnt/models:/models"
34 | deploy:
35 | replicas: 1
36 | update_config:
37 | parallelism: 2
38 | delay: 10s
39 | restart_policy:
40 | condition: on-failure
41 | ```
42 |
43 | **Notes about cpu-inference.yaml:**
44 |
45 | * the volumes field on the left of ":" should be an absolute path, can be changeable by the user, and represents the models directory on your Operating System
46 | * the following volume's field ":/models" should never be changed
47 |
48 | #### With one host
49 |
50 | Deploy the API:
51 |
52 | ```sh
53 | docker stack deploy -c cpu-inference.yaml tensorflow-cpu
54 | ```
55 |
56 | 
57 |
58 | #### With multiple hosts (Linux users only)
59 |
60 | 1- **Make sure hosts are reachable on the same network**.
61 |
62 | 2- Choose a host to be the manager and run the following command on the chosen host to generate a token so the other hosts can join:
63 |
64 | ```sh
65 | docker swarm join-token worker
66 | ```
67 |
68 | A command will appear on your terminal, copy and paste it on the other hosts, as seen in the below image
69 |
70 | 3- Deploy your application using:
71 |
72 | ```sh
73 | docker stack deploy -c cpu-inference.yaml tensorflow-cpu
74 | ```
75 |
76 | 
77 |
78 | #### Useful Commands
79 |
80 | 1- In order to scale up the service to 4 replicas for example use this command:
81 |
82 | ```sh
83 | docker service scale tensorflow-cpu_api=4
84 | ```
85 |
86 | 2- To check the available workers:
87 |
88 | ```sh
89 | docker node ls
90 | ```
91 |
92 | 3- To check on which node the container is running:
93 |
94 | ```sh
95 | docker service ps tensorflow-cpu_api
96 | ```
97 |
98 | 4- To check the number of replicas:
99 |
100 | ```sh
101 | docker service ls
102 | ```
103 |
104 | ## Benchmarking
105 |
106 | Here are two graphs showing time of prediction for different number of requests at the same time.
107 |
108 |
109 | 
110 |
111 |
112 | 
113 |
114 |
115 | We can see that both graphs got the same result no matter what is the number of received requests at the same time. When we increase the number of workers (hosts) we are able to speed up the inference by at least 2 times. For example we can see in the last column we were able to process 40 requests in:
116 |
117 | - 17.5 seconds with 20 replicas in 1 machine
118 | - 8.8 seconds with 20 replicas in each of the 2 machines
119 |
120 | Moreover, in case one of the machines is down the others are always ready to receive requests.
121 |
122 | Finally since we are predicting on CPU scaling more replicas doesn't mean a faster prediction, 4 containers was faster than 20.
123 |
--------------------------------------------------------------------------------
/src/main/object_detection/protos/string_int_label_map_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: object_detection/protos/string_int_label_map.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='object_detection/protos/string_int_label_map.proto',
20 | package='object_detection.protos',
21 | syntax='proto2',
22 | serialized_pb=_b('\n2object_detection/protos/string_int_label_map.proto\x12\x17object_detection.protos\"G\n\x15StringIntLabelMapItem\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02id\x18\x02 \x01(\x05\x12\x14\n\x0c\x64isplay_name\x18\x03 \x01(\t\"Q\n\x11StringIntLabelMap\x12<\n\x04item\x18\x01 \x03(\x0b\x32..object_detection.protos.StringIntLabelMapItem')
23 | )
24 |
25 |
26 |
27 |
28 | _STRINGINTLABELMAPITEM = _descriptor.Descriptor(
29 | name='StringIntLabelMapItem',
30 | full_name='object_detection.protos.StringIntLabelMapItem',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | _descriptor.FieldDescriptor(
36 | name='name', full_name='object_detection.protos.StringIntLabelMapItem.name', index=0,
37 | number=1, type=9, cpp_type=9, label=1,
38 | has_default_value=False, default_value=_b("").decode('utf-8'),
39 | message_type=None, enum_type=None, containing_type=None,
40 | is_extension=False, extension_scope=None,
41 | options=None, file=DESCRIPTOR),
42 | _descriptor.FieldDescriptor(
43 | name='id', full_name='object_detection.protos.StringIntLabelMapItem.id', index=1,
44 | number=2, type=5, cpp_type=1, label=1,
45 | has_default_value=False, default_value=0,
46 | message_type=None, enum_type=None, containing_type=None,
47 | is_extension=False, extension_scope=None,
48 | options=None, file=DESCRIPTOR),
49 | _descriptor.FieldDescriptor(
50 | name='display_name', full_name='object_detection.protos.StringIntLabelMapItem.display_name', index=2,
51 | number=3, type=9, cpp_type=9, label=1,
52 | has_default_value=False, default_value=_b("").decode('utf-8'),
53 | message_type=None, enum_type=None, containing_type=None,
54 | is_extension=False, extension_scope=None,
55 | options=None, file=DESCRIPTOR),
56 | ],
57 | extensions=[
58 | ],
59 | nested_types=[],
60 | enum_types=[
61 | ],
62 | options=None,
63 | is_extendable=False,
64 | syntax='proto2',
65 | extension_ranges=[],
66 | oneofs=[
67 | ],
68 | serialized_start=79,
69 | serialized_end=150,
70 | )
71 |
72 |
73 | _STRINGINTLABELMAP = _descriptor.Descriptor(
74 | name='StringIntLabelMap',
75 | full_name='object_detection.protos.StringIntLabelMap',
76 | filename=None,
77 | file=DESCRIPTOR,
78 | containing_type=None,
79 | fields=[
80 | _descriptor.FieldDescriptor(
81 | name='item', full_name='object_detection.protos.StringIntLabelMap.item', index=0,
82 | number=1, type=11, cpp_type=10, label=3,
83 | has_default_value=False, default_value=[],
84 | message_type=None, enum_type=None, containing_type=None,
85 | is_extension=False, extension_scope=None,
86 | options=None, file=DESCRIPTOR),
87 | ],
88 | extensions=[
89 | ],
90 | nested_types=[],
91 | enum_types=[
92 | ],
93 | options=None,
94 | is_extendable=False,
95 | syntax='proto2',
96 | extension_ranges=[],
97 | oneofs=[
98 | ],
99 | serialized_start=152,
100 | serialized_end=233,
101 | )
102 |
103 | _STRINGINTLABELMAP.fields_by_name['item'].message_type = _STRINGINTLABELMAPITEM
104 | DESCRIPTOR.message_types_by_name['StringIntLabelMapItem'] = _STRINGINTLABELMAPITEM
105 | DESCRIPTOR.message_types_by_name['StringIntLabelMap'] = _STRINGINTLABELMAP
106 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
107 |
108 | StringIntLabelMapItem = _reflection.GeneratedProtocolMessageType('StringIntLabelMapItem', (_message.Message,), dict(
109 | DESCRIPTOR = _STRINGINTLABELMAPITEM,
110 | __module__ = 'object_detection.protos.string_int_label_map_pb2'
111 | # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMapItem)
112 | ))
113 | _sym_db.RegisterMessage(StringIntLabelMapItem)
114 |
115 | StringIntLabelMap = _reflection.GeneratedProtocolMessageType('StringIntLabelMap', (_message.Message,), dict(
116 | DESCRIPTOR = _STRINGINTLABELMAP,
117 | __module__ = 'object_detection.protos.string_int_label_map_pb2'
118 | # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMap)
119 | ))
120 | _sym_db.RegisterMessage(StringIntLabelMap)
121 |
122 |
123 | # @@protoc_insertion_point(module_scope)
124 |
--------------------------------------------------------------------------------
/src/main/object_detection/utils/label_map_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Label map utility functions."""
17 |
18 | import logging
19 |
20 | import tensorflow as tf
21 | from google.protobuf import text_format
22 | from object_detection.protos import string_int_label_map_pb2
23 |
24 |
25 | def _validate_label_map(label_map):
26 | """Checks if a label map is valid.
27 |
28 | Args:
29 | label_map: StringIntLabelMap to validate.
30 |
31 | Raises:
32 | ValueError: if label map is invalid.
33 | """
34 | for item in label_map.item:
35 | if item.id < 0:
36 | raise ValueError('Label map ids should be >= 0.')
37 | if (item.id == 0 and item.name != 'background' and
38 | item.display_name != 'background'):
39 | raise ValueError('Label map id 0 is reserved for the background label')
40 |
41 |
42 | def create_category_index(categories):
43 | """Creates dictionary of COCO compatible categories keyed by category id.
44 |
45 | Args:
46 | categories: a list of dicts, each of which has the following keys:
47 | 'id': (required) an integer id uniquely identifying this category.
48 | 'name': (required) string representing category name
49 | e.g., 'cat', 'dog', 'pizza'.
50 |
51 | Returns:
52 | category_index: a dict containing the same entries as categories, but keyed
53 | by the 'id' field of each category.
54 | """
55 | category_index = {}
56 | for cat in categories:
57 | category_index[cat['id']] = cat
58 | return category_index
59 |
60 |
61 | def get_max_label_map_index(label_map):
62 | """Get maximum index in label map.
63 |
64 | Args:
65 | label_map: a StringIntLabelMapProto
66 |
67 | Returns:
68 | an integer
69 | """
70 | return max([item.id for item in label_map.item])
71 |
72 |
73 | def convert_label_map_to_categories(label_map,
74 | max_num_classes,
75 | use_display_name=True):
76 | """Loads label map proto and returns categories list compatible with eval.
77 |
78 | This function loads a label map and returns a list of dicts, each of which
79 | has the following keys:
80 | 'id': (required) an integer id uniquely identifying this category.
81 | 'name': (required) string representing category name
82 | e.g., 'cat', 'dog', 'pizza'.
83 | We only allow class into the list if its id-label_id_offset is
84 | between 0 (inclusive) and max_num_classes (exclusive).
85 | If there are several items mapping to the same id in the label map,
86 | we will only keep the first one in the categories list.
87 |
88 | Args:
89 | label_map: a StringIntLabelMapProto or None. If None, a default categories
90 | list is created with max_num_classes categories.
91 | max_num_classes: maximum number of (consecutive) label indices to include.
92 | use_display_name: (boolean) choose whether to load 'display_name' field
93 | as category name. If False or if the display_name field does not exist,
94 | uses 'name' field as category names instead.
95 | Returns:
96 | categories: a list of dictionaries representing all possible categories.
97 | """
98 | categories = []
99 | list_of_ids_already_added = []
100 | if not label_map:
101 | label_id_offset = 1
102 | for class_id in range(max_num_classes):
103 | categories.append({
104 | 'id': class_id + label_id_offset,
105 | 'name': 'category_{}'.format(class_id + label_id_offset)
106 | })
107 | return categories
108 | for item in label_map.item:
109 | if not 0 < item.id <= max_num_classes:
110 | logging.info('Ignore item %d since it falls outside of requested '
111 | 'label range.', item.id)
112 | continue
113 | if use_display_name and item.HasField('display_name'):
114 | name = item.display_name
115 | else:
116 | name = item.name
117 | if item.id not in list_of_ids_already_added:
118 | list_of_ids_already_added.append(item.id)
119 | categories.append({'id': item.id, 'name': name})
120 | return categories
121 |
122 |
123 | def load_labelmap(path):
124 | """Loads label map proto.
125 |
126 | Args:
127 | path: path to StringIntLabelMap proto text file.
128 | Returns:
129 | a StringIntLabelMapProto
130 | """
131 | with tf.gfile.GFile(path, 'r') as fid:
132 | label_map_string = fid.read()
133 | label_map = string_int_label_map_pb2.StringIntLabelMap()
134 | try:
135 | text_format.Merge(label_map_string, label_map)
136 | except text_format.ParseError:
137 | label_map.ParseFromString(label_map_string)
138 | _validate_label_map(label_map)
139 | return label_map
140 |
141 |
142 | def get_label_map_dict(label_map_path, use_display_name=False):
143 | """Reads a label map and returns a dictionary of label names to id.
144 |
145 | Args:
146 | label_map_path: path to label_map.
147 | use_display_name: whether to use the label map items' display names as keys.
148 |
149 | Returns:
150 | A dictionary mapping label names to id.
151 | """
152 | label_map = load_labelmap(label_map_path)
153 | label_map_dict = {}
154 | for item in label_map.item:
155 | if use_display_name:
156 | label_map_dict[item.display_name] = item.id
157 | else:
158 | label_map_dict[item.name] = item.id
159 | return label_map_dict
160 |
161 |
162 | def create_category_index_from_labelmap(label_map_path):
163 | """Reads a label map and returns a category index.
164 |
165 | Args:
166 | label_map_path: Path to `StringIntLabelMap` proto text file.
167 |
168 | Returns:
169 | A category index, which is a dictionary that maps integer ids to dicts
170 | containing categories, e.g.
171 | {1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}, ...}
172 | """
173 | label_map = load_labelmap(label_map_path)
174 | max_num_classes = max(item.id for item in label_map.item)
175 | categories = convert_label_map_to_categories(label_map, max_num_classes)
176 | return create_category_index(categories)
177 |
178 |
179 | def create_class_agnostic_category_index():
180 | """Creates a category index with a single `object` class."""
181 | return {1: {'id': 1, 'name': 'object'}}
182 |
--------------------------------------------------------------------------------
/src/main/deep_learning_service.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import json
4 | import uuid
5 | from inference.inference_engines_factory import InferenceEngineFactory
6 | from inference.exceptions import ModelNotFound, InvalidModelConfiguration, ModelNotLoaded, InferenceEngineNotFound, \
7 | InvalidInputData, ApplicationError
8 |
9 |
10 | class DeepLearningService:
11 |
12 | def __init__(self):
13 | """
14 | Sets the models base directory, and initializes some dictionaries.
15 | Saves the loaded model's hashes to a json file, so the values are saved even though the API went down.
16 | """
17 | # dictionary to hold the model instances (model_name: string -> model_instance: AbstractInferenceEngine)
18 | self.models_dict = {}
19 | # read from json file and append to dict
20 | file_name = '/models_hash/model_hash.json'
21 | file_exists = os.path.exists(file_name)
22 | if file_exists:
23 | try:
24 | with open(file_name) as json_file:
25 | self.models_hash_dict = json.load(json_file)
26 | except:
27 | self.models_hash_dict = {}
28 | else:
29 | with open('/models_hash/model_hash.json', 'w'):
30 | self.models_hash_dict = {}
31 | self.labels_hash_dict = {}
32 | self.base_models_dir = '/models'
33 |
34 | def load_model(self, model_name, force_reload=False):
35 | """
36 | Loads a model by passing the model path to the factory.
37 | The factory will return a loaded model instance that will be stored in a dictionary.
38 | :param model_name: Model name
39 | :param force_reload: Boolean to specify if we need to reload a model on each call
40 | :return: Boolean
41 | """
42 | if not force_reload and self.model_loaded(model_name):
43 | return True
44 | model_path = os.path.join(self.base_models_dir, model_name)
45 | try:
46 | self.models_dict[model_name] = InferenceEngineFactory.get_engine(model_path)
47 | return True
48 | except ApplicationError as e:
49 | raise e
50 |
51 | def load_all_models(self):
52 | """
53 | Loads all the available models.
54 | :return: Returns a List of all models and their respective hashed values
55 | """
56 | self.load_models(self.list_models())
57 | models = self.list_models()
58 | for model in models:
59 | if model not in self.models_hash_dict:
60 | self.models_hash_dict[model] = str(uuid.uuid4())
61 | for key in list(self.models_hash_dict):
62 | if key not in models:
63 | del self.models_hash_dict[key]
64 | # append to json file
65 | with open('/models_hash/model_hash.json', "w") as fp:
66 | json.dump(self.models_hash_dict, fp)
67 | return self.models_hash_dict
68 |
69 | def load_models(self, model_names):
70 | """
71 | Loads a set of available models.
72 | :param model_names: List of available models
73 | :return:
74 | """
75 | for model in model_names:
76 | self.load_model(model)
77 |
78 | async def run_model(self, model_name, input_data, draw, predict_batch):
79 | """
80 | Loads the model in case it was never loaded and calls the inference engine class to get a prediction.
81 | :param model_name: Model name
82 | :param input_data: Batch of images or a single image
83 | :param draw: Boolean to specify if we need to draw the response on the input image
84 | :param predict_batch: Boolean to specify if there is a batch of images in a request or not
85 | :return: Model response in case draw was set to False, else an actual image
86 | """
87 | if re.match(r'[0-9a-fA-F]{8}\-[0-9a-fA-F]{4}\-[0-9a-fA-F]{4}\-[0-9a-fA-F]{4}\-[0-9a-fA-F]{12}', model_name,
88 | flags=0):
89 | for key, value in self.models_hash_dict.items():
90 | if value == model_name:
91 | model_name = key
92 | if self.model_loaded(model_name):
93 | try:
94 | if predict_batch:
95 | return await self.models_dict[model_name].run_batch(input_data, draw, predict_batch)
96 | else:
97 | if not draw:
98 | return await self.models_dict[model_name].infer(input_data, draw, predict_batch)
99 | else:
100 | await self.models_dict[model_name].infer(input_data, draw, predict_batch)
101 | except ApplicationError as e:
102 | raise e
103 | else:
104 | try:
105 | self.load_model(model_name)
106 | return await self.run_model(model_name, input_data, draw, predict_batch)
107 | except ApplicationError as e:
108 | raise e
109 |
110 | def list_models(self):
111 | """
112 | Lists all the available models.
113 | :return: List of models
114 | """
115 | return [folder for folder in os.listdir(self.base_models_dir) if
116 | os.path.isdir(os.path.join(self.base_models_dir, folder))]
117 |
118 | def model_loaded(self, model_name):
119 | """
120 | Returns the model in case it was loaded.
121 | :param model_name: Model name
122 | :return: Model name
123 | """
124 | return model_name in self.models_dict.keys()
125 |
126 | def get_labels(self, model_name):
127 | """
128 | Loads the model in case it's not loaded.
129 | Returns the model's labels.
130 | :param model_name: Model name
131 | :return: List of model labels
132 | """
133 | if not self.model_loaded(model_name):
134 | self.load_model(model_name)
135 | return self.models_dict[model_name].labels
136 |
137 | def get_labels_custom(self, model_name):
138 | """
139 | Hashes every label of a specific model.
140 | :param model_name: Model name
141 | :return: A list of mode's labels with their hashed values
142 | """
143 | if re.match(r'[0-9a-fA-F]{8}\-[0-9a-fA-F]{4}\-[0-9a-fA-F]{4}\-[0-9a-fA-F]{4}\-[0-9a-fA-F]{12}', model_name,
144 | flags=0):
145 | for key, value in self.models_hash_dict.items():
146 | if value == model_name:
147 | model_name = key
148 | models = self.list_models()
149 | if model_name not in self.labels_hash_dict:
150 | model_dict = {}
151 | for label in self.models_dict[model_name].labels:
152 | model_dict[label] = str(uuid.uuid4())
153 | self.labels_hash_dict[model_name] = model_dict
154 | for key in list(self.labels_hash_dict):
155 | if key not in models:
156 | del self.labels_hash_dict[key]
157 | return self.labels_hash_dict[model_name]
158 |
159 | def get_config(self, model_name):
160 | """
161 | Returns the model's configuration.
162 | :param model_name: Model name
163 | :return: List of model's configuration
164 | """
165 | if not self.model_loaded(model_name):
166 | self.load_model(model_name)
167 | return self.models_dict[model_name].configuration
168 |
--------------------------------------------------------------------------------
/src/main/object_detection/core/box_list.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Bounding Box List definition.
17 |
18 | BoxList represents a list of bounding boxes as tensorflow
19 | tensors, where each bounding box is represented as a row of 4 numbers,
20 | [y_min, x_min, y_max, x_max]. It is assumed that all bounding boxes
21 | within a given list correspond to a single image. See also
22 | box_list_ops.py for common box related operations (such as area, iou, etc).
23 |
24 | Optionally, users can add additional related fields (such as weights).
25 | We assume the following things to be true about fields:
26 | * they correspond to boxes in the box_list along the 0th dimension
27 | * they have inferrable rank at graph construction time
28 | * all dimensions except for possibly the 0th can be inferred
29 | (i.e., not None) at graph construction time.
30 |
31 | Some other notes:
32 | * Following tensorflow conventions, we use height, width ordering,
33 | and correspondingly, y,x (or ymin, xmin, ymax, xmax) ordering
34 | * Tensors are always provided as (flat) [N, 4] tensors.
35 | """
36 |
37 | import tensorflow as tf
38 |
39 |
40 | class BoxList(object):
41 | """Box collection."""
42 |
43 | def __init__(self, boxes):
44 | """Constructs box collection.
45 |
46 | Args:
47 | boxes: a tensor of shape [N, 4] representing box corners
48 |
49 | Raises:
50 | ValueError: if invalid dimensions for bbox data or if bbox data is not in
51 | float32 format.
52 | """
53 | if len(boxes.get_shape()) != 2 or boxes.get_shape()[-1] != 4:
54 | raise ValueError('Invalid dimensions for box data.')
55 | if boxes.dtype != tf.float32:
56 | raise ValueError('Invalid tensor type: should be tf.float32')
57 | self.data = {'boxes': boxes}
58 |
59 | def num_boxes(self):
60 | """Returns number of boxes held in collection.
61 |
62 | Returns:
63 | a tensor representing the number of boxes held in the collection.
64 | """
65 | return tf.shape(self.data['boxes'])[0]
66 |
67 | def num_boxes_static(self):
68 | """Returns number of boxes held in collection.
69 |
70 | This number is inferred at graph construction time rather than run-time.
71 |
72 | Returns:
73 | Number of boxes held in collection (integer) or None if this is not
74 | inferrable at graph construction time.
75 | """
76 | return self.data['boxes'].get_shape()[0].value
77 |
78 | def get_all_fields(self):
79 | """Returns all fields."""
80 | return self.data.keys()
81 |
82 | def get_extra_fields(self):
83 | """Returns all non-box fields (i.e., everything not named 'boxes')."""
84 | return [k for k in self.data.keys() if k != 'boxes']
85 |
86 | def add_field(self, field, field_data):
87 | """Add field to box list.
88 |
89 | This method can be used to add related box data such as
90 | weights/labels, etc.
91 |
92 | Args:
93 | field: a string key to access the data via `get`
94 | field_data: a tensor containing the data to store in the BoxList
95 | """
96 | self.data[field] = field_data
97 |
98 | def has_field(self, field):
99 | return field in self.data
100 |
101 | def get(self):
102 | """Convenience function for accessing box coordinates.
103 |
104 | Returns:
105 | a tensor with shape [N, 4] representing box coordinates.
106 | """
107 | return self.get_field('boxes')
108 |
109 | def set(self, boxes):
110 | """Convenience function for setting box coordinates.
111 |
112 | Args:
113 | boxes: a tensor of shape [N, 4] representing box corners
114 |
115 | Raises:
116 | ValueError: if invalid dimensions for bbox data
117 | """
118 | if len(boxes.get_shape()) != 2 or boxes.get_shape()[-1] != 4:
119 | raise ValueError('Invalid dimensions for box data.')
120 | self.data['boxes'] = boxes
121 |
122 | def get_field(self, field):
123 | """Accesses a box collection and associated fields.
124 |
125 | This function returns specified field with object; if no field is specified,
126 | it returns the box coordinates.
127 |
128 | Args:
129 | field: this optional string parameter can be used to specify
130 | a related field to be accessed.
131 |
132 | Returns:
133 | a tensor representing the box collection or an associated field.
134 |
135 | Raises:
136 | ValueError: if invalid field
137 | """
138 | if not self.has_field(field):
139 | raise ValueError('field ' + str(field) + ' does not exist')
140 | return self.data[field]
141 |
142 | def set_field(self, field, value):
143 | """Sets the value of a field.
144 |
145 | Updates the field of a box_list with a given value.
146 |
147 | Args:
148 | field: (string) name of the field to set value.
149 | value: the value to assign to the field.
150 |
151 | Raises:
152 | ValueError: if the box_list does not have specified field.
153 | """
154 | if not self.has_field(field):
155 | raise ValueError('field %s does not exist' % field)
156 | self.data[field] = value
157 |
158 | def get_center_coordinates_and_sizes(self, scope=None):
159 | """Computes the center coordinates, height and width of the boxes.
160 |
161 | Args:
162 | scope: name scope of the function.
163 |
164 | Returns:
165 | a list of 4 1-D tensors [ycenter, xcenter, height, width].
166 | """
167 | with tf.name_scope(scope, 'get_center_coordinates_and_sizes'):
168 | box_corners = self.get()
169 | ymin, xmin, ymax, xmax = tf.unstack(tf.transpose(box_corners))
170 | width = xmax - xmin
171 | height = ymax - ymin
172 | ycenter = ymin + height / 2.
173 | xcenter = xmin + width / 2.
174 | return [ycenter, xcenter, height, width]
175 |
176 | def transpose_coordinates(self, scope=None):
177 | """Transpose the coordinate representation in a boxlist.
178 |
179 | Args:
180 | scope: name scope of the function.
181 | """
182 | with tf.name_scope(scope, 'transpose_coordinates'):
183 | y_min, x_min, y_max, x_max = tf.split(
184 | value=self.get(), num_or_size_splits=4, axis=1)
185 | self.set(tf.concat([x_min, y_min, x_max, y_max], 1))
186 |
187 | def as_tensor_dict(self, fields=None):
188 | """Retrieves specified fields as a dictionary of tensors.
189 |
190 | Args:
191 | fields: (optional) list of fields to return in the dictionary.
192 | If None (default), all fields are returned.
193 |
194 | Returns:
195 | tensor_dict: A dictionary of tensors specified by fields.
196 |
197 | Raises:
198 | ValueError: if specified field is not contained in boxlist.
199 | """
200 | tensor_dict = {}
201 | if fields is None:
202 | fields = self.get_all_fields()
203 | for field in fields:
204 | if not self.has_field(field):
205 | raise ValueError('boxlist must contain all specified fields')
206 | tensor_dict[field] = self.get_field(field)
207 | return tensor_dict
208 |
--------------------------------------------------------------------------------
/src/main/inference/tensorflow_detection.py:
--------------------------------------------------------------------------------
1 | import os
2 | import jsonschema
3 | import asyncio
4 | import json
5 | import numpy as np
6 | import tensorflow as tf
7 | from PIL import Image, ImageDraw, ImageFont
8 | from object_detection.utils import label_map_util
9 | from inference.base_inference_engine import AbstractInferenceEngine
10 | from inference.exceptions import InvalidModelConfiguration, InvalidInputData, ApplicationError
11 |
12 |
13 | class InferenceEngine(AbstractInferenceEngine):
14 |
15 | def __init__(self, model_path):
16 | self.label_path = ""
17 | self.NUM_CLASSES = None
18 | self.sess = None
19 | self.label_map = None
20 | self.categories = None
21 | self.category_index = None
22 | self.detection_graph = None
23 | self.image_tensor = None
24 | self.d_boxes = None
25 | self.d_scores = None
26 | self.d_classes = None
27 | self.num_d = None
28 | self.font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 20)
29 | super().__init__(model_path)
30 |
31 | def load(self):
32 | with open(os.path.join(self.model_path, 'config.json')) as f:
33 | data = json.load(f)
34 | try:
35 | self.validate_json_configuration(data)
36 | self.set_model_configuration(data)
37 | except ApplicationError as e:
38 | raise e
39 |
40 | self.label_path = os.path.join(self.model_path, 'object-detection.pbtxt')
41 | self.label_map = label_map_util.load_labelmap(self.label_path)
42 | self.categories = label_map_util.convert_label_map_to_categories(self.label_map,
43 | max_num_classes=self.NUM_CLASSES,
44 | use_display_name=True)
45 | for dict in self.categories:
46 | self.labels.append(dict['name'])
47 |
48 | self.category_index = label_map_util.create_category_index(self.categories)
49 | self.detection_graph = tf.Graph()
50 | with self.detection_graph.as_default():
51 | od_graph_def = tf.GraphDef()
52 | with tf.gfile.GFile(os.path.join(self.model_path, 'frozen_inference_graph.pb'), 'rb') as fid:
53 | serialized_graph = fid.read()
54 | od_graph_def.ParseFromString(serialized_graph)
55 | tf.import_graph_def(od_graph_def, name='')
56 | self.image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
57 | self.d_boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
58 | self.d_scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
59 | self.d_classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
60 | self.num_d = self.detection_graph.get_tensor_by_name('num_detections:0')
61 | self.sess = tf.Session(graph=self.detection_graph)
62 | img = Image.open("object_detection/image1.jpg")
63 | img_expanded = np.expand_dims(img, axis=0)
64 | self.sess.run(
65 | [self.d_boxes, self.d_scores, self.d_classes, self.num_d],
66 | feed_dict={self.image_tensor: img_expanded})
67 |
68 | async def infer(self, input_data, draw, predict_batch):
69 | await asyncio.sleep(0.00001)
70 | try:
71 | pillow_image = Image.open(input_data.file).convert('RGB')
72 | np_image = np.array(pillow_image)
73 | except Exception as e:
74 | raise InvalidInputData('corrupted image')
75 | try:
76 | with open(self.model_path + '/config.json') as f:
77 | data = json.load(f)
78 | except Exception as e:
79 | raise InvalidModelConfiguration('config.json not found or corrupted')
80 | json_confidence = data['confidence']
81 | json_predictions = data['predictions']
82 | with self.detection_graph.as_default():
83 | # Expand dimension since the model expects image to have shape [1, None, None, 3].
84 | img_expanded = np.expand_dims(np_image, axis=0)
85 | (boxes, scores, classes, num) = self.sess.run(
86 | [self.d_boxes, self.d_scores, self.d_classes, self.num_d],
87 | feed_dict={self.image_tensor: img_expanded})
88 | classes_names = ([self.category_index.get(i) for i in classes[0]])
89 | names_start = []
90 | for name in classes_names:
91 | if name is not None:
92 | names_start.append(name['name'])
93 | height, width, depth = np_image.shape
94 | names = []
95 | confidence = []
96 | ids = []
97 | bounding_boxes = []
98 | for i in range(json_predictions):
99 | if scores[0][i] * 100 >= json_confidence:
100 | ymin = int(round(boxes[0][i][0] * height)) if int(round(boxes[0][i][0] * height)) > 0 else 0
101 | xmin = int(round(boxes[0][i][1] * width)) if int(round(boxes[0][i][1] * height)) > 0 else 0
102 | ymax = int(round(boxes[0][i][2] * height)) if int(round(boxes[0][i][2] * height)) > 0 else 0
103 | xmax = int(round(boxes[0][i][3] * width)) if int(round(boxes[0][i][3] * height)) > 0 else 0
104 | tmp = dict([('left', xmin), ('top', ymin), ('right', xmax), ('bottom', ymax)])
105 | bounding_boxes.append(tmp)
106 | confidence.append(float(scores[0][i] * 100))
107 | ids.append(int(classes[0][i]))
108 | names.append(names_start[i])
109 |
110 | responses_list = zip(names, confidence, bounding_boxes, ids)
111 |
112 | output = []
113 | for response in responses_list:
114 | tmp = dict([('ObjectClassName', response[0]), ('confidence', response[1]), ('coordinates', response[2]),
115 | ('ObjectClassId', response[3])])
116 | output.append(tmp)
117 | if predict_batch:
118 | response = dict([('bounding-boxes', output), ('ImageName', input_data.filename)])
119 | else:
120 | response = dict([('bounding-boxes', output)])
121 |
122 | if not draw:
123 | return response
124 | else:
125 | try:
126 | self.draw_image(pillow_image, response)
127 | except ApplicationError as e:
128 | raise e
129 | except Exception as e:
130 | raise e
131 |
132 | async def run_batch(self, input_data, draw, predict_batch):
133 | result_list = []
134 | for image in input_data:
135 | post_process = await self.infer(image, draw, predict_batch)
136 | if post_process is not None:
137 | result_list.append(post_process)
138 | return result_list
139 |
140 | def draw_image(self, image, response):
141 | """
142 | Draws on image and saves it.
143 | :param image: image of type pillow image
144 | :param response: inference response
145 | :return:
146 | """
147 | draw = ImageDraw.Draw(image)
148 | for bbox in response['bounding-boxes']:
149 | draw.rectangle([bbox['coordinates']['left'], bbox['coordinates']['top'], bbox['coordinates']['right'],
150 | bbox['coordinates']['bottom']], outline="red")
151 | left = bbox['coordinates']['left']
152 | top = bbox['coordinates']['top']
153 | conf = "{0:.2f}".format(bbox['confidence'])
154 | draw.text((int(left), int(top) - 20), str(conf) + "% " + str(bbox['ObjectClassName']), 'red', self.font)
155 | image.save('/main/result.jpg', 'PNG')
156 |
157 | def free(self):
158 | pass
159 |
160 | def validate_configuration(self):
161 | # check if weights file exists
162 | if not os.path.exists(os.path.join(self.model_path, 'frozen_inference_graph.pb')):
163 | raise InvalidModelConfiguration('frozen_inference_graph.pb not found')
164 | # check if labels file exists
165 | if not os.path.exists(os.path.join(self.model_path, 'object-detection.pbtxt')):
166 | raise InvalidModelConfiguration('object-detection.pbtxt not found')
167 | return True
168 |
169 | def set_model_configuration(self, data):
170 | self.configuration['framework'] = data['framework']
171 | self.configuration['type'] = data['type']
172 | self.configuration['network'] = data['network']
173 | self.NUM_CLASSES = data['number_of_classes']
174 |
175 | def validate_json_configuration(self, data):
176 | with open(os.path.join('inference', 'ConfigurationSchema.json')) as f:
177 | schema = json.load(f)
178 | try:
179 | jsonschema.validate(data, schema)
180 | except Exception as e:
181 | raise InvalidModelConfiguration(e)
182 |
--------------------------------------------------------------------------------
/src/main/start.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from typing import List
3 | from models import ApiResponse
4 | from inference.errors import Error
5 | from starlette.responses import FileResponse
6 | from starlette.staticfiles import StaticFiles
7 | from starlette.middleware.cors import CORSMiddleware
8 | from deep_learning_service import DeepLearningService
9 | from fastapi import FastAPI, Form, File, UploadFile, Header, HTTPException
10 | from inference.exceptions import ModelNotFound, InvalidModelConfiguration, ApplicationError, ModelNotLoaded, \
11 | InferenceEngineNotFound, InvalidInputData
12 | from ocr import ocr_service, one_shot_ocr_service
13 | from datetime import datetime
14 | import pytz
15 | from PIL import Image
16 |
17 | sys.path.append('./inference')
18 |
19 | tz = pytz.timezone("Europe/Berlin")
20 |
21 | dl_service = DeepLearningService()
22 | error_logging = Error()
23 | app = FastAPI(version='1.0', title='BMW InnovationLab tensorflow cpu inference Automation',
24 | description="API for performing tensorflow cpu inference"
25 | "Contact the developers:"
26 | "Hadi Koubeissy: hadi.koubeissy@inmind.ai"
27 | "BMW Innovation Lab: innovation-lab@bmw.de")
28 |
29 |
30 | # app.mount("/public", StaticFiles(directory="/main/public"), name="public")
31 |
32 | # app.add_middleware(
33 | # CORSMiddleware,
34 | # allow_origins=["*"],
35 | # allow_credentials=True,
36 | # allow_methods=["*"],
37 | # allow_headers=["*"],
38 | # )
39 |
40 |
41 | @app.get('/load')
42 | def load_custom():
43 | """
44 | Loads all the available models.
45 | :return: All the available models with their respective hashed values
46 | """
47 | try:
48 | return dl_service.load_all_models()
49 | except ApplicationError as e:
50 | return ApiResponse(success=False, error=e)
51 | except Exception:
52 | return ApiResponse(success=False, error='unexpected server error')
53 |
54 |
55 | @app.post('/detect')
56 | async def detect_custom(model: str = Form(...), image: UploadFile = File(...)):
57 | """
58 | Performs a prediction for a specified image using one of the available models.
59 | :param model: Model name or model hash
60 | :param image: Image file
61 | :return: Model's Bounding boxes
62 | """
63 | draw_boxes = False
64 | predict_batch = False
65 | try:
66 | output = await dl_service.run_model(model, image, draw_boxes, predict_batch)
67 | error_logging.info('request successful;' + str(output))
68 | return output
69 | except ApplicationError as e:
70 | error_logging.warning(model + ';' + str(e))
71 | return ApiResponse(success=False, error=e)
72 | except Exception as e:
73 | error_logging.error(model + ' ' + str(e))
74 | return ApiResponse(success=False, error='unexpected server error')
75 |
76 |
77 | @app.post('/get_labels')
78 | def get_labels_custom(model: str = Form(...)):
79 | """
80 | Lists the model's labels with their hashed values.
81 | :param model: Model name or model hash
82 | :return: A list of the model's labels with their hashed values
83 | """
84 | return dl_service.get_labels_custom(model)
85 |
86 |
87 | @app.get('/models/{model_name}/load')
88 | async def load(model_name: str, force: bool = False):
89 | """
90 | Loads a model specified as a query parameter.
91 | :param model_name: Model name
92 | :param force: Boolean for model force reload on each call
93 | :return: APIResponse
94 | """
95 | try:
96 | dl_service.load_model(model_name, force)
97 | return ApiResponse(success=True)
98 | except ApplicationError as e:
99 | return ApiResponse(success=False, error=e)
100 |
101 |
102 | @app.get('/models')
103 | async def list_models(user_agent: str = Header(None)):
104 | """
105 | Lists all available models.
106 | :param user_agent:
107 | :return: APIResponse
108 | """
109 | return ApiResponse(data={'models': dl_service.list_models()})
110 |
111 |
112 | @app.post('/models/{model_name}/predict')
113 | async def run_model(model_name: str, input_data: UploadFile = File(...)):
114 | """
115 | Performs a prediction by giving both model name and image file.
116 | :param model_name: Model name
117 | :param input_data: An image file
118 | :return: APIResponse containing the prediction's bounding boxes
119 | """
120 | try:
121 | output = await dl_service.run_model(model_name, input_data, draw=False, predict_batch=False)
122 | error_logging.info('request successful;' + str(output))
123 | return ApiResponse(data=output)
124 | except ApplicationError as e:
125 | error_logging.warning(model_name + ';' + str(e))
126 | return ApiResponse(success=False, error=e)
127 | except Exception as e:
128 | error_logging.error(model_name + ' ' + str(e))
129 | return ApiResponse(success=False, error='unexpected server error')
130 |
131 |
132 | @app.post('/models/{model_name}/predict_batch', include_in_schema=False)
133 | async def run_model_batch(model_name: str, input_data: List[UploadFile] = File(...)):
134 | """
135 | Performs a prediction by giving both model name and image file(s).
136 | :param model_name: Model name
137 | :param input_data: A batch of image files or a single image file
138 | :return: APIResponse containing prediction(s) bounding boxes
139 | """
140 | try:
141 | output = await dl_service.run_model(model_name, input_data, draw=False, predict_batch=True)
142 | error_logging.info('request successful;' + str(output))
143 | return ApiResponse(data=output)
144 | except ApplicationError as e:
145 | error_logging.warning(model_name + ';' + str(e))
146 | return ApiResponse(success=False, error=e)
147 | except Exception as e:
148 | print(e)
149 | error_logging.error(model_name + ' ' + str(e))
150 | return ApiResponse(success=False, error='unexpected server error')
151 |
152 |
153 | @app.post('/models/{model_name}/predict_image')
154 | async def predict_image(model_name: str, input_data: UploadFile = File(...)):
155 | """
156 | Draws bounding box(es) on image and returns it.
157 | :param model_name: Model name
158 | :param input_data: Image file
159 | :return: Image file
160 | """
161 | try:
162 | output = await dl_service.run_model(model_name, input_data, draw=True, predict_batch=False)
163 | error_logging.info('request successful;' + str(output))
164 | return FileResponse("/main/result.jpg", media_type="image/jpg")
165 | except ApplicationError as e:
166 | error_logging.warning(model_name + ';' + str(e))
167 | return ApiResponse(success=False, error=e)
168 | except Exception as e:
169 | error_logging.error(model_name + ' ' + str(e))
170 | return ApiResponse(success=False, error='unexpected server error')
171 |
172 |
173 | @app.get('/models/{model_name}/labels')
174 | async def list_model_labels(model_name: str):
175 | """
176 | Lists all the model's labels.
177 | :param model_name: Model name
178 | :return: List of model's labels
179 | """
180 | labels = dl_service.get_labels(model_name)
181 | return ApiResponse(data=labels)
182 |
183 |
184 | @app.get('/models/{model_name}/config')
185 | async def list_model_config(model_name: str):
186 | """
187 | Lists all the model's configuration.
188 | :param model_name: Model name
189 | :return: List of model's configuration
190 | """
191 | config = dl_service.get_config(model_name)
192 | return ApiResponse(data=config)
193 |
194 |
195 | @app.post('/models/{model_name}/one_shot_ocr')
196 | async def one_shot_ocr(
197 | model_name: str,
198 | image: UploadFile = File(
199 | ..., description="Image to perform optical character recognition based on layout inference:")
200 | ):
201 | """
202 | Takes an image and returns extracted text details.
203 |
204 | In first place a detection model will be used for cropping interesting areas in the uploaded image. These areas will then be passed to the OCR-Service for text extraction.
205 |
206 | :param model_name: Model name or model hash for layout detection
207 |
208 | :param image: Image file
209 |
210 | :return: Text fields with the detected files inside
211 |
212 | """
213 | output = None
214 | # call detection on image with choosen model
215 | try:
216 | output = await run_model(model_name, image)
217 | except:
218 | raise HTTPException(status_code=404, detail='Invalid Model')
219 |
220 | # run ocr_service
221 | response = None
222 | try:
223 | image = Image.open(image.file).convert('RGB')
224 | response = one_shot_ocr_service(image, output.data)
225 | except:
226 | raise HTTPException(
227 | status_code=500, detail='Unexpected Error during Inference (Determination of Texts)')
228 |
229 | if not response:
230 | raise HTTPException(
231 | status_code=400, detail='Inference (Determination of Texts) is not Possible with the Specified Model')
232 |
233 | return response
234 |
235 |
236 | @app.post('/models/{model_name}/ocr')
237 | async def optical_character_recognition(
238 | model_name: str,
239 | image: UploadFile = File(
240 | ..., description="Image to perform optical character recognition based on layout inference:"),
241 | ):
242 | """
243 | Takes an image and returns extracted text informations.
244 |
245 | The image is passed to the OCR-Service for text extraction
246 |
247 | :param model: Model name or model hash
248 |
249 | :param image: Image file
250 |
251 | :return: Text fields with the detected files inside
252 |
253 | """
254 | # run ocr_service
255 | response = None
256 | try:
257 | image = Image.open(image.file).convert('RGB')
258 | response = ocr_service(image)
259 | except:
260 | raise HTTPException(
261 | status_code=500, detail='Unexpected Error during Inference (Determination of Texts)')
262 |
263 | if not response:
264 | raise HTTPException(
265 | status_code=400, detail='Inference (Determination of Texts) is not Possible with the Specified Model')
266 |
267 | return response
268 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Tensorflow CPU Inference API For Windows and Linux
2 | This is a repository for an object detection inference API using the Tensorflow framework.
3 |
4 | This repo is based on [Tensorflow Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection).
5 |
6 | The Tensorflow version used is 1.13.1. The inference REST API works on CPU and doesn't require any GPU usage. It's supported on both Windows and Linux Operating systems.
7 |
8 | Models trained using our training tensorflow repository can be deployed in this API. Several object detection models can be loaded and used at the same time.
9 | This repo also offers optical character recognition services to extract textboxes from images.
10 |
11 | This repo can be deployed using either **docker** or **docker swarm**.
12 |
13 | Please use **docker swarm** only if you need to:
14 |
15 | * Provide redundancy in terms of API containers: In case a container went down, the incoming requests will be redirected to another running instance.
16 |
17 | * Coordinate between the containers: Swarm will orchestrate between the APIs and choose one of them to listen to the incoming request.
18 |
19 | * Scale up the Inference service in order to get a faster prediction especially if there's traffic on the service.
20 |
21 | If none of the aforementioned requirements are needed, simply use **docker**.
22 |
23 | 
24 |
25 | ## Prerequisites
26 |
27 | - OS:
28 | - Ubuntu 16.04/18.04
29 | - Windows 10 pro/enterprise
30 | - Docker
31 |
32 | ### Check for prerequisites
33 |
34 | To check if you have docker-ce installed:
35 |
36 | ```sh
37 | docker --version
38 | ```
39 |
40 | ### Install prerequisites
41 |
42 | #### Ubuntu
43 |
44 | Use the following command to install docker on Ubuntu:
45 |
46 | ```sh
47 | chmod +x install_prerequisites.sh && source install_prerequisites.sh
48 | ```
49 |
50 | #### Windows 10
51 |
52 | To [install Docker on Windows](https://docs.docker.com/docker-for-windows/install/), please follow the link.
53 |
54 | **P.S: For Windows users, open the Docker Desktop menu by clicking the Docker Icon in the Notifications area. Select Settings, and then Advanced tab to adjust the resources available to Docker Engine.**
55 |
56 | ## Build The Docker Image
57 |
58 | In order to build the project run the following command from the project's root directory:
59 |
60 | ```sh
61 | sudo docker build -t tensorflow_inference_api_cpu -f docker/dockerfile .
62 | ```
63 |
64 | ### Behind a proxy
65 |
66 | ```sh
67 | sudo docker build --build-arg http_proxy='' --build-arg https_proxy='' -t tensorflow_inference_api_cpu -f ./docker/dockerfile .
68 | ```
69 |
70 | ## Run the docker container
71 |
72 | As mentioned before, this container can be deployed using either **docker** or **docker swarm**.
73 |
74 | If you wish to deploy this API using **docker**, please issue the following run command.
75 |
76 | If you wish to deploy this API using **docker swarm**, please refer to following link [docker swarm documentation](./README-docker_swarm.md). After deploying the API with docker swarm, please consider returning to this documentation for further information about the API endpoints as well as the model structure sections.
77 |
78 | To run the API, go the to the API's directory and run the following:
79 |
80 | #### Using Linux based docker:
81 |
82 | ```sh
83 | sudo docker run -itv $(pwd)/models:/models -v $(pwd)/models_hash:/models_hash -p :4343 tensorflow_inference_api_cpu
84 | ```
85 |
86 | #### Using Windows based docker:
87 |
88 | ```sh
89 | docker run -itv ${PWD}/models:/models -v ${PWD}/models_hash:/models_hash -p :4343 tensorflow_inference_api_cpu
90 | ```
91 |
92 | The can be any unique port of your choice.
93 |
94 | The API file will be run automatically, and the service will listen to http requests on the chosen port.
95 |
96 | ## API Endpoints
97 |
98 | To see all available endpoints, open your favorite browser and navigate to:
99 |
100 | ```
101 | http://:/docs
102 | ```
103 | The 'predict_batch' endpoint is not shown on swagger. The list of files input is not yet supported.
104 |
105 | **P.S: If you are using custom endpoints like /load, /detect, and /get_labels, you should always use the /load endpoint first and then use /detect or /get_labels**
106 |
107 | ### Endpoints summary
108 |
109 | #### /load (GET)
110 |
111 | Loads all available models and returns every model with it's hashed value. Loaded models are stored and aren't loaded again
112 |
113 | 
114 |
115 | #### /detect (POST)
116 |
117 | Performs inference on specified model, image, and returns bounding-boxes
118 |
119 | 
120 |
121 | #### /get_labels (POST)
122 |
123 | Returns all of the specified model labels with their hashed values
124 |
125 | 
126 |
127 | #### /models/{model_name}/predict_image (POST)
128 |
129 | Performs inference on specified model, image, draws bounding boxes on the image, and returns the actual image as response
130 |
131 | 
132 |
133 | #### /models (GET)
134 |
135 | Lists all available models
136 |
137 | #### /models/{model_name}/load (GET)
138 |
139 | Loads the specified model. Loaded models are stored and aren't loaded again
140 |
141 | #### /models/{model_name}/predict (POST)
142 |
143 | Performs inference on specified model, image, and returns bounding boxes.
144 |
145 | #### /models/{model_name}/labels (GET)
146 |
147 | Returns all of the specified model labels
148 |
149 | #### /models/{model_name}/config (GET)
150 |
151 | Returns the specified model's configuration
152 |
153 | #### /models/{model_name}/predict_batch (POST)
154 |
155 | Performs inference on specified model and a list of images, and returns bounding boxes
156 |
157 | #### /models/{model_name}/one_shot_ocr (POST)
158 |
159 | Takes an image and returns extracted text details. In first place a detection model will be used for cropping interesting areas in the uploaded image. Then, these areas will be passed to the OCR-Service for text extraction.
160 |
161 | #### /models/{model_name}/ocr (POST)
162 |
163 | 
164 |
165 | Takes an image and returns extracted text details without using an object detection model
166 |
167 | **P.S:** Custom endpoints like /load, /detect, /get_labels and /one_shot_ocr should be used in a chronological order. First you have to call /load, and then call /detect, /get_labels or /one_shot_ocr
168 | ## Model structure
169 |
170 | The folder "models" contains subfolders of all the models to be loaded.
171 | Inside each subfolder there should be a:
172 |
173 | - pb file (frozen_inference_graph.pb): contains the model weights
174 |
175 | - pbtxt file (object-detection.pbtxt): contains model classes
176 |
177 | - Config.json (This is a json file containing information about the model)
178 |
179 | ```json
180 | {
181 | "inference_engine_name": "tensorflow_detection",
182 | "confidence": 60,
183 | "predictions": 15,
184 | "number_of_classes": 2,
185 | "framework": "tensorflow",
186 | "type": "detection",
187 | "network": "inception"
188 | }
189 | ```
190 | P.S:
191 | - You can change confidence and predictions values while running the API
192 | - The API will return bounding boxes with a confidence higher than the "confidence" value. A high "confidence" can show you only accurate predictions
193 | - The "predictions" value specifies the maximum number of bounding boxes in the API response
194 |
195 |
196 | ## Benchmarking
197 |
198 |
199 |
200 |
201 | |
202 | Windows |
203 | Ubuntu |
204 |
205 |
206 |
207 |
208 | | Network\Hardware |
209 | Intel Xeon CPU 2.3 GHz |
210 | Intel Xeon CPU 2.3 GHz |
211 | Intel Xeon CPU 3.60 GHz |
212 | GeForce GTX 1080 |
213 |
214 |
215 |
216 |
217 | | ssd_fpn |
218 | 0.867 seconds/image |
219 | 1.016 seconds/image |
220 | 0.434 seconds/image |
221 | 0.0658 seconds/image |
222 |
223 |
224 | | frcnn_resnet_50 |
225 | 4.029 seconds/image |
226 | 4.219 seconds/image |
227 | 1.994 seconds/image |
228 | 0.148 seconds/image |
229 |
230 |
231 | | ssd_mobilenet |
232 | 0.055 seconds/image |
233 | 0.106 seconds/image |
234 | 0.051 seconds/image |
235 | 0.052 seconds/image |
236 |
237 |
238 | | frcnn_resnet_101 |
239 | 4.469 seconds/image |
240 | 4.985 seconds/image |
241 | 2.254 seconds/image |
242 | 0.364 seconds/image |
243 |
244 |
245 | | ssd_resnet_50 |
246 | 1.34 seconds/image |
247 | 1.462 seconds/image |
248 | 0.668 seconds/image |
249 | 0.091 seconds/image |
250 |
251 |
252 | | ssd_inception |
253 | 0.094 seconds/image |
254 | 0.15 seconds/image |
255 | 0.074 seconds/image |
256 | 0.0513 seconds/image |
257 |
258 |
259 |
260 |
261 | ## Acknowledgment
262 |
263 | [inmind.ai](https://inmind.ai)
264 |
265 | [robotron.de](https://robotron.de)
266 |
267 | Joe Sleiman, inmind.ai , Beirut, Lebanon
268 |
269 | Antoine Charbel, inmind.ai, Beirut, Lebanon
270 |
271 | [Anis Ismail](https://www.linkedin.com/in/anisdismail), Beirut, Lebanon
272 |
--------------------------------------------------------------------------------
/src/main/object_detection/core/standard_fields.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Contains classes specifying naming conventions used for object detection.
17 |
18 |
19 | Specifies:
20 | InputDataFields: standard fields used by reader/preprocessor/batcher.
21 | DetectionResultFields: standard fields returned by object detector.
22 | BoxListFields: standard field used by BoxList
23 | TfExampleFields: standard fields for tf-example data format (go/tf-example).
24 | """
25 |
26 |
27 | class InputDataFields(object):
28 | """Names for the input tensors.
29 |
30 | Holds the standard data field names to use for identifying input tensors. This
31 | should be used by the decoder to identify keys for the returned tensor_dict
32 | containing input tensors. And it should be used by the model to identify the
33 | tensors it needs.
34 |
35 | Attributes:
36 | image: image.
37 | image_additional_channels: additional channels.
38 | original_image: image in the original input size.
39 | key: unique key corresponding to image.
40 | source_id: source of the original image.
41 | filename: original filename of the dataset (without common path).
42 | groundtruth_image_classes: image-level class labels.
43 | groundtruth_boxes: coordinates of the ground truth boxes in the image.
44 | groundtruth_classes: box-level class labels.
45 | groundtruth_label_types: box-level label types (e.g. explicit negative).
46 | groundtruth_is_crowd: [DEPRECATED, use groundtruth_group_of instead]
47 | is the groundtruth a single object or a crowd.
48 | groundtruth_area: area of a groundtruth segment.
49 | groundtruth_difficult: is a `difficult` object
50 | groundtruth_group_of: is a `group_of` objects, e.g. multiple objects of the
51 | same class, forming a connected group, where instances are heavily
52 | occluding each other.
53 | proposal_boxes: coordinates of object proposal boxes.
54 | proposal_objectness: objectness score of each proposal.
55 | groundtruth_instance_masks: ground truth instance masks.
56 | groundtruth_instance_boundaries: ground truth instance boundaries.
57 | groundtruth_instance_classes: instance mask-level class labels.
58 | groundtruth_keypoints: ground truth keypoints.
59 | groundtruth_keypoint_visibilities: ground truth keypoint visibilities.
60 | groundtruth_label_scores: groundtruth label scores.
61 | groundtruth_weights: groundtruth weight factor for bounding boxes.
62 | num_groundtruth_boxes: number of groundtruth boxes.
63 | true_image_shapes: true shapes of images in the resized images, as resized
64 | images can be padded with zeros.
65 | verified_labels: list of human-verified image-level labels (note, that a
66 | label can be verified both as positive and negative).
67 | multiclass_scores: the label score per class for each box.
68 | """
69 | image = 'image'
70 | image_additional_channels = 'image_additional_channels'
71 | original_image = 'original_image'
72 | key = 'key'
73 | source_id = 'source_id'
74 | filename = 'filename'
75 | groundtruth_image_classes = 'groundtruth_image_classes'
76 | groundtruth_boxes = 'groundtruth_boxes'
77 | groundtruth_classes = 'groundtruth_classes'
78 | groundtruth_label_types = 'groundtruth_label_types'
79 | groundtruth_is_crowd = 'groundtruth_is_crowd'
80 | groundtruth_area = 'groundtruth_area'
81 | groundtruth_difficult = 'groundtruth_difficult'
82 | groundtruth_group_of = 'groundtruth_group_of'
83 | proposal_boxes = 'proposal_boxes'
84 | proposal_objectness = 'proposal_objectness'
85 | groundtruth_instance_masks = 'groundtruth_instance_masks'
86 | groundtruth_instance_boundaries = 'groundtruth_instance_boundaries'
87 | groundtruth_instance_classes = 'groundtruth_instance_classes'
88 | groundtruth_keypoints = 'groundtruth_keypoints'
89 | groundtruth_keypoint_visibilities = 'groundtruth_keypoint_visibilities'
90 | groundtruth_label_scores = 'groundtruth_label_scores'
91 | groundtruth_weights = 'groundtruth_weights'
92 | num_groundtruth_boxes = 'num_groundtruth_boxes'
93 | true_image_shape = 'true_image_shape'
94 | verified_labels = 'verified_labels'
95 | multiclass_scores = 'multiclass_scores'
96 |
97 |
98 | class DetectionResultFields(object):
99 | """Naming conventions for storing the output of the detector.
100 |
101 | Attributes:
102 | source_id: source of the original image.
103 | key: unique key corresponding to image.
104 | detection_boxes: coordinates of the detection boxes in the image.
105 | detection_scores: detection scores for the detection boxes in the image.
106 | detection_classes: detection-level class labels.
107 | detection_masks: contains a segmentation mask for each detection box.
108 | detection_boundaries: contains an object boundary for each detection box.
109 | detection_keypoints: contains detection keypoints for each detection box.
110 | num_detections: number of detections in the batch.
111 | """
112 |
113 | source_id = 'source_id'
114 | key = 'key'
115 | detection_boxes = 'detection_boxes'
116 | detection_scores = 'detection_scores'
117 | detection_classes = 'detection_classes'
118 | detection_masks = 'detection_masks'
119 | detection_boundaries = 'detection_boundaries'
120 | detection_keypoints = 'detection_keypoints'
121 | num_detections = 'num_detections'
122 |
123 |
124 | class BoxListFields(object):
125 | """Naming conventions for BoxLists.
126 |
127 | Attributes:
128 | boxes: bounding box coordinates.
129 | classes: classes per bounding box.
130 | scores: scores per bounding box.
131 | weights: sample weights per bounding box.
132 | objectness: objectness score per bounding box.
133 | masks: masks per bounding box.
134 | boundaries: boundaries per bounding box.
135 | keypoints: keypoints per bounding box.
136 | keypoint_heatmaps: keypoint heatmaps per bounding box.
137 | is_crowd: is_crowd annotation per bounding box.
138 | """
139 | boxes = 'boxes'
140 | classes = 'classes'
141 | scores = 'scores'
142 | weights = 'weights'
143 | objectness = 'objectness'
144 | masks = 'masks'
145 | boundaries = 'boundaries'
146 | keypoints = 'keypoints'
147 | keypoint_heatmaps = 'keypoint_heatmaps'
148 | is_crowd = 'is_crowd'
149 |
150 |
151 | class TfExampleFields(object):
152 | """TF-example proto feature names for object detection.
153 |
154 | Holds the standard feature names to load from an Example proto for object
155 | detection.
156 |
157 | Attributes:
158 | image_encoded: JPEG encoded string
159 | image_format: image format, e.g. "JPEG"
160 | filename: filename
161 | channels: number of channels of image
162 | colorspace: colorspace, e.g. "RGB"
163 | height: height of image in pixels, e.g. 462
164 | width: width of image in pixels, e.g. 581
165 | source_id: original source of the image
166 | image_class_text: image-level label in text format
167 | image_class_label: image-level label in numerical format
168 | object_class_text: labels in text format, e.g. ["person", "cat"]
169 | object_class_label: labels in numbers, e.g. [16, 8]
170 | object_bbox_xmin: xmin coordinates of groundtruth box, e.g. 10, 30
171 | object_bbox_xmax: xmax coordinates of groundtruth box, e.g. 50, 40
172 | object_bbox_ymin: ymin coordinates of groundtruth box, e.g. 40, 50
173 | object_bbox_ymax: ymax coordinates of groundtruth box, e.g. 80, 70
174 | object_view: viewpoint of object, e.g. ["frontal", "left"]
175 | object_truncated: is object truncated, e.g. [true, false]
176 | object_occluded: is object occluded, e.g. [true, false]
177 | object_difficult: is object difficult, e.g. [true, false]
178 | object_group_of: is object a single object or a group of objects
179 | object_depiction: is object a depiction
180 | object_is_crowd: [DEPRECATED, use object_group_of instead]
181 | is the object a single object or a crowd
182 | object_segment_area: the area of the segment.
183 | object_weight: a weight factor for the object's bounding box.
184 | instance_masks: instance segmentation masks.
185 | instance_boundaries: instance boundaries.
186 | instance_classes: Classes for each instance segmentation mask.
187 | detection_class_label: class label in numbers.
188 | detection_bbox_ymin: ymin coordinates of a detection box.
189 | detection_bbox_xmin: xmin coordinates of a detection box.
190 | detection_bbox_ymax: ymax coordinates of a detection box.
191 | detection_bbox_xmax: xmax coordinates of a detection box.
192 | detection_score: detection score for the class label and box.
193 | """
194 | image_encoded = 'image/encoded'
195 | image_format = 'image/format' # format is reserved keyword
196 | filename = 'image/filename'
197 | channels = 'image/channels'
198 | colorspace = 'image/colorspace'
199 | height = 'image/height'
200 | width = 'image/width'
201 | source_id = 'image/source_id'
202 | image_class_text = 'image/class/text'
203 | image_class_label = 'image/class/label'
204 | object_class_text = 'image/object/class/text'
205 | object_class_label = 'image/object/class/label'
206 | object_bbox_ymin = 'image/object/bbox/ymin'
207 | object_bbox_xmin = 'image/object/bbox/xmin'
208 | object_bbox_ymax = 'image/object/bbox/ymax'
209 | object_bbox_xmax = 'image/object/bbox/xmax'
210 | object_view = 'image/object/view'
211 | object_truncated = 'image/object/truncated'
212 | object_occluded = 'image/object/occluded'
213 | object_difficult = 'image/object/difficult'
214 | object_group_of = 'image/object/group_of'
215 | object_depiction = 'image/object/depiction'
216 | object_is_crowd = 'image/object/is_crowd'
217 | object_segment_area = 'image/object/segment/area'
218 | object_weight = 'image/object/weight'
219 | instance_masks = 'image/segmentation/object'
220 | instance_boundaries = 'image/boundaries/object'
221 | instance_classes = 'image/segmentation/object/class'
222 | detection_class_label = 'image/detection/label'
223 | detection_bbox_ymin = 'image/detection/bbox/ymin'
224 | detection_bbox_xmin = 'image/detection/bbox/xmin'
225 | detection_bbox_ymax = 'image/detection/bbox/ymax'
226 | detection_bbox_xmax = 'image/detection/bbox/xmax'
227 | detection_score = 'image/detection/score'
228 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2019 BMW Group
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/src/main/object_detection/utils/shape_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utils used to manipulate tensor shapes."""
17 |
18 | import tensorflow as tf
19 |
20 | from object_detection.utils import static_shape
21 |
22 |
23 | def _is_tensor(t):
24 | """Returns a boolean indicating whether the input is a tensor.
25 |
26 | Args:
27 | t: the input to be tested.
28 |
29 | Returns:
30 | a boolean that indicates whether t is a tensor.
31 | """
32 | return isinstance(t, (tf.Tensor, tf.SparseTensor, tf.Variable))
33 |
34 |
35 | def _set_dim_0(t, d0):
36 | """Sets the 0-th dimension of the input tensor.
37 |
38 | Args:
39 | t: the input tensor, assuming the rank is at least 1.
40 | d0: an integer indicating the 0-th dimension of the input tensor.
41 |
42 | Returns:
43 | the tensor t with the 0-th dimension set.
44 | """
45 | t_shape = t.get_shape().as_list()
46 | t_shape[0] = d0
47 | t.set_shape(t_shape)
48 | return t
49 |
50 |
51 | def pad_tensor(t, length):
52 | """Pads the input tensor with 0s along the first dimension up to the length.
53 |
54 | Args:
55 | t: the input tensor, assuming the rank is at least 1.
56 | length: a tensor of shape [1] or an integer, indicating the first dimension
57 | of the input tensor t after padding, assuming length <= t.shape[0].
58 |
59 | Returns:
60 | padded_t: the padded tensor, whose first dimension is length. If the length
61 | is an integer, the first dimension of padded_t is set to length
62 | statically.
63 | """
64 | t_rank = tf.rank(t)
65 | t_shape = tf.shape(t)
66 | t_d0 = t_shape[0]
67 | pad_d0 = tf.expand_dims(length - t_d0, 0)
68 | pad_shape = tf.cond(
69 | tf.greater(t_rank, 1), lambda: tf.concat([pad_d0, t_shape[1:]], 0),
70 | lambda: tf.expand_dims(length - t_d0, 0))
71 | padded_t = tf.concat([t, tf.zeros(pad_shape, dtype=t.dtype)], 0)
72 | if not _is_tensor(length):
73 | padded_t = _set_dim_0(padded_t, length)
74 | return padded_t
75 |
76 |
77 | def clip_tensor(t, length):
78 | """Clips the input tensor along the first dimension up to the length.
79 |
80 | Args:
81 | t: the input tensor, assuming the rank is at least 1.
82 | length: a tensor of shape [1] or an integer, indicating the first dimension
83 | of the input tensor t after clipping, assuming length <= t.shape[0].
84 |
85 | Returns:
86 | clipped_t: the clipped tensor, whose first dimension is length. If the
87 | length is an integer, the first dimension of clipped_t is set to length
88 | statically.
89 | """
90 | clipped_t = tf.gather(t, tf.range(length))
91 | if not _is_tensor(length):
92 | clipped_t = _set_dim_0(clipped_t, length)
93 | return clipped_t
94 |
95 |
96 | def pad_or_clip_tensor(t, length):
97 | """Pad or clip the input tensor along the first dimension.
98 |
99 | Args:
100 | t: the input tensor, assuming the rank is at least 1.
101 | length: a tensor of shape [1] or an integer, indicating the first dimension
102 | of the input tensor t after processing.
103 |
104 | Returns:
105 | processed_t: the processed tensor, whose first dimension is length. If the
106 | length is an integer, the first dimension of the processed tensor is set
107 | to length statically.
108 | """
109 | processed_t = tf.cond(
110 | tf.greater(tf.shape(t)[0], length),
111 | lambda: clip_tensor(t, length),
112 | lambda: pad_tensor(t, length))
113 | if not _is_tensor(length):
114 | processed_t = _set_dim_0(processed_t, length)
115 | return processed_t
116 |
117 |
118 | def combined_static_and_dynamic_shape(tensor):
119 | """Returns a list containing static and dynamic values for the dimensions.
120 |
121 | Returns a list of static and dynamic values for shape dimensions. This is
122 | useful to preserve static shapes when available in reshape operation.
123 |
124 | Args:
125 | tensor: A tensor of any type.
126 |
127 | Returns:
128 | A list of size tensor.shape.ndims containing integers or a scalar tensor.
129 | """
130 | static_tensor_shape = tensor.shape.as_list()
131 | dynamic_tensor_shape = tf.shape(tensor)
132 | combined_shape = []
133 | for index, dim in enumerate(static_tensor_shape):
134 | if dim is not None:
135 | combined_shape.append(dim)
136 | else:
137 | combined_shape.append(dynamic_tensor_shape[index])
138 | return combined_shape
139 |
140 |
141 | def static_or_dynamic_map_fn(fn, elems, dtype=None,
142 | parallel_iterations=32, back_prop=True):
143 | """Runs map_fn as a (static) for loop when possible.
144 |
145 | This function rewrites the map_fn as an explicit unstack input -> for loop
146 | over function calls -> stack result combination. This allows our graphs to
147 | be acyclic when the batch size is static.
148 | For comparison, see https://www.tensorflow.org/api_docs/python/tf/map_fn.
149 |
150 | Note that `static_or_dynamic_map_fn` currently is not *fully* interchangeable
151 | with the default tf.map_fn function as it does not accept nested inputs (only
152 | Tensors or lists of Tensors). Likewise, the output of `fn` can only be a
153 | Tensor or list of Tensors.
154 |
155 | TODO(jonathanhuang): make this function fully interchangeable with tf.map_fn.
156 |
157 | Args:
158 | fn: The callable to be performed. It accepts one argument, which will have
159 | the same structure as elems. Its output must have the
160 | same structure as elems.
161 | elems: A tensor or list of tensors, each of which will
162 | be unpacked along their first dimension. The sequence of the
163 | resulting slices will be applied to fn.
164 | dtype: (optional) The output type(s) of fn. If fn returns a structure of
165 | Tensors differing from the structure of elems, then dtype is not optional
166 | and must have the same structure as the output of fn.
167 | parallel_iterations: (optional) number of batch items to process in
168 | parallel. This flag is only used if the native tf.map_fn is used
169 | and defaults to 32 instead of 10 (unlike the standard tf.map_fn default).
170 | back_prop: (optional) True enables support for back propagation.
171 | This flag is only used if the native tf.map_fn is used.
172 |
173 | Returns:
174 | A tensor or sequence of tensors. Each tensor packs the
175 | results of applying fn to tensors unpacked from elems along the first
176 | dimension, from first to last.
177 | Raises:
178 | ValueError: if `elems` a Tensor or a list of Tensors.
179 | ValueError: if `fn` does not return a Tensor or list of Tensors
180 | """
181 | if isinstance(elems, list):
182 | for elem in elems:
183 | if not isinstance(elem, tf.Tensor):
184 | raise ValueError('`elems` must be a Tensor or list of Tensors.')
185 |
186 | elem_shapes = [elem.shape.as_list() for elem in elems]
187 | # Fall back on tf.map_fn if shapes of each entry of `elems` are None or fail
188 | # to all be the same size along the batch dimension.
189 | for elem_shape in elem_shapes:
190 | if (not elem_shape or not elem_shape[0]
191 | or elem_shape[0] != elem_shapes[0][0]):
192 | return tf.map_fn(fn, elems, dtype, parallel_iterations, back_prop)
193 | arg_tuples = zip(*[tf.unstack(elem) for elem in elems])
194 | outputs = [fn(arg_tuple) for arg_tuple in arg_tuples]
195 | else:
196 | if not isinstance(elems, tf.Tensor):
197 | raise ValueError('`elems` must be a Tensor or list of Tensors.')
198 | elems_shape = elems.shape.as_list()
199 | if not elems_shape or not elems_shape[0]:
200 | return tf.map_fn(fn, elems, dtype, parallel_iterations, back_prop)
201 | outputs = [fn(arg) for arg in tf.unstack(elems)]
202 | # Stack `outputs`, which is a list of Tensors or list of lists of Tensors
203 | if all([isinstance(output, tf.Tensor) for output in outputs]):
204 | return tf.stack(outputs)
205 | else:
206 | if all([isinstance(output, list) for output in outputs]):
207 | if all([all(
208 | [isinstance(entry, tf.Tensor) for entry in output_list])
209 | for output_list in outputs]):
210 | return [tf.stack(output_tuple) for output_tuple in zip(*outputs)]
211 | raise ValueError('`fn` should return a Tensor or a list of Tensors.')
212 |
213 |
214 | def check_min_image_dim(min_dim, image_tensor):
215 | """Checks that the image width/height are greater than some number.
216 |
217 | This function is used to check that the width and height of an image are above
218 | a certain value. If the image shape is static, this function will perform the
219 | check at graph construction time. Otherwise, if the image shape varies, an
220 | Assertion control dependency will be added to the graph.
221 |
222 | Args:
223 | min_dim: The minimum number of pixels along the width and height of the
224 | image.
225 | image_tensor: The image tensor to check size for.
226 |
227 | Returns:
228 | If `image_tensor` has dynamic size, return `image_tensor` with a Assert
229 | control dependency. Otherwise returns image_tensor.
230 |
231 | Raises:
232 | ValueError: if `image_tensor`'s' width or height is smaller than `min_dim`.
233 | """
234 | image_shape = image_tensor.get_shape()
235 | image_height = static_shape.get_height(image_shape)
236 | image_width = static_shape.get_width(image_shape)
237 | if image_height is None or image_width is None:
238 | shape_assert = tf.Assert(
239 | tf.logical_and(tf.greater_equal(tf.shape(image_tensor)[1], min_dim),
240 | tf.greater_equal(tf.shape(image_tensor)[2], min_dim)),
241 | ['image size must be >= {} in both height and width.'.format(min_dim)])
242 | with tf.control_dependencies([shape_assert]):
243 | return tf.identity(image_tensor)
244 |
245 | if image_height < min_dim or image_width < min_dim:
246 | raise ValueError(
247 | 'image size must be >= %d in both height and width; image dim = %d,%d' %
248 | (min_dim, image_height, image_width))
249 |
250 | return image_tensor
251 |
252 |
253 | def assert_shape_equal(shape_a, shape_b):
254 | """Asserts that shape_a and shape_b are equal.
255 |
256 | If the shapes are static, raises a ValueError when the shapes
257 | mismatch.
258 |
259 | If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes
260 | mismatch.
261 |
262 | Args:
263 | shape_a: a list containing shape of the first tensor.
264 | shape_b: a list containing shape of the second tensor.
265 |
266 | Returns:
267 | Either a tf.no_op() when shapes are all static and a tf.assert_equal() op
268 | when the shapes are dynamic.
269 |
270 | Raises:
271 | ValueError: When shapes are both static and unequal.
272 | """
273 | if (all(isinstance(dim, int) for dim in shape_a) and
274 | all(isinstance(dim, int) for dim in shape_b)):
275 | if shape_a != shape_b:
276 | raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b))
277 | else: return tf.no_op()
278 | else:
279 | return tf.assert_equal(shape_a, shape_b)
280 |
281 |
282 | def assert_shape_equal_along_first_dimension(shape_a, shape_b):
283 | """Asserts that shape_a and shape_b are the same along the 0th-dimension.
284 |
285 | If the shapes are static, raises a ValueError when the shapes
286 | mismatch.
287 |
288 | If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes
289 | mismatch.
290 |
291 | Args:
292 | shape_a: a list containing shape of the first tensor.
293 | shape_b: a list containing shape of the second tensor.
294 |
295 | Returns:
296 | Either a tf.no_op() when shapes are all static and a tf.assert_equal() op
297 | when the shapes are dynamic.
298 |
299 | Raises:
300 | ValueError: When shapes are both static and unequal.
301 | """
302 | if isinstance(shape_a[0], int) and isinstance(shape_b[0], int):
303 | if shape_a[0] != shape_b[0]:
304 | raise ValueError('Unequal first dimension {}, {}'.format(
305 | shape_a[0], shape_b[0]))
306 | else: return tf.no_op()
307 | else:
308 | return tf.assert_equal(shape_a[0], shape_b[0])
309 |
310 |
--------------------------------------------------------------------------------
/src/main/object_detection/utils/visualization_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """A set of functions that are used for visualization.
17 |
18 | These functions often receive an image, perform some visualization on the image.
19 | The functions do not return a value, instead they modify the image itself.
20 |
21 | """
22 | import collections
23 | import functools
24 | # Set headless-friendly backend.
25 | import matplotlib; matplotlib.use('Agg') # pylint: disable=multiple-statements
26 | import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top
27 | import numpy as np
28 | import PIL.Image as Image
29 | import PIL.ImageColor as ImageColor
30 | import PIL.ImageDraw as ImageDraw
31 | import PIL.ImageFont as ImageFont
32 | import six
33 | import tensorflow as tf
34 |
35 | from object_detection.core import standard_fields as fields
36 |
37 |
38 | _TITLE_LEFT_MARGIN = 10
39 | _TITLE_TOP_MARGIN = 10
40 | STANDARD_COLORS = [
41 | 'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque',
42 | 'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite',
43 | 'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan',
44 | 'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange',
45 | 'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet',
46 | 'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite',
47 | 'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod',
48 | 'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki',
49 | 'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue',
50 | 'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey',
51 | 'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue',
52 | 'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime',
53 | 'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid',
54 | 'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen',
55 | 'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin',
56 | 'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed',
57 | 'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed',
58 | 'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple',
59 | 'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown',
60 | 'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue',
61 | 'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow',
62 | 'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White',
63 | 'WhiteSmoke', 'Yellow', 'YellowGreen'
64 | ]
65 |
66 |
67 | def save_image_array_as_png(image, output_path):
68 | """Saves an image (represented as a numpy array) to PNG.
69 |
70 | Args:
71 | image: a numpy array with shape [height, width, 3].
72 | output_path: path to which image should be written.
73 | """
74 | image_pil = Image.fromarray(np.uint8(image)).convert('RGB')
75 | with tf.gfile.Open(output_path, 'w') as fid:
76 | image_pil.save(fid, 'PNG')
77 |
78 |
79 | def encode_image_array_as_png_str(image):
80 | """Encodes a numpy array into a PNG string.
81 |
82 | Args:
83 | image: a numpy array with shape [height, width, 3].
84 |
85 | Returns:
86 | PNG encoded image string.
87 | """
88 | image_pil = Image.fromarray(np.uint8(image))
89 | output = six.BytesIO()
90 | image_pil.save(output, format='PNG')
91 | png_string = output.getvalue()
92 | output.close()
93 | return png_string
94 |
95 |
96 | def draw_bounding_box_on_image_array(image,
97 | ymin,
98 | xmin,
99 | ymax,
100 | xmax,
101 | color='red',
102 | thickness=4,
103 | display_str_list=(),
104 | use_normalized_coordinates=True):
105 | """Adds a bounding box to an image (numpy array).
106 |
107 | Bounding box coordinates can be specified in either absolute (pixel) or
108 | normalized coordinates by setting the use_normalized_coordinates argument.
109 |
110 | Args:
111 | image: a numpy array with shape [height, width, 3].
112 | ymin: ymin of bounding box.
113 | xmin: xmin of bounding box.
114 | ymax: ymax of bounding box.
115 | xmax: xmax of bounding box.
116 | color: color to draw bounding box. Default is red.
117 | thickness: line thickness. Default value is 4.
118 | display_str_list: list of strings to display in box
119 | (each to be shown on its own line).
120 | use_normalized_coordinates: If True (default), treat coordinates
121 | ymin, xmin, ymax, xmax as relative to the image. Otherwise treat
122 | coordinates as absolute.
123 | """
124 | image_pil = Image.fromarray(np.uint8(image)).convert('RGB')
125 | draw_bounding_box_on_image(image_pil, ymin, xmin, ymax, xmax, color,
126 | thickness, display_str_list,
127 | use_normalized_coordinates)
128 | np.copyto(image, np.array(image_pil))
129 |
130 |
131 | def draw_bounding_box_on_image(image,
132 | ymin,
133 | xmin,
134 | ymax,
135 | xmax,
136 | color='red',
137 | thickness=4,
138 | display_str_list=(),
139 | use_normalized_coordinates=True):
140 | """Adds a bounding box to an image.
141 |
142 | Bounding box coordinates can be specified in either absolute (pixel) or
143 | normalized coordinates by setting the use_normalized_coordinates argument.
144 |
145 | Each string in display_str_list is displayed on a separate line above the
146 | bounding box in black text on a rectangle filled with the input 'color'.
147 | If the top of the bounding box extends to the edge of the image, the strings
148 | are displayed below the bounding box.
149 |
150 | Args:
151 | image: a PIL.Image object.
152 | ymin: ymin of bounding box.
153 | xmin: xmin of bounding box.
154 | ymax: ymax of bounding box.
155 | xmax: xmax of bounding box.
156 | color: color to draw bounding box. Default is red.
157 | thickness: line thickness. Default value is 4.
158 | display_str_list: list of strings to display in box
159 | (each to be shown on its own line).
160 | use_normalized_coordinates: If True (default), treat coordinates
161 | ymin, xmin, ymax, xmax as relative to the image. Otherwise treat
162 | coordinates as absolute.
163 | """
164 | draw = ImageDraw.Draw(image)
165 | im_width, im_height = image.size
166 | if use_normalized_coordinates:
167 | (left, right, top, bottom) = (xmin * im_width, xmax * im_width,
168 | ymin * im_height, ymax * im_height)
169 | else:
170 | (left, right, top, bottom) = (xmin, xmax, ymin, ymax)
171 | draw.line([(left, top), (left, bottom), (right, bottom),
172 | (right, top), (left, top)], width=thickness, fill=color)
173 | try:
174 | font = ImageFont.truetype('arial.ttf', 24)
175 | except IOError:
176 | font = ImageFont.load_default()
177 |
178 | # If the total height of the display strings added to the top of the bounding
179 | # box exceeds the top of the image, stack the strings below the bounding box
180 | # instead of above.
181 | display_str_heights = [font.getsize(ds)[1] for ds in display_str_list]
182 | # Each display_str has a top and bottom margin of 0.05x.
183 | total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights)
184 |
185 | if top > total_display_str_height:
186 | text_bottom = top
187 | else:
188 | text_bottom = bottom + total_display_str_height
189 | # Reverse list and print from bottom to top.
190 | for display_str in display_str_list[::-1]:
191 | text_width, text_height = font.getsize(display_str)
192 | margin = np.ceil(0.05 * text_height)
193 | draw.rectangle(
194 | [(left, text_bottom - text_height - 2 * margin), (left + text_width,
195 | text_bottom)],
196 | fill=color)
197 | draw.text(
198 | (left + margin, text_bottom - text_height - margin),
199 | display_str,
200 | fill='black',
201 | font=font)
202 | text_bottom -= text_height - 2 * margin
203 |
204 |
205 | def draw_bounding_boxes_on_image_array(image,
206 | boxes,
207 | color='red',
208 | thickness=4,
209 | display_str_list_list=()):
210 | """Draws bounding boxes on image (numpy array).
211 |
212 | Args:
213 | image: a numpy array object.
214 | boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax).
215 | The coordinates are in normalized format between [0, 1].
216 | color: color to draw bounding box. Default is red.
217 | thickness: line thickness. Default value is 4.
218 | display_str_list_list: list of list of strings.
219 | a list of strings for each bounding box.
220 | The reason to pass a list of strings for a
221 | bounding box is that it might contain
222 | multiple labels.
223 |
224 | Raises:
225 | ValueError: if boxes is not a [N, 4] array
226 | """
227 | image_pil = Image.fromarray(image)
228 | draw_bounding_boxes_on_image(image_pil, boxes, color, thickness,
229 | display_str_list_list)
230 | np.copyto(image, np.array(image_pil))
231 |
232 |
233 | def draw_bounding_boxes_on_image(image,
234 | boxes,
235 | color='red',
236 | thickness=4,
237 | display_str_list_list=()):
238 | """Draws bounding boxes on image.
239 |
240 | Args:
241 | image: a PIL.Image object.
242 | boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax).
243 | The coordinates are in normalized format between [0, 1].
244 | color: color to draw bounding box. Default is red.
245 | thickness: line thickness. Default value is 4.
246 | display_str_list_list: list of list of strings.
247 | a list of strings for each bounding box.
248 | The reason to pass a list of strings for a
249 | bounding box is that it might contain
250 | multiple labels.
251 |
252 | Raises:
253 | ValueError: if boxes is not a [N, 4] array
254 | """
255 | boxes_shape = boxes.shape
256 | if not boxes_shape:
257 | return
258 | if len(boxes_shape) != 2 or boxes_shape[1] != 4:
259 | raise ValueError('Input must be of size [N, 4]')
260 | for i in range(boxes_shape[0]):
261 | display_str_list = ()
262 | if display_str_list_list:
263 | display_str_list = display_str_list_list[i]
264 | draw_bounding_box_on_image(image, boxes[i, 0], boxes[i, 1], boxes[i, 2],
265 | boxes[i, 3], color, thickness, display_str_list)
266 |
267 |
268 | def _visualize_boxes(image, boxes, classes, scores, category_index, **kwargs):
269 | return visualize_boxes_and_labels_on_image_array(
270 | image, boxes, classes, scores, category_index=category_index, **kwargs)
271 |
272 |
273 | def _visualize_boxes_and_masks(image, boxes, classes, scores, masks,
274 | category_index, **kwargs):
275 | return visualize_boxes_and_labels_on_image_array(
276 | image,
277 | boxes,
278 | classes,
279 | scores,
280 | category_index=category_index,
281 | instance_masks=masks,
282 | **kwargs)
283 |
284 |
285 | def _visualize_boxes_and_keypoints(image, boxes, classes, scores, keypoints,
286 | category_index, **kwargs):
287 | return visualize_boxes_and_labels_on_image_array(
288 | image,
289 | boxes,
290 | classes,
291 | scores,
292 | category_index=category_index,
293 | keypoints=keypoints,
294 | **kwargs)
295 |
296 |
297 | def _visualize_boxes_and_masks_and_keypoints(
298 | image, boxes, classes, scores, masks, keypoints, category_index, **kwargs):
299 | return visualize_boxes_and_labels_on_image_array(
300 | image,
301 | boxes,
302 | classes,
303 | scores,
304 | category_index=category_index,
305 | instance_masks=masks,
306 | keypoints=keypoints,
307 | **kwargs)
308 |
309 |
310 | def draw_bounding_boxes_on_image_tensors(images,
311 | boxes,
312 | classes,
313 | scores,
314 | category_index,
315 | instance_masks=None,
316 | keypoints=None,
317 | max_boxes_to_draw=20,
318 | min_score_thresh=0.2,
319 | use_normalized_coordinates=True):
320 | """Draws bounding boxes, masks, and keypoints on batch of image tensors.
321 |
322 | Args:
323 | images: A 4D uint8 image tensor of shape [N, H, W, C]. If C > 3, additional
324 | channels will be ignored.
325 | boxes: [N, max_detections, 4] float32 tensor of detection boxes.
326 | classes: [N, max_detections] int tensor of detection classes. Note that
327 | classes are 1-indexed.
328 | scores: [N, max_detections] float32 tensor of detection scores.
329 | category_index: a dict that maps integer ids to category dicts. e.g.
330 | {1: {1: 'dog'}, 2: {2: 'cat'}, ...}
331 | instance_masks: A 4D uint8 tensor of shape [N, max_detection, H, W] with
332 | instance masks.
333 | keypoints: A 4D float32 tensor of shape [N, max_detection, num_keypoints, 2]
334 | with keypoints.
335 | max_boxes_to_draw: Maximum number of boxes to draw on an image. Default 20.
336 | min_score_thresh: Minimum score threshold for visualization. Default 0.2.
337 | use_normalized_coordinates: Whether to assume boxes and kepoints are in
338 | normalized coordinates (as opposed to absolute coordiantes).
339 | Default is True.
340 |
341 | Returns:
342 | 4D image tensor of type uint8, with boxes drawn on top.
343 | """
344 | # Additional channels are being ignored.
345 | images = images[:, :, :, 0:3]
346 | visualization_keyword_args = {
347 | 'use_normalized_coordinates': use_normalized_coordinates,
348 | 'max_boxes_to_draw': max_boxes_to_draw,
349 | 'min_score_thresh': min_score_thresh,
350 | 'agnostic_mode': False,
351 | 'line_thickness': 4
352 | }
353 |
354 | if instance_masks is not None and keypoints is None:
355 | visualize_boxes_fn = functools.partial(
356 | _visualize_boxes_and_masks,
357 | category_index=category_index,
358 | **visualization_keyword_args)
359 | elems = [images, boxes, classes, scores, instance_masks]
360 | elif instance_masks is None and keypoints is not None:
361 | visualize_boxes_fn = functools.partial(
362 | _visualize_boxes_and_keypoints,
363 | category_index=category_index,
364 | **visualization_keyword_args)
365 | elems = [images, boxes, classes, scores, keypoints]
366 | elif instance_masks is not None and keypoints is not None:
367 | visualize_boxes_fn = functools.partial(
368 | _visualize_boxes_and_masks_and_keypoints,
369 | category_index=category_index,
370 | **visualization_keyword_args)
371 | elems = [images, boxes, classes, scores, instance_masks, keypoints]
372 | else:
373 | visualize_boxes_fn = functools.partial(
374 | _visualize_boxes,
375 | category_index=category_index,
376 | **visualization_keyword_args)
377 | elems = [images, boxes, classes, scores]
378 |
379 | def draw_boxes(image_and_detections):
380 | """Draws boxes on image."""
381 | image_with_boxes = tf.py_func(visualize_boxes_fn, image_and_detections,
382 | tf.uint8)
383 | return image_with_boxes
384 |
385 | images = tf.map_fn(draw_boxes, elems, dtype=tf.uint8, back_prop=False)
386 | return images
387 |
388 |
389 | def draw_side_by_side_evaluation_image(eval_dict,
390 | category_index,
391 | max_boxes_to_draw=20,
392 | min_score_thresh=0.2,
393 | use_normalized_coordinates=True):
394 | """Creates a side-by-side image with detections and groundtruth.
395 |
396 | Bounding boxes (and instance masks, if available) are visualized on both
397 | subimages.
398 |
399 | Args:
400 | eval_dict: The evaluation dictionary returned by
401 | eval_util.result_dict_for_single_example().
402 | category_index: A category index (dictionary) produced from a labelmap.
403 | max_boxes_to_draw: The maximum number of boxes to draw for detections.
404 | min_score_thresh: The minimum score threshold for showing detections.
405 | use_normalized_coordinates: Whether to assume boxes and kepoints are in
406 | normalized coordinates (as opposed to absolute coordiantes).
407 | Default is True.
408 |
409 | Returns:
410 | A [1, H, 2 * W, C] uint8 tensor. The subimage on the left corresponds to
411 | detections, while the subimage on the right corresponds to groundtruth.
412 | """
413 | detection_fields = fields.DetectionResultFields()
414 | input_data_fields = fields.InputDataFields()
415 | instance_masks = None
416 | if detection_fields.detection_masks in eval_dict:
417 | instance_masks = tf.cast(
418 | tf.expand_dims(eval_dict[detection_fields.detection_masks], axis=0),
419 | tf.uint8)
420 | keypoints = None
421 | if detection_fields.detection_keypoints in eval_dict:
422 | keypoints = tf.expand_dims(
423 | eval_dict[detection_fields.detection_keypoints], axis=0)
424 | groundtruth_instance_masks = None
425 | if input_data_fields.groundtruth_instance_masks in eval_dict:
426 | groundtruth_instance_masks = tf.cast(
427 | tf.expand_dims(
428 | eval_dict[input_data_fields.groundtruth_instance_masks], axis=0),
429 | tf.uint8)
430 | images_with_detections = draw_bounding_boxes_on_image_tensors(
431 | eval_dict[input_data_fields.original_image],
432 | tf.expand_dims(eval_dict[detection_fields.detection_boxes], axis=0),
433 | tf.expand_dims(eval_dict[detection_fields.detection_classes], axis=0),
434 | tf.expand_dims(eval_dict[detection_fields.detection_scores], axis=0),
435 | category_index,
436 | instance_masks=instance_masks,
437 | keypoints=keypoints,
438 | max_boxes_to_draw=max_boxes_to_draw,
439 | min_score_thresh=min_score_thresh,
440 | use_normalized_coordinates=use_normalized_coordinates)
441 | images_with_groundtruth = draw_bounding_boxes_on_image_tensors(
442 | eval_dict[input_data_fields.original_image],
443 | tf.expand_dims(eval_dict[input_data_fields.groundtruth_boxes], axis=0),
444 | tf.expand_dims(eval_dict[input_data_fields.groundtruth_classes], axis=0),
445 | tf.expand_dims(
446 | tf.ones_like(
447 | eval_dict[input_data_fields.groundtruth_classes],
448 | dtype=tf.float32),
449 | axis=0),
450 | category_index,
451 | instance_masks=groundtruth_instance_masks,
452 | keypoints=None,
453 | max_boxes_to_draw=None,
454 | min_score_thresh=0.0,
455 | use_normalized_coordinates=use_normalized_coordinates)
456 | return tf.concat([images_with_detections, images_with_groundtruth], axis=2)
457 |
458 |
459 | def draw_keypoints_on_image_array(image,
460 | keypoints,
461 | color='red',
462 | radius=2,
463 | use_normalized_coordinates=True):
464 | """Draws keypoints on an image (numpy array).
465 |
466 | Args:
467 | image: a numpy array with shape [height, width, 3].
468 | keypoints: a numpy array with shape [num_keypoints, 2].
469 | color: color to draw the keypoints with. Default is red.
470 | radius: keypoint radius. Default value is 2.
471 | use_normalized_coordinates: if True (default), treat keypoint values as
472 | relative to the image. Otherwise treat them as absolute.
473 | """
474 | image_pil = Image.fromarray(np.uint8(image)).convert('RGB')
475 | draw_keypoints_on_image(image_pil, keypoints, color, radius,
476 | use_normalized_coordinates)
477 | np.copyto(image, np.array(image_pil))
478 |
479 |
480 | def draw_keypoints_on_image(image,
481 | keypoints,
482 | color='red',
483 | radius=2,
484 | use_normalized_coordinates=True):
485 | """Draws keypoints on an image.
486 |
487 | Args:
488 | image: a PIL.Image object.
489 | keypoints: a numpy array with shape [num_keypoints, 2].
490 | color: color to draw the keypoints with. Default is red.
491 | radius: keypoint radius. Default value is 2.
492 | use_normalized_coordinates: if True (default), treat keypoint values as
493 | relative to the image. Otherwise treat them as absolute.
494 | """
495 | draw = ImageDraw.Draw(image)
496 | im_width, im_height = image.size
497 | keypoints_x = [k[1] for k in keypoints]
498 | keypoints_y = [k[0] for k in keypoints]
499 | if use_normalized_coordinates:
500 | keypoints_x = tuple([im_width * x for x in keypoints_x])
501 | keypoints_y = tuple([im_height * y for y in keypoints_y])
502 | for keypoint_x, keypoint_y in zip(keypoints_x, keypoints_y):
503 | draw.ellipse([(keypoint_x - radius, keypoint_y - radius),
504 | (keypoint_x + radius, keypoint_y + radius)],
505 | outline=color, fill=color)
506 |
507 |
508 | def draw_mask_on_image_array(image, mask, color='red', alpha=0.4):
509 | """Draws mask on an image.
510 |
511 | Args:
512 | image: uint8 numpy array with shape (img_height, img_height, 3)
513 | mask: a uint8 numpy array of shape (img_height, img_height) with
514 | values between either 0 or 1.
515 | color: color to draw the keypoints with. Default is red.
516 | alpha: transparency value between 0 and 1. (default: 0.4)
517 |
518 | Raises:
519 | ValueError: On incorrect data type for image or masks.
520 | """
521 | if image.dtype != np.uint8:
522 | raise ValueError('`image` not of type np.uint8')
523 | if mask.dtype != np.uint8:
524 | raise ValueError('`mask` not of type np.uint8')
525 | if np.any(np.logical_and(mask != 1, mask != 0)):
526 | raise ValueError('`mask` elements should be in [0, 1]')
527 | if image.shape[:2] != mask.shape:
528 | raise ValueError('The image has spatial dimensions %s but the mask has '
529 | 'dimensions %s' % (image.shape[:2], mask.shape))
530 | rgb = ImageColor.getrgb(color)
531 | pil_image = Image.fromarray(image)
532 |
533 | solid_color = np.expand_dims(
534 | np.ones_like(mask), axis=2) * np.reshape(list(rgb), [1, 1, 3])
535 | pil_solid_color = Image.fromarray(np.uint8(solid_color)).convert('RGBA')
536 | pil_mask = Image.fromarray(np.uint8(255.0*alpha*mask)).convert('L')
537 | pil_image = Image.composite(pil_solid_color, pil_image, pil_mask)
538 | np.copyto(image, np.array(pil_image.convert('RGB')))
539 |
540 |
541 | def visualize_boxes_and_labels_on_image_array(
542 | image,
543 | boxes,
544 | classes,
545 | scores,
546 | category_index,
547 | instance_masks=None,
548 | instance_boundaries=None,
549 | keypoints=None,
550 | use_normalized_coordinates=False,
551 | max_boxes_to_draw=20,
552 | min_score_thresh=.5,
553 | agnostic_mode=False,
554 | line_thickness=4,
555 | groundtruth_box_visualization_color='black',
556 | skip_scores=False,
557 | skip_labels=False):
558 | """Overlay labeled boxes on an image with formatted scores and label names.
559 |
560 | This function groups boxes that correspond to the same location
561 | and creates a display string for each detection and overlays these
562 | on the image. Note that this function modifies the image in place, and returns
563 | that same image.
564 |
565 | Args:
566 | image: uint8 numpy array with shape (img_height, img_width, 3)
567 | boxes: a numpy array of shape [N, 4]
568 | classes: a numpy array of shape [N]. Note that class indices are 1-based,
569 | and match the keys in the label map.
570 | scores: a numpy array of shape [N] or None. If scores=None, then
571 | this function assumes that the boxes to be plotted are groundtruth
572 | boxes and plot all boxes as black with no classes or scores.
573 | category_index: a dict containing category dictionaries (each holding
574 | category index `id` and category name `name`) keyed by category indices.
575 | instance_masks: a numpy array of shape [N, image_height, image_width] with
576 | values ranging between 0 and 1, can be None.
577 | instance_boundaries: a numpy array of shape [N, image_height, image_width]
578 | with values ranging between 0 and 1, can be None.
579 | keypoints: a numpy array of shape [N, num_keypoints, 2], can
580 | be None
581 | use_normalized_coordinates: whether boxes is to be interpreted as
582 | normalized coordinates or not.
583 | max_boxes_to_draw: maximum number of boxes to visualize. If None, draw
584 | all boxes.
585 | min_score_thresh: minimum score threshold for a box to be visualized
586 | agnostic_mode: boolean (default: False) controlling whether to evaluate in
587 | class-agnostic mode or not. This mode will display scores but ignore
588 | classes.
589 | line_thickness: integer (default: 4) controlling line width of the boxes.
590 | groundtruth_box_visualization_color: box color for visualizing groundtruth
591 | boxes
592 | skip_scores: whether to skip score when drawing a single detection
593 | skip_labels: whether to skip label when drawing a single detection
594 |
595 | Returns:
596 | uint8 numpy array with shape (img_height, img_width, 3) with overlaid boxes.
597 | """
598 | # Create a display string (and color) for every box location, group any boxes
599 | # that correspond to the same location.
600 | box_to_display_str_map = collections.defaultdict(list)
601 | box_to_color_map = collections.defaultdict(str)
602 | box_to_instance_masks_map = {}
603 | box_to_instance_boundaries_map = {}
604 | box_to_keypoints_map = collections.defaultdict(list)
605 | if not max_boxes_to_draw:
606 | max_boxes_to_draw = boxes.shape[0]
607 | for i in range(min(max_boxes_to_draw, boxes.shape[0])):
608 | if scores is None or scores[i] > min_score_thresh:
609 | box = tuple(boxes[i].tolist())
610 | if instance_masks is not None:
611 | box_to_instance_masks_map[box] = instance_masks[i]
612 | if instance_boundaries is not None:
613 | box_to_instance_boundaries_map[box] = instance_boundaries[i]
614 | if keypoints is not None:
615 | box_to_keypoints_map[box].extend(keypoints[i])
616 | if scores is None:
617 | box_to_color_map[box] = groundtruth_box_visualization_color
618 | else:
619 | display_str = ''
620 | if not skip_labels:
621 | if not agnostic_mode:
622 | if classes[i] in category_index.keys():
623 | class_name = category_index[classes[i]]['name']
624 | else:
625 | class_name = 'N/A'
626 | display_str = str(class_name)
627 | if not skip_scores:
628 | if not display_str:
629 | display_str = '{}%'.format(int(100*scores[i]))
630 | else:
631 | display_str = '{}: {}%'.format(display_str, int(100*scores[i]))
632 | box_to_display_str_map[box].append(display_str)
633 | if agnostic_mode:
634 | box_to_color_map[box] = 'DarkOrange'
635 | else:
636 | box_to_color_map[box] = STANDARD_COLORS[
637 | classes[i] % len(STANDARD_COLORS)]
638 |
639 | # Draw all boxes onto image.
640 | for box, color in box_to_color_map.items():
641 | ymin, xmin, ymax, xmax = box
642 | if instance_masks is not None:
643 | draw_mask_on_image_array(
644 | image,
645 | box_to_instance_masks_map[box],
646 | color=color
647 | )
648 | if instance_boundaries is not None:
649 | draw_mask_on_image_array(
650 | image,
651 | box_to_instance_boundaries_map[box],
652 | color='red',
653 | alpha=1.0
654 | )
655 | draw_bounding_box_on_image_array(
656 | image,
657 | ymin,
658 | xmin,
659 | ymax,
660 | xmax,
661 | color=color,
662 | thickness=line_thickness,
663 | display_str_list=box_to_display_str_map[box],
664 | use_normalized_coordinates=use_normalized_coordinates)
665 | if keypoints is not None:
666 | draw_keypoints_on_image_array(
667 | image,
668 | box_to_keypoints_map[box],
669 | color=color,
670 | radius=line_thickness / 2,
671 | use_normalized_coordinates=use_normalized_coordinates)
672 |
673 | return image
674 |
675 |
676 | def add_cdf_image_summary(values, name):
677 | """Adds a tf.summary.image for a CDF plot of the values.
678 |
679 | Normalizes `values` such that they sum to 1, plots the cumulative distribution
680 | function and creates a tf image summary.
681 |
682 | Args:
683 | values: a 1-D float32 tensor containing the values.
684 | name: name for the image summary.
685 | """
686 | def cdf_plot(values):
687 | """Numpy function to plot CDF."""
688 | normalized_values = values / np.sum(values)
689 | sorted_values = np.sort(normalized_values)
690 | cumulative_values = np.cumsum(sorted_values)
691 | fraction_of_examples = (np.arange(cumulative_values.size, dtype=np.float32)
692 | / cumulative_values.size)
693 | fig = plt.figure(frameon=False)
694 | ax = fig.add_subplot('111')
695 | ax.plot(fraction_of_examples, cumulative_values)
696 | ax.set_ylabel('cumulative normalized values')
697 | ax.set_xlabel('fraction of examples')
698 | fig.canvas.draw()
699 | width, height = fig.get_size_inches() * fig.get_dpi()
700 | image = np.fromstring(fig.canvas.tostring_rgb(), dtype='uint8').reshape(
701 | 1, int(height), int(width), 3)
702 | return image
703 | cdf_plot = tf.py_func(cdf_plot, [values], tf.uint8)
704 | tf.summary.image(name, cdf_plot)
705 |
706 |
707 | def add_hist_image_summary(values, bins, name):
708 | """Adds a tf.summary.image for a histogram plot of the values.
709 |
710 | Plots the histogram of values and creates a tf image summary.
711 |
712 | Args:
713 | values: a 1-D float32 tensor containing the values.
714 | bins: bin edges which will be directly passed to np.histogram.
715 | name: name for the image summary.
716 | """
717 |
718 | def hist_plot(values, bins):
719 | """Numpy function to plot hist."""
720 | fig = plt.figure(frameon=False)
721 | ax = fig.add_subplot('111')
722 | y, x = np.histogram(values, bins=bins)
723 | ax.plot(x[:-1], y)
724 | ax.set_ylabel('count')
725 | ax.set_xlabel('value')
726 | fig.canvas.draw()
727 | width, height = fig.get_size_inches() * fig.get_dpi()
728 | image = np.fromstring(
729 | fig.canvas.tostring_rgb(), dtype='uint8').reshape(
730 | 1, int(height), int(width), 3)
731 | return image
732 | hist_plot = tf.py_func(hist_plot, [values, bins], tf.uint8)
733 | tf.summary.image(name, hist_plot)
734 |
--------------------------------------------------------------------------------
/src/main/object_detection/utils/ops.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """A module for helper tensorflow ops."""
17 | import math
18 | import numpy as np
19 | import six
20 |
21 | import tensorflow as tf
22 |
23 | from object_detection.core import box_list
24 | from object_detection.core import box_list_ops
25 | from object_detection.core import standard_fields as fields
26 | from object_detection.utils import shape_utils
27 | from object_detection.utils import static_shape
28 |
29 |
30 | def expanded_shape(orig_shape, start_dim, num_dims):
31 | """Inserts multiple ones into a shape vector.
32 |
33 | Inserts an all-1 vector of length num_dims at position start_dim into a shape.
34 | Can be combined with tf.reshape to generalize tf.expand_dims.
35 |
36 | Args:
37 | orig_shape: the shape into which the all-1 vector is added (int32 vector)
38 | start_dim: insertion position (int scalar)
39 | num_dims: length of the inserted all-1 vector (int scalar)
40 | Returns:
41 | An int32 vector of length tf.size(orig_shape) + num_dims.
42 | """
43 | with tf.name_scope('ExpandedShape'):
44 | start_dim = tf.expand_dims(start_dim, 0) # scalar to rank-1
45 | before = tf.slice(orig_shape, [0], start_dim)
46 | add_shape = tf.ones(tf.reshape(num_dims, [1]), dtype=tf.int32)
47 | after = tf.slice(orig_shape, start_dim, [-1])
48 | new_shape = tf.concat([before, add_shape, after], 0)
49 | return new_shape
50 |
51 |
52 | def normalized_to_image_coordinates(normalized_boxes, image_shape,
53 | parallel_iterations=32):
54 | """Converts a batch of boxes from normal to image coordinates.
55 |
56 | Args:
57 | normalized_boxes: a float32 tensor of shape [None, num_boxes, 4] in
58 | normalized coordinates.
59 | image_shape: a float32 tensor of shape [4] containing the image shape.
60 | parallel_iterations: parallelism for the map_fn op.
61 |
62 | Returns:
63 | absolute_boxes: a float32 tensor of shape [None, num_boxes, 4] containg the
64 | boxes in image coordinates.
65 | """
66 | def _to_absolute_coordinates(normalized_boxes):
67 | return box_list_ops.to_absolute_coordinates(
68 | box_list.BoxList(normalized_boxes),
69 | image_shape[1], image_shape[2], check_range=False).get()
70 |
71 | absolute_boxes = shape_utils.static_or_dynamic_map_fn(
72 | _to_absolute_coordinates,
73 | elems=(normalized_boxes),
74 | dtype=tf.float32,
75 | parallel_iterations=parallel_iterations,
76 | back_prop=True)
77 | return absolute_boxes
78 |
79 |
80 | def meshgrid(x, y):
81 | """Tiles the contents of x and y into a pair of grids.
82 |
83 | Multidimensional analog of numpy.meshgrid, giving the same behavior if x and y
84 | are vectors. Generally, this will give:
85 |
86 | xgrid(i1, ..., i_m, j_1, ..., j_n) = x(j_1, ..., j_n)
87 | ygrid(i1, ..., i_m, j_1, ..., j_n) = y(i_1, ..., i_m)
88 |
89 | Keep in mind that the order of the arguments and outputs is reverse relative
90 | to the order of the indices they go into, done for compatibility with numpy.
91 | The output tensors have the same shapes. Specifically:
92 |
93 | xgrid.get_shape() = y.get_shape().concatenate(x.get_shape())
94 | ygrid.get_shape() = y.get_shape().concatenate(x.get_shape())
95 |
96 | Args:
97 | x: A tensor of arbitrary shape and rank. xgrid will contain these values
98 | varying in its last dimensions.
99 | y: A tensor of arbitrary shape and rank. ygrid will contain these values
100 | varying in its first dimensions.
101 | Returns:
102 | A tuple of tensors (xgrid, ygrid).
103 | """
104 | with tf.name_scope('Meshgrid'):
105 | x = tf.convert_to_tensor(x)
106 | y = tf.convert_to_tensor(y)
107 | x_exp_shape = expanded_shape(tf.shape(x), 0, tf.rank(y))
108 | y_exp_shape = expanded_shape(tf.shape(y), tf.rank(y), tf.rank(x))
109 |
110 | xgrid = tf.tile(tf.reshape(x, x_exp_shape), y_exp_shape)
111 | ygrid = tf.tile(tf.reshape(y, y_exp_shape), x_exp_shape)
112 | new_shape = y.get_shape().concatenate(x.get_shape())
113 | xgrid.set_shape(new_shape)
114 | ygrid.set_shape(new_shape)
115 |
116 | return xgrid, ygrid
117 |
118 |
119 | def fixed_padding(inputs, kernel_size, rate=1):
120 | """Pads the input along the spatial dimensions independently of input size.
121 |
122 | Args:
123 | inputs: A tensor of size [batch, height_in, width_in, channels].
124 | kernel_size: The kernel to be used in the conv2d or max_pool2d operation.
125 | Should be a positive integer.
126 | rate: An integer, rate for atrous convolution.
127 |
128 | Returns:
129 | output: A tensor of size [batch, height_out, width_out, channels] with the
130 | input, either intact (if kernel_size == 1) or padded (if kernel_size > 1).
131 | """
132 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
133 | pad_total = kernel_size_effective - 1
134 | pad_beg = pad_total // 2
135 | pad_end = pad_total - pad_beg
136 | padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end],
137 | [pad_beg, pad_end], [0, 0]])
138 | return padded_inputs
139 |
140 |
141 | def pad_to_multiple(tensor, multiple):
142 | """Returns the tensor zero padded to the specified multiple.
143 |
144 | Appends 0s to the end of the first and second dimension (height and width) of
145 | the tensor until both dimensions are a multiple of the input argument
146 | 'multiple'. E.g. given an input tensor of shape [1, 3, 5, 1] and an input
147 | multiple of 4, PadToMultiple will append 0s so that the resulting tensor will
148 | be of shape [1, 4, 8, 1].
149 |
150 | Args:
151 | tensor: rank 4 float32 tensor, where
152 | tensor -> [batch_size, height, width, channels].
153 | multiple: the multiple to pad to.
154 |
155 | Returns:
156 | padded_tensor: the tensor zero padded to the specified multiple.
157 | """
158 | tensor_shape = tensor.get_shape()
159 | batch_size = static_shape.get_batch_size(tensor_shape)
160 | tensor_height = static_shape.get_height(tensor_shape)
161 | tensor_width = static_shape.get_width(tensor_shape)
162 | tensor_depth = static_shape.get_depth(tensor_shape)
163 |
164 | if batch_size is None:
165 | batch_size = tf.shape(tensor)[0]
166 |
167 | if tensor_height is None:
168 | tensor_height = tf.shape(tensor)[1]
169 | padded_tensor_height = tf.to_int32(
170 | tf.ceil(tf.to_float(tensor_height) / tf.to_float(multiple))) * multiple
171 | else:
172 | padded_tensor_height = int(
173 | math.ceil(float(tensor_height) / multiple) * multiple)
174 |
175 | if tensor_width is None:
176 | tensor_width = tf.shape(tensor)[2]
177 | padded_tensor_width = tf.to_int32(
178 | tf.ceil(tf.to_float(tensor_width) / tf.to_float(multiple))) * multiple
179 | else:
180 | padded_tensor_width = int(
181 | math.ceil(float(tensor_width) / multiple) * multiple)
182 |
183 | if tensor_depth is None:
184 | tensor_depth = tf.shape(tensor)[3]
185 |
186 | # Use tf.concat instead of tf.pad to preserve static shape
187 | if padded_tensor_height != tensor_height:
188 | height_pad = tf.zeros([
189 | batch_size, padded_tensor_height - tensor_height, tensor_width,
190 | tensor_depth
191 | ])
192 | tensor = tf.concat([tensor, height_pad], 1)
193 | if padded_tensor_width != tensor_width:
194 | width_pad = tf.zeros([
195 | batch_size, padded_tensor_height, padded_tensor_width - tensor_width,
196 | tensor_depth
197 | ])
198 | tensor = tf.concat([tensor, width_pad], 2)
199 |
200 | return tensor
201 |
202 |
203 | def padded_one_hot_encoding(indices, depth, left_pad):
204 | """Returns a zero padded one-hot tensor.
205 |
206 | This function converts a sparse representation of indices (e.g., [4]) to a
207 | zero padded one-hot representation (e.g., [0, 0, 0, 0, 1] with depth = 4 and
208 | left_pad = 1). If `indices` is empty, the result will simply be a tensor of
209 | shape (0, depth + left_pad). If depth = 0, then this function just returns
210 | `None`.
211 |
212 | Args:
213 | indices: an integer tensor of shape [num_indices].
214 | depth: depth for the one-hot tensor (integer).
215 | left_pad: number of zeros to left pad the one-hot tensor with (integer).
216 |
217 | Returns:
218 | padded_onehot: a tensor with shape (num_indices, depth + left_pad). Returns
219 | `None` if the depth is zero.
220 |
221 | Raises:
222 | ValueError: if `indices` does not have rank 1 or if `left_pad` or `depth are
223 | either negative or non-integers.
224 |
225 | TODO(rathodv): add runtime checks for depth and indices.
226 | """
227 | if depth < 0 or not isinstance(depth, six.integer_types):
228 | raise ValueError('`depth` must be a non-negative integer.')
229 | if left_pad < 0 or not isinstance(left_pad, six.integer_types):
230 | raise ValueError('`left_pad` must be a non-negative integer.')
231 | if depth == 0:
232 | return None
233 |
234 | rank = len(indices.get_shape().as_list())
235 | if rank != 1:
236 | raise ValueError('`indices` must have rank 1, but has rank=%s' % rank)
237 |
238 | def one_hot_and_pad():
239 | one_hot = tf.cast(tf.one_hot(tf.cast(indices, tf.int64), depth,
240 | on_value=1, off_value=0), tf.float32)
241 | return tf.pad(one_hot, [[0, 0], [left_pad, 0]], mode='CONSTANT')
242 | result = tf.cond(tf.greater(tf.size(indices), 0), one_hot_and_pad,
243 | lambda: tf.zeros((depth + left_pad, 0)))
244 | return tf.reshape(result, [-1, depth + left_pad])
245 |
246 |
247 | def dense_to_sparse_boxes(dense_locations, dense_num_boxes, num_classes):
248 | """Converts bounding boxes from dense to sparse form.
249 |
250 | Args:
251 | dense_locations: a [max_num_boxes, 4] tensor in which only the first k rows
252 | are valid bounding box location coordinates, where k is the sum of
253 | elements in dense_num_boxes.
254 | dense_num_boxes: a [max_num_classes] tensor indicating the counts of
255 | various bounding box classes e.g. [1, 0, 0, 2] means that the first
256 | bounding box is of class 0 and the second and third bounding boxes are
257 | of class 3. The sum of elements in this tensor is the number of valid
258 | bounding boxes.
259 | num_classes: number of classes
260 |
261 | Returns:
262 | box_locations: a [num_boxes, 4] tensor containing only valid bounding
263 | boxes (i.e. the first num_boxes rows of dense_locations)
264 | box_classes: a [num_boxes] tensor containing the classes of each bounding
265 | box (e.g. dense_num_boxes = [1, 0, 0, 2] => box_classes = [0, 3, 3]
266 | """
267 |
268 | num_valid_boxes = tf.reduce_sum(dense_num_boxes)
269 | box_locations = tf.slice(dense_locations,
270 | tf.constant([0, 0]), tf.stack([num_valid_boxes, 4]))
271 | tiled_classes = [tf.tile([i], tf.expand_dims(dense_num_boxes[i], 0))
272 | for i in range(num_classes)]
273 | box_classes = tf.concat(tiled_classes, 0)
274 | box_locations.set_shape([None, 4])
275 | return box_locations, box_classes
276 |
277 |
278 | def indices_to_dense_vector(indices,
279 | size,
280 | indices_value=1.,
281 | default_value=0,
282 | dtype=tf.float32):
283 | """Creates dense vector with indices set to specific value and rest to zeros.
284 |
285 | This function exists because it is unclear if it is safe to use
286 | tf.sparse_to_dense(indices, [size], 1, validate_indices=False)
287 | with indices which are not ordered.
288 | This function accepts a dynamic size (e.g. tf.shape(tensor)[0])
289 |
290 | Args:
291 | indices: 1d Tensor with integer indices which are to be set to
292 | indices_values.
293 | size: scalar with size (integer) of output Tensor.
294 | indices_value: values of elements specified by indices in the output vector
295 | default_value: values of other elements in the output vector.
296 | dtype: data type.
297 |
298 | Returns:
299 | dense 1D Tensor of shape [size] with indices set to indices_values and the
300 | rest set to default_value.
301 | """
302 | size = tf.to_int32(size)
303 | zeros = tf.ones([size], dtype=dtype) * default_value
304 | values = tf.ones_like(indices, dtype=dtype) * indices_value
305 |
306 | return tf.dynamic_stitch([tf.range(size), tf.to_int32(indices)],
307 | [zeros, values])
308 |
309 |
310 | def reduce_sum_trailing_dimensions(tensor, ndims):
311 | """Computes sum across all dimensions following first `ndims` dimensions."""
312 | return tf.reduce_sum(tensor, axis=tuple(range(ndims, tensor.shape.ndims)))
313 |
314 |
315 | def retain_groundtruth(tensor_dict, valid_indices):
316 | """Retains groundtruth by valid indices.
317 |
318 | Args:
319 | tensor_dict: a dictionary of following groundtruth tensors -
320 | fields.InputDataFields.groundtruth_boxes
321 | fields.InputDataFields.groundtruth_classes
322 | fields.InputDataFields.groundtruth_keypoints
323 | fields.InputDataFields.groundtruth_instance_masks
324 | fields.InputDataFields.groundtruth_is_crowd
325 | fields.InputDataFields.groundtruth_area
326 | fields.InputDataFields.groundtruth_label_types
327 | fields.InputDataFields.groundtruth_difficult
328 | valid_indices: a tensor with valid indices for the box-level groundtruth.
329 |
330 | Returns:
331 | a dictionary of tensors containing only the groundtruth for valid_indices.
332 |
333 | Raises:
334 | ValueError: If the shape of valid_indices is invalid.
335 | ValueError: field fields.InputDataFields.groundtruth_boxes is
336 | not present in tensor_dict.
337 | """
338 | input_shape = valid_indices.get_shape().as_list()
339 | if not (len(input_shape) == 1 or
340 | (len(input_shape) == 2 and input_shape[1] == 1)):
341 | raise ValueError('The shape of valid_indices is invalid.')
342 | valid_indices = tf.reshape(valid_indices, [-1])
343 | valid_dict = {}
344 | if fields.InputDataFields.groundtruth_boxes in tensor_dict:
345 | # Prevents reshape failure when num_boxes is 0.
346 | num_boxes = tf.maximum(tf.shape(
347 | tensor_dict[fields.InputDataFields.groundtruth_boxes])[0], 1)
348 | for key in tensor_dict:
349 | if key in [fields.InputDataFields.groundtruth_boxes,
350 | fields.InputDataFields.groundtruth_classes,
351 | fields.InputDataFields.groundtruth_keypoints,
352 | fields.InputDataFields.groundtruth_instance_masks]:
353 | valid_dict[key] = tf.gather(tensor_dict[key], valid_indices)
354 | # Input decoder returns empty tensor when these fields are not provided.
355 | # Needs to reshape into [num_boxes, -1] for tf.gather() to work.
356 | elif key in [fields.InputDataFields.groundtruth_is_crowd,
357 | fields.InputDataFields.groundtruth_area,
358 | fields.InputDataFields.groundtruth_difficult,
359 | fields.InputDataFields.groundtruth_label_types]:
360 | valid_dict[key] = tf.reshape(
361 | tf.gather(tf.reshape(tensor_dict[key], [num_boxes, -1]),
362 | valid_indices), [-1])
363 | # Fields that are not associated with boxes.
364 | else:
365 | valid_dict[key] = tensor_dict[key]
366 | else:
367 | raise ValueError('%s not present in input tensor dict.' % (
368 | fields.InputDataFields.groundtruth_boxes))
369 | return valid_dict
370 |
371 |
372 | def retain_groundtruth_with_positive_classes(tensor_dict):
373 | """Retains only groundtruth with positive class ids.
374 |
375 | Args:
376 | tensor_dict: a dictionary of following groundtruth tensors -
377 | fields.InputDataFields.groundtruth_boxes
378 | fields.InputDataFields.groundtruth_classes
379 | fields.InputDataFields.groundtruth_keypoints
380 | fields.InputDataFields.groundtruth_instance_masks
381 | fields.InputDataFields.groundtruth_is_crowd
382 | fields.InputDataFields.groundtruth_area
383 | fields.InputDataFields.groundtruth_label_types
384 | fields.InputDataFields.groundtruth_difficult
385 |
386 | Returns:
387 | a dictionary of tensors containing only the groundtruth with positive
388 | classes.
389 |
390 | Raises:
391 | ValueError: If groundtruth_classes tensor is not in tensor_dict.
392 | """
393 | if fields.InputDataFields.groundtruth_classes not in tensor_dict:
394 | raise ValueError('`groundtruth classes` not in tensor_dict.')
395 | keep_indices = tf.where(tf.greater(
396 | tensor_dict[fields.InputDataFields.groundtruth_classes], 0))
397 | return retain_groundtruth(tensor_dict, keep_indices)
398 |
399 |
400 | def replace_nan_groundtruth_label_scores_with_ones(label_scores):
401 | """Replaces nan label scores with 1.0.
402 |
403 | Args:
404 | label_scores: a tensor containing object annoation label scores.
405 |
406 | Returns:
407 | a tensor where NaN label scores have been replaced by ones.
408 | """
409 | return tf.where(
410 | tf.is_nan(label_scores), tf.ones(tf.shape(label_scores)), label_scores)
411 |
412 |
413 | def filter_groundtruth_with_crowd_boxes(tensor_dict):
414 | """Filters out groundtruth with boxes corresponding to crowd.
415 |
416 | Args:
417 | tensor_dict: a dictionary of following groundtruth tensors -
418 | fields.InputDataFields.groundtruth_boxes
419 | fields.InputDataFields.groundtruth_classes
420 | fields.InputDataFields.groundtruth_keypoints
421 | fields.InputDataFields.groundtruth_instance_masks
422 | fields.InputDataFields.groundtruth_is_crowd
423 | fields.InputDataFields.groundtruth_area
424 | fields.InputDataFields.groundtruth_label_types
425 |
426 | Returns:
427 | a dictionary of tensors containing only the groundtruth that have bounding
428 | boxes.
429 | """
430 | if fields.InputDataFields.groundtruth_is_crowd in tensor_dict:
431 | is_crowd = tensor_dict[fields.InputDataFields.groundtruth_is_crowd]
432 | is_not_crowd = tf.logical_not(is_crowd)
433 | is_not_crowd_indices = tf.where(is_not_crowd)
434 | tensor_dict = retain_groundtruth(tensor_dict, is_not_crowd_indices)
435 | return tensor_dict
436 |
437 |
438 | def filter_groundtruth_with_nan_box_coordinates(tensor_dict):
439 | """Filters out groundtruth with no bounding boxes.
440 |
441 | Args:
442 | tensor_dict: a dictionary of following groundtruth tensors -
443 | fields.InputDataFields.groundtruth_boxes
444 | fields.InputDataFields.groundtruth_classes
445 | fields.InputDataFields.groundtruth_keypoints
446 | fields.InputDataFields.groundtruth_instance_masks
447 | fields.InputDataFields.groundtruth_is_crowd
448 | fields.InputDataFields.groundtruth_area
449 | fields.InputDataFields.groundtruth_label_types
450 |
451 | Returns:
452 | a dictionary of tensors containing only the groundtruth that have bounding
453 | boxes.
454 | """
455 | groundtruth_boxes = tensor_dict[fields.InputDataFields.groundtruth_boxes]
456 | nan_indicator_vector = tf.greater(tf.reduce_sum(tf.to_int32(
457 | tf.is_nan(groundtruth_boxes)), reduction_indices=[1]), 0)
458 | valid_indicator_vector = tf.logical_not(nan_indicator_vector)
459 | valid_indices = tf.where(valid_indicator_vector)
460 |
461 | return retain_groundtruth(tensor_dict, valid_indices)
462 |
463 |
464 | def normalize_to_target(inputs,
465 | target_norm_value,
466 | dim,
467 | epsilon=1e-7,
468 | trainable=True,
469 | scope='NormalizeToTarget',
470 | summarize=True):
471 | """L2 normalizes the inputs across the specified dimension to a target norm.
472 |
473 | This op implements the L2 Normalization layer introduced in
474 | Liu, Wei, et al. "SSD: Single Shot MultiBox Detector."
475 | and Liu, Wei, Andrew Rabinovich, and Alexander C. Berg.
476 | "Parsenet: Looking wider to see better." and is useful for bringing
477 | activations from multiple layers in a convnet to a standard scale.
478 |
479 | Note that the rank of `inputs` must be known and the dimension to which
480 | normalization is to be applied should be statically defined.
481 |
482 | TODO(jonathanhuang): Add option to scale by L2 norm of the entire input.
483 |
484 | Args:
485 | inputs: A `Tensor` of arbitrary size.
486 | target_norm_value: A float value that specifies an initial target norm or
487 | a list of floats (whose length must be equal to the depth along the
488 | dimension to be normalized) specifying a per-dimension multiplier
489 | after normalization.
490 | dim: The dimension along which the input is normalized.
491 | epsilon: A small value to add to the inputs to avoid dividing by zero.
492 | trainable: Whether the norm is trainable or not
493 | scope: Optional scope for variable_scope.
494 | summarize: Whether or not to add a tensorflow summary for the op.
495 |
496 | Returns:
497 | The input tensor normalized to the specified target norm.
498 |
499 | Raises:
500 | ValueError: If dim is smaller than the number of dimensions in 'inputs'.
501 | ValueError: If target_norm_value is not a float or a list of floats with
502 | length equal to the depth along the dimension to be normalized.
503 | """
504 | with tf.variable_scope(scope, 'NormalizeToTarget', [inputs]):
505 | if not inputs.get_shape():
506 | raise ValueError('The input rank must be known.')
507 | input_shape = inputs.get_shape().as_list()
508 | input_rank = len(input_shape)
509 | if dim < 0 or dim >= input_rank:
510 | raise ValueError(
511 | 'dim must be non-negative but smaller than the input rank.')
512 | if not input_shape[dim]:
513 | raise ValueError('input shape should be statically defined along '
514 | 'the specified dimension.')
515 | depth = input_shape[dim]
516 | if not (isinstance(target_norm_value, float) or
517 | (isinstance(target_norm_value, list) and
518 | len(target_norm_value) == depth) and
519 | all([isinstance(val, float) for val in target_norm_value])):
520 | raise ValueError('target_norm_value must be a float or a list of floats '
521 | 'with length equal to the depth along the dimension to '
522 | 'be normalized.')
523 | if isinstance(target_norm_value, float):
524 | initial_norm = depth * [target_norm_value]
525 | else:
526 | initial_norm = target_norm_value
527 | target_norm = tf.contrib.framework.model_variable(
528 | name='weights', dtype=tf.float32,
529 | initializer=tf.constant(initial_norm, dtype=tf.float32),
530 | trainable=trainable)
531 | if summarize:
532 | mean = tf.reduce_mean(target_norm)
533 | mean = tf.Print(mean, ['NormalizeToTarget:', mean])
534 | tf.summary.scalar(tf.get_variable_scope().name, mean)
535 | lengths = epsilon + tf.sqrt(tf.reduce_sum(tf.square(inputs), dim, True))
536 | mult_shape = input_rank*[1]
537 | mult_shape[dim] = depth
538 | return tf.reshape(target_norm, mult_shape) * tf.truediv(inputs, lengths)
539 |
540 |
541 | def position_sensitive_crop_regions(image,
542 | boxes,
543 | box_ind,
544 | crop_size,
545 | num_spatial_bins,
546 | global_pool,
547 | extrapolation_value=None):
548 | """Position-sensitive crop and pool rectangular regions from a feature grid.
549 |
550 | The output crops are split into `spatial_bins_y` vertical bins
551 | and `spatial_bins_x` horizontal bins. For each intersection of a vertical
552 | and a horizontal bin the output values are gathered by performing
553 | `tf.image.crop_and_resize` (bilinear resampling) on a a separate subset of
554 | channels of the image. This reduces `depth` by a factor of
555 | `(spatial_bins_y * spatial_bins_x)`.
556 |
557 | When global_pool is True, this function implements a differentiable version
558 | of position-sensitive RoI pooling used in
559 | [R-FCN detection system](https://arxiv.org/abs/1605.06409).
560 |
561 | When global_pool is False, this function implements a differentiable version
562 | of position-sensitive assembling operation used in
563 | [instance FCN](https://arxiv.org/abs/1603.08678).
564 |
565 | Args:
566 | image: A `Tensor`. Must be one of the following types: `uint8`, `int8`,
567 | `int16`, `int32`, `int64`, `half`, `float32`, `float64`.
568 | A 4-D tensor of shape `[batch, image_height, image_width, depth]`.
569 | Both `image_height` and `image_width` need to be positive.
570 | boxes: A `Tensor` of type `float32`.
571 | A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor
572 | specifies the coordinates of a box in the `box_ind[i]` image and is
573 | specified in normalized coordinates `[y1, x1, y2, x2]`. A normalized
574 | coordinate value of `y` is mapped to the image coordinate at
575 | `y * (image_height - 1)`, so as the `[0, 1]` interval of normalized image
576 | height is mapped to `[0, image_height - 1] in image height coordinates.
577 | We do allow y1 > y2, in which case the sampled crop is an up-down flipped
578 | version of the original image. The width dimension is treated similarly.
579 | Normalized coordinates outside the `[0, 1]` range are allowed, in which
580 | case we use `extrapolation_value` to extrapolate the input image values.
581 | box_ind: A `Tensor` of type `int32`.
582 | A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`.
583 | The value of `box_ind[i]` specifies the image that the `i`-th box refers
584 | to.
585 | crop_size: A list of two integers `[crop_height, crop_width]`. All
586 | cropped image patches are resized to this size. The aspect ratio of the
587 | image content is not preserved. Both `crop_height` and `crop_width` need
588 | to be positive.
589 | num_spatial_bins: A list of two integers `[spatial_bins_y, spatial_bins_x]`.
590 | Represents the number of position-sensitive bins in y and x directions.
591 | Both values should be >= 1. `crop_height` should be divisible by
592 | `spatial_bins_y`, and similarly for width.
593 | The number of image channels should be divisible by
594 | (spatial_bins_y * spatial_bins_x).
595 | Suggested value from R-FCN paper: [3, 3].
596 | global_pool: A boolean variable.
597 | If True, we perform average global pooling on the features assembled from
598 | the position-sensitive score maps.
599 | If False, we keep the position-pooled features without global pooling
600 | over the spatial coordinates.
601 | Note that using global_pool=True is equivalent to but more efficient than
602 | running the function with global_pool=False and then performing global
603 | average pooling.
604 | extrapolation_value: An optional `float`. Defaults to `0`.
605 | Value used for extrapolation, when applicable.
606 | Returns:
607 | position_sensitive_features: A 4-D tensor of shape
608 | `[num_boxes, K, K, crop_channels]`,
609 | where `crop_channels = depth / (spatial_bins_y * spatial_bins_x)`,
610 | where K = 1 when global_pool is True (Average-pooled cropped regions),
611 | and K = crop_size when global_pool is False.
612 | Raises:
613 | ValueError: Raised in four situations:
614 | `num_spatial_bins` is not >= 1;
615 | `num_spatial_bins` does not divide `crop_size`;
616 | `(spatial_bins_y*spatial_bins_x)` does not divide `depth`;
617 | `bin_crop_size` is not square when global_pool=False due to the
618 | constraint in function space_to_depth.
619 | """
620 | total_bins = 1
621 | bin_crop_size = []
622 |
623 | for (num_bins, crop_dim) in zip(num_spatial_bins, crop_size):
624 | if num_bins < 1:
625 | raise ValueError('num_spatial_bins should be >= 1')
626 |
627 | if crop_dim % num_bins != 0:
628 | raise ValueError('crop_size should be divisible by num_spatial_bins')
629 |
630 | total_bins *= num_bins
631 | bin_crop_size.append(crop_dim // num_bins)
632 |
633 | if not global_pool and bin_crop_size[0] != bin_crop_size[1]:
634 | raise ValueError('Only support square bin crop size for now.')
635 |
636 | ymin, xmin, ymax, xmax = tf.unstack(boxes, axis=1)
637 | spatial_bins_y, spatial_bins_x = num_spatial_bins
638 |
639 | # Split each box into spatial_bins_y * spatial_bins_x bins.
640 | position_sensitive_boxes = []
641 | for bin_y in range(spatial_bins_y):
642 | step_y = (ymax - ymin) / spatial_bins_y
643 | for bin_x in range(spatial_bins_x):
644 | step_x = (xmax - xmin) / spatial_bins_x
645 | box_coordinates = [ymin + bin_y * step_y,
646 | xmin + bin_x * step_x,
647 | ymin + (bin_y + 1) * step_y,
648 | xmin + (bin_x + 1) * step_x,
649 | ]
650 | position_sensitive_boxes.append(tf.stack(box_coordinates, axis=1))
651 |
652 | image_splits = tf.split(value=image, num_or_size_splits=total_bins, axis=3)
653 |
654 | image_crops = []
655 | for (split, box) in zip(image_splits, position_sensitive_boxes):
656 | crop = tf.image.crop_and_resize(split, box, box_ind, bin_crop_size,
657 | extrapolation_value=extrapolation_value)
658 | image_crops.append(crop)
659 |
660 | if global_pool:
661 | # Average over all bins.
662 | position_sensitive_features = tf.add_n(image_crops) / len(image_crops)
663 | # Then average over spatial positions within the bins.
664 | position_sensitive_features = tf.reduce_mean(
665 | position_sensitive_features, [1, 2], keep_dims=True)
666 | else:
667 | # Reorder height/width to depth channel.
668 | block_size = bin_crop_size[0]
669 | if block_size >= 2:
670 | image_crops = [tf.space_to_depth(
671 | crop, block_size=block_size) for crop in image_crops]
672 |
673 | # Pack image_crops so that first dimension is for position-senstive boxes.
674 | position_sensitive_features = tf.stack(image_crops, axis=0)
675 |
676 | # Unroll the position-sensitive boxes to spatial positions.
677 | position_sensitive_features = tf.squeeze(
678 | tf.batch_to_space_nd(position_sensitive_features,
679 | block_shape=[1] + num_spatial_bins,
680 | crops=tf.zeros((3, 2), dtype=tf.int32)),
681 | squeeze_dims=[0])
682 |
683 | # Reorder back the depth channel.
684 | if block_size >= 2:
685 | position_sensitive_features = tf.depth_to_space(
686 | position_sensitive_features, block_size=block_size)
687 |
688 | return position_sensitive_features
689 |
690 |
691 | def reframe_box_masks_to_image_masks(box_masks, boxes, image_height,
692 | image_width):
693 | """Transforms the box masks back to full image masks.
694 |
695 | Embeds masks in bounding boxes of larger masks whose shapes correspond to
696 | image shape.
697 |
698 | Args:
699 | box_masks: A tf.float32 tensor of size [num_masks, mask_height, mask_width].
700 | boxes: A tf.float32 tensor of size [num_masks, 4] containing the box
701 | corners. Row i contains [ymin, xmin, ymax, xmax] of the box
702 | corresponding to mask i. Note that the box corners are in
703 | normalized coordinates.
704 | image_height: Image height. The output mask will have the same height as
705 | the image height.
706 | image_width: Image width. The output mask will have the same width as the
707 | image width.
708 |
709 | Returns:
710 | A tf.float32 tensor of size [num_masks, image_height, image_width].
711 | """
712 | # TODO(rathodv): Make this a public function.
713 | def reframe_box_masks_to_image_masks_default():
714 | """The default function when there are more than 0 box masks."""
715 | def transform_boxes_relative_to_boxes(boxes, reference_boxes):
716 | boxes = tf.reshape(boxes, [-1, 2, 2])
717 | min_corner = tf.expand_dims(reference_boxes[:, 0:2], 1)
718 | max_corner = tf.expand_dims(reference_boxes[:, 2:4], 1)
719 | transformed_boxes = (boxes - min_corner) / (max_corner - min_corner)
720 | return tf.reshape(transformed_boxes, [-1, 4])
721 |
722 | box_masks_expanded = tf.expand_dims(box_masks, axis=3)
723 | num_boxes = tf.shape(box_masks_expanded)[0]
724 | unit_boxes = tf.concat(
725 | [tf.zeros([num_boxes, 2]), tf.ones([num_boxes, 2])], axis=1)
726 | reverse_boxes = transform_boxes_relative_to_boxes(unit_boxes, boxes)
727 | return tf.image.crop_and_resize(
728 | image=box_masks_expanded,
729 | boxes=reverse_boxes,
730 | box_ind=tf.range(num_boxes),
731 | crop_size=[image_height, image_width],
732 | extrapolation_value=0.0)
733 | image_masks = tf.cond(
734 | tf.shape(box_masks)[0] > 0,
735 | reframe_box_masks_to_image_masks_default,
736 | lambda: tf.zeros([0, image_height, image_width, 1], dtype=tf.float32))
737 | return tf.squeeze(image_masks, axis=3)
738 |
739 |
740 | def merge_boxes_with_multiple_labels(boxes, classes, num_classes):
741 | """Merges boxes with same coordinates and returns K-hot encoded classes.
742 |
743 | Args:
744 | boxes: A tf.float32 tensor with shape [N, 4] holding N boxes.
745 | classes: A tf.int32 tensor with shape [N] holding class indices.
746 | The class index starts at 0.
747 | num_classes: total number of classes to use for K-hot encoding.
748 |
749 | Returns:
750 | merged_boxes: A tf.float32 tensor with shape [N', 4] holding boxes,
751 | where N' <= N.
752 | class_encodings: A tf.int32 tensor with shape [N', num_classes] holding
753 | k-hot encodings for the merged boxes.
754 | merged_box_indices: A tf.int32 tensor with shape [N'] holding original
755 | indices of the boxes.
756 | """
757 | def merge_numpy_boxes(boxes, classes, num_classes):
758 | """Python function to merge numpy boxes."""
759 | if boxes.size < 1:
760 | return (np.zeros([0, 4], dtype=np.float32),
761 | np.zeros([0, num_classes], dtype=np.int32),
762 | np.zeros([0], dtype=np.int32))
763 | box_to_class_indices = {}
764 | for box_index in range(boxes.shape[0]):
765 | box = tuple(boxes[box_index, :].tolist())
766 | class_index = classes[box_index]
767 | if box not in box_to_class_indices:
768 | box_to_class_indices[box] = [box_index, np.zeros([num_classes])]
769 | box_to_class_indices[box][1][class_index] = 1
770 | merged_boxes = np.vstack(box_to_class_indices.keys()).astype(np.float32)
771 | class_encodings = [item[1] for item in box_to_class_indices.values()]
772 | class_encodings = np.vstack(class_encodings).astype(np.int32)
773 | merged_box_indices = [item[0] for item in box_to_class_indices.values()]
774 | merged_box_indices = np.array(merged_box_indices).astype(np.int32)
775 | return merged_boxes, class_encodings, merged_box_indices
776 |
777 | merged_boxes, class_encodings, merged_box_indices = tf.py_func(
778 | merge_numpy_boxes, [boxes, classes, num_classes],
779 | [tf.float32, tf.int32, tf.int32])
780 | merged_boxes = tf.reshape(merged_boxes, [-1, 4])
781 | class_encodings = tf.reshape(class_encodings, [-1, num_classes])
782 | merged_box_indices = tf.reshape(merged_box_indices, [-1])
783 | return merged_boxes, class_encodings, merged_box_indices
784 |
785 |
786 | def nearest_neighbor_upsampling(input_tensor, scale):
787 | """Nearest neighbor upsampling implementation.
788 |
789 | Nearest neighbor upsampling function that maps input tensor with shape
790 | [batch_size, height, width, channels] to [batch_size, height * scale
791 | , width * scale, channels]. This implementation only uses reshape and
792 | broadcasting to make it TPU compatible.
793 |
794 | Args:
795 | input_tensor: A float32 tensor of size [batch, height_in, width_in,
796 | channels].
797 | scale: An integer multiple to scale resolution of input data.
798 | Returns:
799 | data_up: A float32 tensor of size
800 | [batch, height_in*scale, width_in*scale, channels].
801 | """
802 | with tf.name_scope('nearest_neighbor_upsampling'):
803 | (batch_size, height, width,
804 | channels) = shape_utils.combined_static_and_dynamic_shape(input_tensor)
805 | output_tensor = tf.reshape(
806 | input_tensor, [batch_size, height, 1, width, 1, channels]) * tf.ones(
807 | [1, 1, scale, 1, scale, 1], dtype=input_tensor.dtype)
808 | return tf.reshape(output_tensor,
809 | [batch_size, height * scale, width * scale, channels])
810 |
811 |
812 | def matmul_gather_on_zeroth_axis(params, indices, scope=None):
813 | """Matrix multiplication based implementation of tf.gather on zeroth axis.
814 |
815 | TODO(rathodv, jonathanhuang): enable sparse matmul option.
816 |
817 | Args:
818 | params: A float32 Tensor. The tensor from which to gather values.
819 | Must be at least rank 1.
820 | indices: A Tensor. Must be one of the following types: int32, int64.
821 | Must be in range [0, params.shape[0])
822 | scope: A name for the operation (optional).
823 |
824 | Returns:
825 | A Tensor. Has the same type as params. Values from params gathered
826 | from indices given by indices, with shape indices.shape + params.shape[1:].
827 | """
828 | with tf.name_scope(scope, 'MatMulGather'):
829 | params_shape = shape_utils.combined_static_and_dynamic_shape(params)
830 | indices_shape = shape_utils.combined_static_and_dynamic_shape(indices)
831 | params2d = tf.reshape(params, [params_shape[0], -1])
832 | indicator_matrix = tf.one_hot(indices, params_shape[0])
833 | gathered_result_flattened = tf.matmul(indicator_matrix, params2d)
834 | return tf.reshape(gathered_result_flattened,
835 | tf.stack(indices_shape + params_shape[1:]))
836 |
837 |
838 | def matmul_crop_and_resize(image, boxes, crop_size, scope=None):
839 | """Matrix multiplication based implementation of the crop and resize op.
840 |
841 | Extracts crops from the input image tensor and bilinearly resizes them
842 | (possibly with aspect ratio change) to a common output size specified by
843 | crop_size. This is more general than the crop_to_bounding_box op which
844 | extracts a fixed size slice from the input image and does not allow
845 | resizing or aspect ratio change.
846 |
847 | Returns a tensor with crops from the input image at positions defined at
848 | the bounding box locations in boxes. The cropped boxes are all resized
849 | (with bilinear interpolation) to a fixed size = `[crop_height, crop_width]`.
850 | The result is a 4-D tensor `[num_boxes, crop_height, crop_width, depth]`.
851 |
852 | Running time complexity:
853 | O((# channels) * (# boxes) * (crop_size)^2 * M), where M is the number
854 | of pixels of the longer edge of the image.
855 |
856 | Note that this operation is meant to replicate the behavior of the standard
857 | tf.image.crop_and_resize operation but there are a few differences.
858 | Specifically:
859 | 1) The extrapolation value (the values that are interpolated from outside
860 | the bounds of the image window) is always zero
861 | 2) Only XLA supported operations are used (e.g., matrix multiplication).
862 | 3) There is no `box_indices` argument --- to run this op on multiple images,
863 | one must currently call this op independently on each image.
864 | 4) All shapes and the `crop_size` parameter are assumed to be statically
865 | defined. Moreover, the number of boxes must be strictly nonzero.
866 |
867 | Args:
868 | image: A `Tensor`. Must be one of the following types: `uint8`, `int8`,
869 | `int16`, `int32`, `int64`, `half`, `float32`, `float64`.
870 | A 4-D tensor of shape `[batch, image_height, image_width, depth]`.
871 | Both `image_height` and `image_width` need to be positive.
872 | boxes: A `Tensor` of type `float32`.
873 | A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor
874 | specifies the coordinates of a box in the `box_ind[i]` image and is
875 | specified in normalized coordinates `[y1, x1, y2, x2]`. A normalized
876 | coordinate value of `y` is mapped to the image coordinate at
877 | `y * (image_height - 1)`, so as the `[0, 1]` interval of normalized image
878 | height is mapped to `[0, image_height - 1] in image height coordinates.
879 | We do allow y1 > y2, in which case the sampled crop is an up-down flipped
880 | version of the original image. The width dimension is treated similarly.
881 | Normalized coordinates outside the `[0, 1]` range are allowed, in which
882 | case we use `extrapolation_value` to extrapolate the input image values.
883 | crop_size: A list of two integers `[crop_height, crop_width]`. All
884 | cropped image patches are resized to this size. The aspect ratio of the
885 | image content is not preserved. Both `crop_height` and `crop_width` need
886 | to be positive.
887 | scope: A name for the operation (optional).
888 |
889 | Returns:
890 | A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`
891 |
892 | Raises:
893 | ValueError: if image tensor does not have shape
894 | `[1, image_height, image_width, depth]` and all dimensions statically
895 | defined.
896 | ValueError: if boxes tensor does not have shape `[num_boxes, 4]` where
897 | num_boxes > 0.
898 | ValueError: if crop_size is not a list of two positive integers
899 | """
900 | img_shape = image.shape.as_list()
901 | boxes_shape = boxes.shape.as_list()
902 | _, img_height, img_width, _ = img_shape
903 | if not isinstance(crop_size, list) or len(crop_size) != 2:
904 | raise ValueError('`crop_size` must be a list of length 2')
905 | dimensions = img_shape + crop_size + boxes_shape
906 | if not all([isinstance(dim, int) for dim in dimensions]):
907 | raise ValueError('all input shapes must be statically defined')
908 | if len(crop_size) != 2:
909 | raise ValueError('`crop_size` must be a list of length 2')
910 | if len(boxes_shape) != 2 or boxes_shape[1] != 4:
911 | raise ValueError('`boxes` should have shape `[num_boxes, 4]`')
912 | if len(img_shape) != 4 and img_shape[0] != 1:
913 | raise ValueError('image should have shape '
914 | '`[1, image_height, image_width, depth]`')
915 | num_crops = boxes_shape[0]
916 | if not num_crops > 0:
917 | raise ValueError('number of boxes must be > 0')
918 | if not (crop_size[0] > 0 and crop_size[1] > 0):
919 | raise ValueError('`crop_size` must be a list of two positive integers.')
920 |
921 | def _lin_space_weights(num, img_size):
922 | if num > 1:
923 | alpha = (img_size - 1) / float(num - 1)
924 | indices = np.reshape(np.arange(num), (1, num))
925 | start_weights = alpha * (num - 1 - indices)
926 | stop_weights = alpha * indices
927 | else:
928 | start_weights = num * [.5 * (img_size - 1)]
929 | stop_weights = num * [.5 * (img_size - 1)]
930 | return (tf.constant(start_weights, dtype=tf.float32),
931 | tf.constant(stop_weights, dtype=tf.float32))
932 |
933 | with tf.name_scope(scope, 'MatMulCropAndResize'):
934 | y1_weights, y2_weights = _lin_space_weights(crop_size[0], img_height)
935 | x1_weights, x2_weights = _lin_space_weights(crop_size[1], img_width)
936 | [y1, x1, y2, x2] = tf.split(value=boxes, num_or_size_splits=4, axis=1)
937 |
938 | # Pixel centers of input image and grid points along height and width
939 | image_idx_h = tf.constant(
940 | np.reshape(np.arange(img_height), (1, 1, img_height)), dtype=tf.float32)
941 | image_idx_w = tf.constant(
942 | np.reshape(np.arange(img_width), (1, 1, img_width)), dtype=tf.float32)
943 | grid_pos_h = tf.expand_dims(y1 * y1_weights + y2 * y2_weights, 2)
944 | grid_pos_w = tf.expand_dims(x1 * x1_weights + x2 * x2_weights, 2)
945 |
946 | # Create kernel matrices of pairwise kernel evaluations between pixel
947 | # centers of image and grid points.
948 | kernel_h = tf.nn.relu(1 - tf.abs(image_idx_h - grid_pos_h))
949 | kernel_w = tf.nn.relu(1 - tf.abs(image_idx_w - grid_pos_w))
950 |
951 | # TODO(jonathanhuang): investigate whether all channels can be processed
952 | # without the explicit unstack --- possibly with a permute and map_fn call.
953 | result_channels = []
954 | for channel in tf.unstack(image, axis=3):
955 | result_channels.append(
956 | tf.matmul(
957 | tf.matmul(kernel_h, tf.tile(channel, [num_crops, 1, 1])),
958 | kernel_w, transpose_b=True))
959 | return tf.stack(result_channels, axis=3)
960 |
--------------------------------------------------------------------------------