├── .gitignore
├── LICENSE
├── README.md
├── detr_tf
├── bbox.py
├── data
│ ├── __init__.py
│ ├── coco.py
│ ├── processing.py
│ ├── tfcsv.py
│ ├── transformation.py
│ └── voc.py
├── inference.py
├── logger
│ ├── training_logging.py
│ └── wandb_logging.py
├── loss
│ ├── compute_map.py
│ ├── hungarian_matching.py
│ └── loss.py
├── networks
│ ├── custom_layers.py
│ ├── detr.py
│ ├── position_embeddings.py
│ ├── resnet_backbone.py
│ ├── transformer.py
│ └── weights.py
├── optimizers.py
├── training.py
└── training_config.py
├── eval.py
├── finetune_coco.py
├── finetune_hardhat.py
├── finetune_voc.py
├── images
├── datasetsupport.png
├── detr-figure.png
├── hardhatdataset.jpg
├── tutorials
│ ├── data-pipeline.png
│ └── download_hardhat_dataset.png
├── wandb_logging.png
└── webcam_detr.png
├── notebooks
├── DETR Tensorflow - Finetuning tutorial.ipynb
├── DETR Tensorflow - How to setup a custom dataset.ipynb
└── How to load a dataset.ipynb
├── requirements.txt
├── train_coco.py
└── webcam_inference.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | wandb
10 | weights
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Visual-Behavior
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DETR : End-to-End Object Detection with Transformers (Tensorflow)
2 |
3 | Tensorflow implementation of DETR : Object Detection with Transformers, including code for inference, training, and finetuning. DETR is a promising model that brings widely adopted transformers to vision models. We believe that models based on convolution and transformers will soon become the default choice for most practitioners because of the simplicity of the training procedure: NMS and anchors free! Therefore this repository is a step toward making this type of architecture widely available.
4 |
5 | * [1. Install](#install)
6 | * [2. Datasets](#datasets)
7 | * [3. Tutorials](#tutorials)
8 | * [4. Finetuning](#finetuning)
9 | * [5. Training](#training)
10 | * [5. inference](#inference)
11 | * [6. Acknowledgement](#acknowledgement)
12 |
13 |
14 | DETR paper: https://arxiv.org/pdf/2005.12872.pdf
15 | Torch implementation: https://github.com/facebookresearch/detr
16 |
17 |
18 |
19 | About this implementation: This repository includes codes to run an inference with the original model's weights (based on the PyTorch weights), to train the model from scratch (multi-GPU training support coming soon) as well as examples to finetune the model on your dataset. Unlike the PyTorch implementation, the training uses fixed image sizes and a standard Adam optimizer with gradient norm clipping.
20 |
21 | Additionally, our logging system is based on https://www.wandb.com/ so that you can get a great visualization of your model performance!
22 |
23 | - Checkout our logging board with the reports here: https://wandb.ai/thibault-neveu/detr-tensorflow-log
24 |
25 |
26 |
27 | ## Install
28 |
29 | The code has been currently tested with Tensorflow 2.3.0 and python 3.7. The following dependencies are required.
30 |
31 |
32 | ```
33 | wandb
34 | matplotlib
35 | numpy
36 | pycocotools
37 | scikit-image
38 | imageio
39 | pandas
40 | ```
41 |
42 | ```
43 | pip install -r requirements.txt
44 | ```
45 |
46 |
47 |
48 | ## Datasets
49 |
50 |
51 | This repository currently supports three dataset formats: **COCO**, **VOC**, and **Tensorflow Object detection csv**. The easiest way to get started is to set up your dataset based on one of these formats. Along with the datasets, we provide a code example to finetune your model.
52 | Finally, we provide a jupyter notebook to help you understand how to load a dataset, set up a custom dataset, and finetune your model.
53 |
54 |
55 |
56 | ## Tutorials
57 |
58 | To get started with the repository you can check the following Jupyter notebooks:
59 |
60 | - ✍ [DETR Tensorflow - How to load a dataset.ipynb](https://github.com/Visual-Behavior/detr-tensorflow/blob/main/notebooks/How%20to%20load%20a%20dataset.ipynb)
61 | - ✍ [DETR Tensorflow - Finetuning tutorial.ipynb](https://github.com/Visual-Behavior/detr-tensorflow/blob/main/notebooks/DETR%20Tensorflow%20-%20%20Finetuning%20tutorial.ipynb)
62 | - ✍ [DETR Tensorflow - How to setup a custom dataset.ipynb](https://github.com/Visual-Behavior/detr-tensorflow/blob/main/notebooks/DETR%20Tensorflow%20-%20%20How%20to%20setup%20a%20custom%20dataset.ipynb)
63 |
64 | As well as the logging board on wandb https://wandb.ai/thibault-neveu/detr-tensorflow-log and this report:
65 |
66 | - 🚀 [Finetuning DETR on Tensorflow - A step by step guide](https://wandb.ai/thibault-neveu/detr-tensorflow-log/reports/Finetuning-DETR-on-Tensorflow-A-step-by-step-tutorial--VmlldzozOTYyNzQ)
67 |
68 |
69 | ## Evaluation
70 |
71 | Run the following to evaluate the model using the pre-trained weights.
72 | - **data_dir** is your coco dataset folder
73 | - **img_dir** is the image folder relative to the data_dir
74 | - **ann_file** is the validation annotation file relative to the data_dir
75 |
76 | Checkout ✍ [DETR Tensorflow - How to load a dataset.ipynb](https://github.com/Visual-Behavior/detr-tensorflow/blob/main/notebooks/How%20to%20load%20a%20dataset.ipynb) for more information about the supported dataset ans their usage.
77 |
78 | ```
79 | python eval.py --data_dir /path/to/coco/dataset --img_dir val2017 --ann_file annotations/instances_val2017.json
80 | ```
81 |
82 | Outputs:
83 |
84 | ```
85 | | all | .50 | .55 | .60 | .65 | .70 | .75 | .80 | .85 | .90 | .95 |
86 | -------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+
87 | box | 36.53 | 55.38 | 53.13 | 50.46 | 47.11 | 43.07 | 38.11 | 32.10 | 25.01 | 16.20 | 4.77 |
88 | mask | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 |
89 | -------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+
90 |
91 | ```
92 |
93 | The result is not the same as reported in the paper because the evaluation is run on the original image size and not on the larger images. The actual implementation resizes the image so that the shorter side is at least 800pixels and the longer side at most 1333.
94 |
95 |
96 | ## Finetuning
97 |
98 | To fine-tune the model on a new dataset we siply need to set the number of class to detect in our new dataset (**nb_class**). The method will remove the last layers that predict the box class&positions and add new layers to finetune.
99 |
100 | ```python
101 | # Load the pretrained model
102 | detr = get_detr_model(config, include_top=False, nb_class=3, weights="detr", num_decoder_layers=6, num_encoder_layers=6)
103 | detr.summary()
104 |
105 | # Load your dataset
106 | train_dt, class_names = load_tfcsv_dataset(config, config.batch_size, augmentation=True)
107 |
108 | # Setup the optimziers and the trainable variables
109 | optimzers = setup_optimizers(detr, config)
110 |
111 | # Train the model
112 | training.fit(detr, train_dt, optimzers, config, epoch_nb, class_names)
113 | ```
114 | The following commands gives some examples to finetune the model on new datasets: (Pacal VOC) and (The Hard hat dataset), with a real ```batch_size``` of 8 and a virtual ```target_batch``` size (gradient aggregate) of 32. ```--log``` is used for logging the training into wandb.
115 |
116 | - **data_dir** is your voc dataset folder
117 | - **img_dir** is the image folder relative to the data_dir
118 | - **ann_file** is the validation annotation file relative to the data_dir
119 |
120 | ```
121 | python finetune_voc.py --data_dir /home/thibault/data/VOCdevkit/VOC2012 --img_dir JPEGImages --ann_dir Annotations --batch_size 8 --target_batch 32 --log
122 |
123 | ```
124 | - **data_dir** is the hardhatcsv dataset folder
125 | - **img_dir** and **ann_file** set in the training file to load the training and validation differently
126 |
127 | Checkout ✍ [DETR Tensorflow - How to load a dataset.ipynb](https://github.com/Visual-Behavior/detr-tensorflow/blob/main/notebooks/How%20to%20load%20a%20dataset.ipynb) for more information about the supported dataset ans their usage.
128 |
129 | ```
130 | python finetune_hardhat.py --data_dir /home/thibault/data/hardhat --batch_size 8 --target_batch 32 --log
131 | ```
132 |
133 | ## Training
134 |
135 | (Multi GPU training comming soon)
136 |
137 |
138 | - **data_dir** is the coco dataset folder
139 | - **img_dir** and **ann_file** set in the training file to load the training and validation differently.
140 |
141 | ```
142 | python train_coco.py --data_dir /path/to/COCO --batch_size 8 --target_batch 32 --log
143 | ```
144 |
145 | ## Inference
146 |
147 | Here is an example of running an inference with the model on your webcam.
148 |
149 | ```
150 | python webcam_inference.py
151 | ```
152 |
153 |
154 |
155 |
156 | ## Acknowledgement
157 |
158 | The pretrained weights of this models are originaly provide from the Facebook repository https://github.com/facebookresearch/detr and made avaiable in tensorflow in this repository: https://github.com/Leonardo-Blanger/detr_tensorflow
159 |
--------------------------------------------------------------------------------
/detr_tf/bbox.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is used to define all the function related to the manipulation
3 | and comparaison of bbox
4 | """
5 |
6 | from typing import Union,Dict,Tuple
7 | import matplotlib.pyplot as plt
8 | import tensorflow as tf
9 | import numpy as np
10 | import random
11 | import cv2
12 |
13 |
14 | def bbox_xcycwh_to_x1y1x2y2(bbox_xcycwh: np.array):
15 | """
16 | Rescale a list of bbox to the image size
17 | @bbox_xcycwh: [[xc, yc, w, h], ...]
18 | @img_size (height, width)
19 | """
20 | bbox_x1y1x2y2 = np.zeros_like((bbox_xcycwh))
21 | bbox_x1y1x2y2[:,0] = bbox_xcycwh[:,0] - (bbox_xcycwh[:,2] / 2)
22 | bbox_x1y1x2y2[:,2] = bbox_xcycwh[:,0] + (bbox_xcycwh[:,2] / 2)
23 | bbox_x1y1x2y2[:,1] = bbox_xcycwh[:,1] - (bbox_xcycwh[:,3] / 2)
24 | bbox_x1y1x2y2[:,3] = bbox_xcycwh[:,1] + (bbox_xcycwh[:,3] / 2)
25 | bbox_x1y1x2y2 = bbox_x1y1x2y2.astype(np.int32)
26 | return bbox_x1y1x2y2
27 |
28 |
29 | def intersect(box_a: tf.Tensor, box_b: tf.Tensor) -> tf.Tensor:
30 | """
31 | Compute the intersection area between two sets of boxes.
32 | Args:
33 | box_a: A (tf.Tensor) list a bbox (a, 4) with a the number of bbox
34 | box_b: A (tf.Tensor) list a bbox (b, 4) with b the number of bbox
35 | Returns:
36 | The intersection area [a, b] between each bbox. zero if no intersection
37 | """
38 | # resize both tensors to [A,B,2] with the tile function to compare
39 | # each bbox with the anchors:
40 | # [A,2] -> [A,1,2] -> [A,B,2]
41 | # [B,2] -> [1,B,2] -> [A,B,2]
42 | # Then we compute the area of intersect between box_a and box_b.
43 | # box_a: (tensor) bounding boxes, Shape: [n, A, 4].
44 | # box_b: (tensor) bounding boxes, Shape: [n, B, 4].
45 | # Return: (tensor) intersection area, Shape: [n,A,B].
46 |
47 | A = tf.shape(box_a)[0] # Number of possible bbox
48 | B = tf.shape(box_b)[0] # Number of anchors
49 |
50 | #print(A, B, box_a.shape, box_b.shape)
51 | # Above Right Corner of Intersect Area
52 | # (b, A, 2) -> (b, A, B, 2)
53 | tiled_box_a_xymax = tf.tile(tf.expand_dims(box_a[:, 2:], axis=1), [1, B, 1])
54 | # (b, B, 2) -> (b, A, B, 2)
55 | tiled_box_b_xymax = tf.tile(tf.expand_dims(box_b[:, 2:], axis=0), [A, 1, 1])
56 | # Select the lower right corner of the intersect area
57 | above_right_corner = tf.math.minimum(tiled_box_a_xymax, tiled_box_b_xymax)
58 |
59 |
60 | # Upper Left Corner of Intersect Area
61 | # (b, A, 2) -> (b, A, B, 2)
62 | tiled_box_a_xymin = tf.tile(tf.expand_dims(box_a[:, :2], axis=1), [1, B, 1])
63 | # (b, B, 2) -> (b, A, B, 2)
64 | tiled_box_b_xymin = tf.tile(tf.expand_dims(box_b[:, :2], axis=0), [A, 1, 1])
65 | # Select the lower right corner of the intersect area
66 | upper_left_corner = tf.math.maximum(tiled_box_a_xymin, tiled_box_b_xymin)
67 |
68 |
69 | # If there is some intersection, both must be > 0
70 | inter = tf.nn.relu(above_right_corner - upper_left_corner)
71 | inter = inter[:, :, 0] * inter[:, :, 1]
72 | return inter
73 |
74 |
75 | def jaccard(box_a: tf.Tensor, box_b: tf.Tensor, return_union=False) -> tf.Tensor:
76 | """
77 | Compute the IoU of two sets of boxes.
78 | Args:
79 | box_a: A (tf.Tensor) list a bbox (a, 4) with a the number of bbox
80 | box_b: A (tf.Tensor) list a bbox (b, 4) with b the number of bbox
81 | Returns:
82 | The Jaccard overlap [a, b] between each bbox
83 | """
84 | # Get the intersectin area
85 | inter = intersect(box_a, box_b)
86 |
87 | # Compute the A area
88 | # (xmax - xmin) * (ymax - ymin)
89 | area_a = (box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])
90 | # Tile the area to match the anchors area
91 | area_a = tf.tile(tf.expand_dims(area_a, axis=-1), [1, tf.shape(inter)[-1]])
92 |
93 | # Compute the B area
94 | # (xmax - xmin) * (ymax - ymin)
95 | area_b = (box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])
96 | # Tile the area to match the gt areas
97 | area_b = tf.tile(tf.expand_dims(area_b, axis=-2), [tf.shape(inter)[-2], 1])
98 |
99 | union = area_a + area_b - inter
100 |
101 | if return_union is False:
102 | # Return the intesect over union
103 | return inter / union
104 | else:
105 | return inter / union, union
106 |
107 | def merge(box_a: tf.Tensor, box_b: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
108 | """
109 | Merged two set of boxes so that operations ca be run to compare them
110 | Args:
111 | box_a: A (tf.Tensor) list a bbox (a, 4) with a the number of bbox
112 | box_b: A (tf.Tensor) list a bbox (b, 4) with b the number of bbox
113 | Returns:
114 | Return the two same tensor tiled: (a, b, 4)
115 | """
116 | A = tf.shape(box_a)[0] # Number of bbox in box_a
117 | B = tf.shape(box_b)[0] # Number of bbox in box b
118 | # Above Right Corner of Intersect Area
119 | # (b, A, 2) -> (b, A, B, 2)
120 | tiled_box_a = tf.tile(tf.expand_dims(box_a, axis=1), [1, B, 1])
121 | # (b, B, 2) -> (b, A, B, 2)
122 | tiled_box_b = tf.tile(tf.expand_dims(box_b, axis=0), [A, 1, 1])
123 |
124 | return tiled_box_a, tiled_box_b
125 |
126 | def xy_min_xy_max_to_yx_min_yx_max(bbox: tf.Tensor) -> tf.Tensor:
127 | """
128 | Convert bbox from shape [xmin, ymin, xmax, ymax] to [ymin, xmin, ymax, xmax]
129 | Args:
130 | bbox A (tf.Tensor) list a bbox (n, 4) with n the number of bbox to convert
131 | Returns:
132 | The converted bbox
133 | """
134 | return tf.concat([
135 | bbox[:,1:2],
136 | bbox[:,0:1],
137 | bbox[:,3:4],
138 | bbox[:,2:3]
139 | ], axis=-1)
140 |
141 | def yx_min_yx_max_to_xy_min_xy_max(bbox: tf.Tensor) -> tf.Tensor:
142 | """
143 | Convert bbox from shape [ymin, xmin, ymax, xmax] to [xmin, ymin, xmax, ymax]
144 | Args:
145 | bbox A (tf.Tensor) list a bbox (n, 4) with n the number of bbox to convert
146 | Returns:
147 | The converted bbox
148 | """
149 | return tf.concat([
150 | bbox[:,1:2],
151 | bbox[:,0:1],
152 | bbox[:,3:4],
153 | bbox[:,2:3]
154 | ], axis=-1)
155 |
156 |
157 | def xy_min_xy_max_to_xcycwh(bbox: tf.Tensor) -> tf.Tensor:
158 | """
159 | Convert bbox from shape [xmin, ymin, xmax, ymax] to [xc, yc, w, h]
160 | Args:
161 | bbox A (tf.Tensor) list a bbox (n, 4) with n the number of bbox to convert
162 | Returns:
163 | The converted bbox
164 | """
165 | # convert the bbox from [xmin, ymin, xmax, ymax] to [x_center, y_center, w, h]
166 | bbox_xcycwh = tf.concat([bbox[:, :2] + ((bbox[:, 2:] - bbox[:, :2]) / 2), bbox[:, 2:] - bbox[:, :2]], axis=-1)
167 | return bbox_xcycwh
168 |
169 |
170 |
171 | def xcycwh_to_xy_min_xy_max(bbox: tf.Tensor) -> tf.Tensor:
172 | """
173 | Convert bbox from shape [xc, yc, w, h] to [xmin, ymin, xmax, ymax]
174 | Args:
175 | bbox A (tf.Tensor) list a bbox (n, 4) with n the number of bbox to convert
176 | Returns:
177 | The converted bbox
178 | """
179 | # convert the bbox from [xc, yc, w, h] to [xmin, ymin, xmax, ymax].
180 | bbox_xyxy = tf.concat([bbox[:, :2] - (bbox[:, 2:] / 2), bbox[:, :2] + (bbox[:, 2:] / 2)], axis=-1)
181 | # Be sure to keep the values btw 0 and 1
182 | bbox_xyxy = tf.clip_by_value(bbox_xyxy, 0.0, 1.0)
183 | return bbox_xyxy
184 |
185 |
186 | def xcycwh_to_yx_min_yx_max(bbox: tf.Tensor) -> tf.Tensor:
187 | """
188 | Convert bbox from shape [xc, yc, w, h] to [ymin, xmin, ymax, xmax]
189 | Args:
190 | bbox A (tf.Tensor) list a bbox (n, 4) with n the number of bbox to convert
191 | Returns:
192 | The converted bbox
193 | """
194 | bbox = xcycwh_to_xy_min_xy_max(bbox)
195 | bbox = xy_min_xy_max_to_yx_min_yx_max(bbox)
196 | return bbox
197 |
198 |
199 | def yx_min_yx_max_to_xcycwh(bbox: tf.Tensor) -> tf.Tensor:
200 | """
201 | Convert bbox from shape [ymin, xmin, ymax, xmax] to [xc, yc, w, h]
202 | Args:
203 | bbox A (tf.Tensor) list a bbox (n, 4) with n the number of bbox to convert
204 | Returns:
205 | The converted bbox
206 | """
207 | bbox = yx_min_yx_max_to_xy_min_xy_max(bbox)
208 | bbox = xy_min_xy_max_to_xcycwh(bbox)
209 | return bbox
210 |
211 |
212 |
213 | """
214 | Numpy Transformations
215 | """
216 |
217 | def xy_min_xy_max_to_xcycwh(bbox: np.array) -> np.array:
218 | """
219 | Convert bbox from shape [xmin, ymin, xmax, ymax] to [xc, yc, w, h]
220 | Args:
221 | bbox A (np.array) list a bbox (n, 4) with n the number of bbox to convert
222 | Returns:
223 | The converted bbox
224 | """
225 | # convert the bbox from [xmin, ymin, xmax, ymax] to [x_center, y_center, w, h]
226 | bbox_xcycwh = np.concatenate([bbox[:, :2] + ((bbox[:, 2:] - bbox[:, :2]) / 2), bbox[:, 2:] - bbox[:, :2]], axis=-1)
227 | return bbox_xcycwh
228 |
229 |
230 | def np_xcycwh_to_xy_min_xy_max(bbox: np.array) -> np.array:
231 | """
232 | Convert bbox from shape [xc, yc, w, h] to [xmin, ymin, xmax, ymax]
233 | Args:
234 | bbox A (tf.Tensor) list a bbox (n, 4) with n the number of bbox to convert
235 | Returns:
236 | The converted bbox
237 | """
238 | # convert the bbox from [xc, yc, w, h] to [xmin, ymin, xmax, ymax].
239 | bbox_xy = np.concatenate([bbox[:, :2] - (bbox[:, 2:] / 2), bbox[:, :2] + (bbox[:, 2:] / 2)], axis=-1)
240 | return bbox_xy
241 |
242 |
243 |
244 | def np_yx_min_yx_max_to_xy_min_xy_max(bbox: np.array) -> np.array:
245 | """
246 | Convert bbox from shape [ymin, xmin, ymax, xmax] to [xmin, ymin, xmax, ymax]
247 | Args:
248 | bbox A (np.array) list a bbox (n, 4) with n the number of bbox to convert
249 | Returns:
250 | The converted bbox
251 | """
252 | return np.concatenate([
253 | bbox[:,1:2],
254 | bbox[:,0:1],
255 | bbox[:,3:4],
256 | bbox[:,2:3]
257 | ], axis=-1)
258 |
259 |
260 |
261 | def np_rescale_bbox_xcycwh(bbox_xcycwh: np.array, img_size: tuple):
262 | """
263 | Rescale a list of bbox to the image size
264 | @bbox_xcycwh: [[xc, yc, w, h], ...]
265 | @img_size (height, width)
266 | """
267 | bbox_xcycwh = np.array(bbox_xcycwh) # Be sure to work with a numpy array
268 | scale = np.array([img_size[1], img_size[0], img_size[1], img_size[0]])
269 | bbox_xcycwh_rescaled = bbox_xcycwh * scale
270 | return bbox_xcycwh_rescaled
271 |
272 |
273 | def np_rescale_bbox_yx_min_yx_max(bbox_xcycwh: np.array, img_size: tuple):
274 | """
275 | Rescale a list of bbox to the image size
276 | @bbox_xcycwh: [[y_min, x_min, y_max, x_max], ...]
277 | @img_size (height, width)
278 | """
279 | bbox_xcycwh = np.array(bbox_xcycwh) # Be sure to work with a numpy array
280 | scale = np.array([img_size[0], img_size[1], img_size[0], img_size[1]])
281 | bbox_xcycwh_rescaled = bbox_xcycwh * scale
282 | return bbox_xcycwh_rescaled
283 |
284 |
285 | def np_rescale_bbox_xy_min_xy_max(bbox: np.array, img_size: tuple):
286 | """
287 | Rescale a list of bbox to the image size
288 | @bbox: [[x_min, y_min, x_max, y_max], ...]
289 | @img_size (height, width)
290 | """
291 | bbox = np.array(bbox) # Be sure to work with a numpy array
292 | scale = np.array([img_size[1], img_size[0], img_size[1], img_size[0]])
293 | bbox_rescaled = bbox * scale
294 | return bbox_rescaled
295 |
296 |
--------------------------------------------------------------------------------
/detr_tf/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .coco import load_coco_dataset, COCO_CLASS_NAME
2 | from .voc import load_voc_dataset
3 | from .tfcsv import load_tfcsv_dataset
4 | #import processing
5 | #import transformation
--------------------------------------------------------------------------------
/detr_tf/data/coco.py:
--------------------------------------------------------------------------------
1 | from pycocotools.coco import COCO
2 | import tensorflow as tf
3 | import numpy as np
4 | import imageio
5 | from skimage.color import gray2rgb
6 | from random import sample, shuffle
7 | import os
8 |
9 | from . import transformation
10 | from . import processing
11 | import matplotlib.pyplot as plt
12 |
13 | COCO_CLASS_NAME = [
14 | 'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
15 | 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
16 | 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
17 | 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
18 | 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
19 | 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
20 | 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
21 | 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
22 | 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
23 | 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
24 | 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
25 | 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
26 | 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
27 | 'toothbrush', "back"
28 | ]
29 |
30 | def get_coco_labels(coco, img_id, image_shape, augmentation):
31 | # Load the labels the instances
32 | ann_ids = coco.getAnnIds(imgIds=img_id)
33 | anns = coco.loadAnns(ann_ids)
34 | # Setup bbox
35 | bbox = []
36 | t_class = []
37 | crowd_bbox = 0
38 | for a, ann in enumerate(anns):
39 | bbox_x, bbox_y, bbox_w, bbox_h = ann['bbox']
40 | # target class
41 | t_cls = ann["category_id"]
42 | if ann["iscrowd"]:
43 | crowd_bbox = 1
44 | # Convert bbox to xc, yc, w, h formast
45 | x_center = bbox_x + (bbox_w / 2)
46 | y_center = bbox_y + (bbox_h / 2)
47 | x_center = x_center / float(image_shape[1])
48 | y_center = y_center / float(image_shape[0])
49 | bbox_w = bbox_w / float(image_shape[1])
50 | bbox_h = bbox_h / float(image_shape[0])
51 | # Add bbox and class
52 | bbox.append([x_center, y_center, bbox_w, bbox_h])
53 | t_class.append([t_cls])
54 | # Set bbox header
55 | bbox = np.array(bbox)
56 | t_class = np.array(t_class)
57 | return bbox.astype(np.float32), t_class.astype(np.int32), crowd_bbox
58 |
59 |
60 | def get_coco_from_id(coco_id, coco, augmentation, config, img_dir):
61 | # Load imag
62 | img = coco.loadImgs([coco_id])[0]
63 | # Load image
64 | #data_type = "train2017" if train_val == "train" else "val2017"
65 | filne_name = img['file_name']
66 | image_path = os.path.join(img_dir, filne_name) #f"{config.}/{data_type}/{filne_name}"
67 | image = imageio.imread(image_path)
68 | # Graycale to RGB if needed
69 | if len(image.shape) == 2: image = gray2rgb(image)
70 | # Retrieve the image label
71 | t_bbox, t_class, is_crowd = get_coco_labels(coco, img['id'], image.shape, augmentation)
72 | # Apply augmentations
73 | if len(t_bbox) > 0 and augmentation is not None:
74 | image, t_bbox, t_class = transformation.detr_transform(image, t_bbox, t_class, config, augmentation)
75 | # Normalized images
76 | image = processing.normalized_images(image, config)
77 | # Set type for tensorflow
78 | image = image.astype(np.float32)
79 | t_bbox = t_bbox.astype(np.float32)
80 | t_class = t_class.astype(np.int64)
81 | is_crowd = np.array(is_crowd, dtype=np.int64)
82 | return image, t_bbox, t_class, is_crowd
83 |
84 |
85 | def load_coco_dataset(config, batch_size, augmentation=False, ann_dir=None, ann_file=None, img_dir=None):
86 | """ Load a coco dataset
87 | """
88 | ann_dir = config.data.ann_dir if ann_dir is None else ann_dir
89 | ann_file = config.data.ann_file if ann_file is None else ann_file
90 | img_dir = config.data.img_dir if img_dir is None else img_dir
91 |
92 |
93 |
94 | coco = COCO(ann_file)
95 |
96 | # Extract CLASS names
97 | cats = coco.loadCats(coco.getCatIds())
98 | # Get the max class ID
99 | max_id = np.array([cat["id"] for cat in cats]).max()
100 | class_names = ["N/A"] * (max_id + 2) # + 2 for the background class
101 | # Add the backgrund class at the end
102 | class_names[-1] = "back"
103 | config.background_class = max_id + 1
104 | for cat in cats:
105 | class_names[cat["id"]] = cat["name"]
106 |
107 | # Setup the data pipeline
108 | img_ids = coco.getImgIds()
109 | shuffle(img_ids)
110 | dataset = tf.data.Dataset.from_tensor_slices(img_ids)
111 | # Shuffle the dataset
112 | dataset = dataset.shuffle(1000)
113 | # Retrieve img and labels
114 | outputs_types=(tf.float32, tf.float32, tf.int64, tf.int64)
115 | dataset = dataset.map(lambda idx: processing.numpy_fc(
116 | idx, get_coco_from_id, outputs_types=outputs_types, coco=coco, augmentation=augmentation, config=config, img_dir=img_dir)
117 | , num_parallel_calls=tf.data.experimental.AUTOTUNE)
118 | dataset = dataset.filter(lambda imgs, tbbox, tclass, iscrowd: tf.shape(tbbox)[0] > 0 and iscrowd != 1)
119 | dataset = dataset.map(lambda imgs, tbbox, tclass, iscrowd: (imgs, tbbox, tclass), num_parallel_calls=tf.data.experimental.AUTOTUNE)
120 |
121 | # Pad bbox and labels
122 | dataset = dataset.map(processing.pad_labels, num_parallel_calls=tf.data.experimental.AUTOTUNE)
123 |
124 | dataset = dataset.batch(batch_size, drop_remainder=True)
125 | dataset = dataset.prefetch(32)
126 |
127 | return dataset, class_names
--------------------------------------------------------------------------------
/detr_tf/data/processing.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 |
5 |
6 | def normalized_images(image, config):
7 | """ Normalized images. torch_resnet is used on finetuning
8 | since the weights are based on the original paper training code
9 | from pytorch. tf_resnet is used when training from scratch with a
10 | resnet50 traine don tensorflow.
11 | """
12 | if config.normalized_method == "torch_resnet":
13 | channel_avg = np.array([0.485, 0.456, 0.406])
14 | channel_std = np.array([0.229, 0.224, 0.225])
15 | image = (image / 255.0 - channel_avg) / channel_std
16 | return image.astype(np.float32)
17 | elif config.normalized_method == "tf_resnet":
18 | mean = [103.939, 116.779, 123.68]
19 | image = image[..., ::-1]
20 | image = image - mean
21 | return image.astype(np.float32)
22 | else:
23 | raise Exception("Can't handler thid normalized method")
24 |
25 |
26 | def numpy_fc(idx, fc, outputs_types=(tf.float32, tf.float32, tf.int64), **params):
27 | """
28 | Call a numpy function on each given ID (`idx`) and load the associated image and labels (bbbox and cls)
29 | """
30 | def _np_function(_idx):
31 | return fc(_idx, **params)
32 | return tf.numpy_function(_np_function, [idx], outputs_types)
33 |
34 |
35 | def pad_labels(images: tf.Tensor, t_bbox: tf.Tensor, t_class: tf.Tensor):
36 | """ Pad the bbox by adding [0, 0, 0, 0] at the end
37 | and one header to indicate how maby bbox are set.
38 | Do the same with the labels.
39 | """
40 | nb_bbox = tf.shape(t_bbox)[0]
41 |
42 | bbox_header = tf.expand_dims(nb_bbox, axis=0)
43 | bbox_header = tf.expand_dims(bbox_header, axis=0)
44 | bbox_header = tf.pad(bbox_header, [[0, 0], [0, 3]])
45 | bbox_header = tf.cast(bbox_header, tf.float32)
46 | cls_header = tf.constant([[0]], dtype=tf.int64)
47 |
48 | # Padd bbox and class
49 | t_bbox = tf.pad(t_bbox, [[0, 100 - 1 - nb_bbox], [0, 0]], mode='CONSTANT', constant_values=0)
50 | t_class = tf.pad(t_class, [[0, 100 - 1 - nb_bbox], [0, 0]], mode='CONSTANT', constant_values=0)
51 |
52 | t_bbox = tf.concat([bbox_header, t_bbox], axis=0)
53 | t_class = tf.concat([cls_header, t_class], axis=0)
54 |
55 | return images, t_bbox, t_class
56 |
57 |
--------------------------------------------------------------------------------
/detr_tf/data/tfcsv.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from random import shuffle
3 | import pandas as pd
4 | import numpy as np
5 | import imageio
6 | import os
7 |
8 | from .import processing
9 | from .transformation import detr_transform
10 | from .. import bbox
11 |
12 | def load_data_from_index(index, class_names, filenames, anns, config, augmentation, img_dir):
13 | # Open the image
14 | image = imageio.imread(os.path.join(config.data.data_dir, img_dir, filenames[index]))
15 | # Select all the annotatiom (bbox and class) on this image
16 | image_anns = anns[anns["filename"] == filenames[index]]
17 |
18 | # Convert all string class to number (the target class)
19 | t_class = image_anns["class"].map(lambda x: class_names.index(x)).to_numpy()
20 | # Select the width&height of each image (should be the same since all the ann belongs to the same image)
21 | width = image_anns["width"].to_numpy()
22 | height = image_anns["height"].to_numpy()
23 | # Select the xmin, ymin, xmax and ymax of each bbox, Then, normalized the bbox to be between and 0 and 1
24 | # Finally, convert the bbox from xmin,ymin,xmax,ymax to x_center,y_center,width,height
25 | bbox_list = image_anns[["xmin", "ymin", "xmax", "ymax"]].to_numpy()
26 | bbox_list = bbox_list / [width[0], height[0], width[0], height[0]]
27 | t_bbox = bbox.xy_min_xy_max_to_xcycwh(bbox_list)
28 |
29 | # Transform and augment image with bbox and class if needed
30 | image, t_bbox, t_class = detr_transform(image, t_bbox, t_class, config, augmentation=augmentation)
31 |
32 | # Normalized image
33 | image = processing.normalized_images(image, config)
34 |
35 | return image.astype(np.float32), t_bbox.astype(np.float32), np.expand_dims(t_class, axis=-1).astype(np.int64)
36 |
37 |
38 | def load_tfcsv_dataset(config, batch_size, augmentation=False, exclude=[], ann_dir=None, ann_file=None, img_dir=None):
39 | """ Load the hardhat dataset
40 | """
41 | ann_dir = config.data.ann_dir if ann_dir is None else ann_dir
42 | ann_file = config.data.ann_file if ann_file is None else ann_file
43 | img_dir = config.data.img_dir if img_dir is None else img_dir
44 |
45 | anns = pd.read_csv(os.path.join(config.data.data_dir, ann_file))
46 | for name in exclude:
47 | anns = anns[anns["class"] != name]
48 |
49 | unique_class = anns["class"].unique()
50 | unique_class.sort()
51 |
52 |
53 | # Set the background class to 0
54 | config.background_class = 0
55 | class_names = ["background"] + unique_class.tolist()
56 |
57 |
58 | filenames = anns["filename"].unique().tolist()
59 | indexes = list(range(0, len(filenames)))
60 | shuffle(indexes)
61 |
62 | dataset = tf.data.Dataset.from_tensor_slices(indexes)
63 | dataset = dataset.map(lambda idx: processing.numpy_fc(
64 | idx, load_data_from_index,
65 | class_names=class_names, filenames=filenames, anns=anns, config=config, augmentation=augmentation, img_dir=img_dir)
66 | ,num_parallel_calls=tf.data.experimental.AUTOTUNE)
67 |
68 |
69 | # Filter labels to be sure to keep only sample with at least one bbox
70 | dataset = dataset.filter(lambda imgs, tbbox, tclass: tf.shape(tbbox)[0] > 0)
71 | # Pad bbox and labels
72 | dataset = dataset.map(processing.pad_labels, num_parallel_calls=tf.data.experimental.AUTOTUNE)
73 | # Batch images
74 | dataset = dataset.batch(batch_size, drop_remainder=True)
75 |
76 | return dataset, class_names
77 |
78 |
--------------------------------------------------------------------------------
/detr_tf/data/transformation.py:
--------------------------------------------------------------------------------
1 | import imageio
2 | import imgaug as ia
3 | import imgaug.augmenters as iaa
4 | import numpy as np
5 |
6 | from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage
7 | from imgaug.augmentables.segmaps import SegmentationMapsOnImage
8 |
9 | import tensorflow as tf
10 |
11 | def bbox_xcyc_wh_to_imgaug_bbox(bbox, target_class, height, width):
12 |
13 | img_aug_bbox = []
14 |
15 | for b in range(0, len(bbox)):
16 | bbox_xcyc_wh = bbox[b]
17 | # Convert size form 0.1 to height/width
18 | bbox_xcyc_wh = [
19 | bbox_xcyc_wh[0] * width,
20 | bbox_xcyc_wh[1] * height,
21 | bbox_xcyc_wh[2] * width,
22 | bbox_xcyc_wh[3] * height
23 | ]
24 | x1 = bbox_xcyc_wh[0] - (bbox_xcyc_wh[2] / 2)
25 | x2 = bbox_xcyc_wh[0] + (bbox_xcyc_wh[2] / 2)
26 | y1 = bbox_xcyc_wh[1] - (bbox_xcyc_wh[3] / 2)
27 | y2 = bbox_xcyc_wh[1] + (bbox_xcyc_wh[3] / 2)
28 |
29 | n_bbox = BoundingBox(x1=x1, y1=y1, x2=x2, y2=y2, label=target_class[b])
30 |
31 | img_aug_bbox.append(n_bbox)
32 | img_aug_bbox
33 | return img_aug_bbox
34 |
35 |
36 | def prepare_aug_inputs(image, bbox, t_class):
37 |
38 | images_batch = []
39 | bbox_batch = []
40 |
41 | images_batch.append(image)
42 |
43 | # Create the Imgaug bbox
44 | bbs_original = bbox_xcyc_wh_to_imgaug_bbox(bbox, t_class, image.shape[0], image.shape[1])
45 | bbs_original = BoundingBoxesOnImage(bbs_original, shape=image.shape)
46 | bbox_batch.append(bbs_original)
47 |
48 | for i in range(len(images_batch)):
49 | images_batch[i] = images_batch[i].astype(np.uint8)
50 |
51 | return images_batch, bbox_batch
52 |
53 |
54 | def detr_aug_seq(image, config, augmenation):
55 |
56 |
57 | sometimes = lambda aug: iaa.Sometimes(0.5, aug)
58 |
59 | target_min_side_size = 480
60 |
61 | # According to the paper
62 | min_side_min = 480
63 | min_side_max = 800
64 | max_side_max = 1333
65 |
66 | image_size = config.image_size
67 | if augmenation:
68 |
69 | seq = iaa.Sequential([
70 | iaa.Fliplr(0.5), # horizontal flips
71 | sometimes(iaa.OneOf([
72 | # Resize complety the image
73 | iaa.Resize({"width": image_size[1], "height": image_size[0]}, interpolation=ia.ALL),
74 | # Crop into the image
75 | iaa.CropToFixedSize(image_size[1], image_size[0]),
76 | # Affine transform
77 | iaa.Affine(
78 | scale={"x": (0.5, 1.5), "y": (0.5, 1.5)},
79 | )
80 | ])),
81 | # Be sure to resize to the target image size
82 | iaa.Resize({"width": image_size[1], "height": image_size[0]}, interpolation=ia.ALL)
83 | ], random_order=False) # apply augmenters in random order
84 |
85 | return seq
86 |
87 | else:
88 |
89 | seq = iaa.Sequential([
90 | # Be sure to resize to the target image size
91 | iaa.Resize({"width": image_size[1], "height": image_size[0]})
92 | ], random_order=False) # apply augmenters in random order
93 |
94 | return seq
95 |
96 | """ Mode paper evaluation
97 | # Evaluation mode, we took the largest min side the model is trained on
98 | target_min_side_size = 480
99 | image_min_side = min(float(image.shape[0]), float(image.shape[1]))
100 | image_max_side = max(float(image.shape[0]), float(image.shape[1]))
101 |
102 | min_side_scaling = target_min_side_size / image_min_side
103 | max_side_scaling = max_side_max / image_max_side
104 | scaling = min(min_side_scaling, max_side_scaling)
105 |
106 | n_height = int(scaling * image.shape[0])
107 | n_width = int(scaling * image.shape[1])
108 |
109 | seq = iaa.Sequential([
110 | iaa.Resize({"height": n_height, "width": n_width}),
111 | ])
112 | """
113 |
114 | return seq
115 |
116 |
117 | def imgaug_bbox_to_xcyc_wh(bbs_aug, height, width):
118 |
119 | bbox_xcyc_wh = []
120 | t_class = []
121 |
122 | nb_bbox = 0
123 |
124 | for b, bbox in enumerate(bbs_aug):
125 |
126 | h = bbox.y2 - bbox.y1
127 | w = bbox.x2 - bbox.x1
128 | xc = bbox.x1 + (w/2)
129 | yc = bbox.y1 + (h/2)
130 |
131 | assert bbox.label != None
132 |
133 | bbox_xcyc_wh.append([xc / width, yc / height, w / width, h / height])
134 | t_class.append(bbox.label)
135 |
136 | nb_bbox += 1
137 |
138 | #bbox_xcyc_wh[0][0] = nb_bbox
139 | bbox_xcyc_wh = np.array(bbox_xcyc_wh)
140 |
141 | return bbox_xcyc_wh, t_class
142 |
143 |
144 | def retrieve_outputs(augmented_images, augmented_bbox):
145 |
146 | outputs_dict = {}
147 | image_shape = None
148 |
149 |
150 | # We expect only one image here for now
151 | image = augmented_images[0].astype(np.float32)
152 | augmented_bbox = augmented_bbox[0]
153 |
154 | bbox, t_class = imgaug_bbox_to_xcyc_wh(augmented_bbox, image.shape[0], image.shape[1])
155 |
156 | bbox = np.array(bbox)
157 | t_class = np.array(t_class)
158 |
159 | return image, bbox, t_class
160 |
161 |
162 |
163 | def detr_transform(image, bbox, t_class, config, augmentation):
164 |
165 |
166 | images_batch, bbox_batch = prepare_aug_inputs(image, bbox, t_class)
167 |
168 |
169 | seq = detr_aug_seq(image, config, augmentation)
170 |
171 | # Run the pipeline in a deterministic manner
172 | seq_det = seq.to_deterministic()
173 |
174 | augmented_images = []
175 | augmented_bbox = []
176 | augmented_class = []
177 |
178 | for img, bbox, t_cls in zip(images_batch, bbox_batch, t_class):
179 |
180 | img_aug = seq_det.augment_image(img)
181 | bbox_aug = seq_det.augment_bounding_boxes(bbox)
182 |
183 |
184 | for b, bbox_instance in enumerate(bbox_aug.items):
185 | setattr(bbox_instance, "instance_id", b+1)
186 |
187 | bbox_aug = bbox_aug.remove_out_of_image_fraction(0.7)
188 | segmap_aug = None
189 | bbox_aug = bbox_aug.clip_out_of_image()
190 |
191 |
192 | augmented_images.append(img_aug)
193 | augmented_bbox.append(bbox_aug)
194 |
195 | return retrieve_outputs(augmented_images, augmented_bbox)
196 |
197 |
198 |
--------------------------------------------------------------------------------
/detr_tf/data/voc.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import xml.etree.ElementTree as ET
3 | from random import sample, shuffle
4 | import numpy as np
5 | import imageio
6 | import numpy as np
7 | import os
8 |
9 |
10 | from . import processing
11 | from . import transformation
12 |
13 |
14 | VOC_CLASS_NAME = [
15 | "back", 'aeroplane', 'bicycle', 'bird', 'boat',
16 | 'bottle', 'bus', 'car', 'cat', 'chair',
17 | 'cow', 'diningtable', 'dog', 'horse',
18 | 'motorbike', 'person', 'pottedplant',
19 | 'sheep', 'sofa', 'train', 'tvmonitor'
20 | ]
21 |
22 | def load_voc_labels(img_id, class_names, voc_dir, augmentation, config):
23 |
24 | anno_path = os.path.join(voc_dir, config.data.ann_dir, img_id + '.xml')
25 | objects = ET.parse(anno_path).findall('object')
26 | size = ET.parse(anno_path).find('size')
27 | width = float(size.find("width").text)
28 | height = float(size.find("height").text)
29 |
30 | # Set up bbox with headers
31 | t_bbox = []
32 | t_class = []
33 |
34 | for obj in objects:
35 | # Open bbox and retrieve info
36 | name = obj.find('name').text.lower().strip()
37 | bndbox = obj.find('bndbox')
38 | xmin = (float(bndbox.find('xmin').text) - 1) / width
39 | ymin = (float(bndbox.find('ymin').text) - 1) / height
40 | xmax = (float(bndbox.find('xmax').text) - 1) / width
41 | ymax = (float(bndbox.find('ymax').text) - 1) / height
42 | # Convert bbox to xc, yc center
43 | xc = xmin + ((xmax - xmin) / 2)
44 | yc = ymin + ((ymax - ymin) / 2)
45 | w = xmax - xmin
46 | h = ymax - ymin
47 | # Add bbox
48 | t_bbox.append([xc, yc, w, h])
49 | # Add target class
50 | t_class.append([class_names.index(name)])
51 |
52 | t_bbox = np.array(t_bbox)
53 | t_class = np.array(t_class)
54 |
55 | return t_bbox, t_class
56 |
57 |
58 | def load_voc_from_id(img_id, class_names, voc_dir, augmentation, config, img_dir):
59 | img_id = str(img_id.decode())
60 | # Load image
61 | img_path = os.path.join(voc_dir, config.data.img_dir, img_id + '.jpg')
62 | image = imageio.imread(img_path)
63 | # Load labels
64 | t_bbox, t_class = load_voc_labels(img_id, class_names, voc_dir, augmentation, config)
65 | # Apply augmentations
66 | if augmentation is not None:
67 | image, t_bbox, t_class = transformation.detr_transform(image, t_bbox, t_class, config, augmentation)
68 | # Normalized images
69 | image = processing.normalized_images(image, config)
70 | # Set type for tensorflow
71 | image = image.astype(np.float32)
72 | t_bbox = t_bbox.astype(np.float32)
73 | t_class = t_class.astype(np.int64)
74 |
75 |
76 | return (image, t_bbox, t_class)
77 |
78 |
79 | def load_voc_dataset(config, batch_size, augmentation=False, ann_dir=None, ann_file=None, img_dir=None):
80 | """
81 | """
82 | ann_dir = config.data.ann_dir if ann_dir is None else ann_dir
83 | ann_file = config.data.ann_file if ann_file is None else ann_file
84 | img_dir = config.data.img_dir if img_dir is None else img_dir
85 |
86 | # Set the background class to 0
87 | config.background_class = 0
88 |
89 | image_dir = os.path.join(config.data.data_dir, img_dir)
90 | anno_dir = os.path.join(config.data.data_dir, ann_dir)
91 | # ids lists
92 | ids = list(map(lambda x: x[:-4], os.listdir(image_dir)))
93 |
94 | # Retrieve the class names in the dataset
95 | class_names = ['back']
96 | for img_id in ids:
97 | anno_path = os.path.join(config.data.data_dir, anno_dir, img_id + '.xml')
98 | for obj in ET.parse(anno_path).findall('object'):
99 | # Open bbox and retrieve info
100 | name = obj.find('name').text.lower().strip()
101 | if name not in class_names:
102 | try: # Faster than checking
103 | class_names[name]
104 | except:
105 | class_names.append(name)
106 |
107 | ids = list(map(lambda x: x[:-4], os.listdir(image_dir)))
108 |
109 | #ids = ids[:int(len(ids) * 0.75)] if train_val == "train" else ids[int(len(ids) * 0.75):]
110 | # Shuffle all the dataset
111 | shuffle(ids)
112 |
113 | # Setup data pipeline
114 | dataset = tf.data.Dataset.from_tensor_slices(ids)
115 | dataset = dataset.shuffle(1000)
116 | # Retrieve img and labels
117 | dataset = dataset.map(lambda idx: processing.numpy_fc(idx, load_voc_from_id,
118 | class_names=class_names, voc_dir=config.data.data_dir, augmentation=augmentation, config=config, img_dir=img_dir)
119 | , num_parallel_calls=tf.data.experimental.AUTOTUNE)
120 | # Filter labels to be sure to keep only sample with at least one bbox
121 | dataset = dataset.filter(lambda imgs, tbbox, tclass: tf.shape(tbbox)[0] > 0)
122 | # Pad bbox and labels
123 | dataset = dataset.map(processing.pad_labels, num_parallel_calls=tf.data.experimental.AUTOTUNE)
124 | # Batch images
125 | dataset = dataset.batch(batch_size, drop_remainder=True)
126 | # Prefetch
127 | dataset = dataset.prefetch(32)
128 | return dataset, class_names
--------------------------------------------------------------------------------
/detr_tf/inference.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import cv2
4 |
5 |
6 | CLASS_COLOR_MAP = np.random.randint(0, 255, (100, 3))
7 |
8 | from detr_tf import bbox
9 |
10 | def numpy_bbox_to_image(image, bbox_list, labels=None, scores=None, class_name=[], config=None):
11 | """ Numpy function used to display the bbox (target or prediction)
12 | """
13 | assert(image.dtype == np.float32 and image.dtype == np.float32 and len(image.shape) == 3)
14 |
15 | if config is not None and config.normalized_method == "torch_resnet":
16 | channel_avg = np.array([0.485, 0.456, 0.406])
17 | channel_std = np.array([0.229, 0.224, 0.225])
18 | image = (image * channel_std) + channel_avg
19 | image = (image*255).astype(np.uint8)
20 | elif config is not None and config.normalized_method == "tf_resnet":
21 | image = image + mean
22 | image = image[..., ::-1]
23 | image = image / 255
24 |
25 | bbox_xcycwh = bbox.np_rescale_bbox_xcycwh(bbox_list, (image.shape[0], image.shape[1]))
26 | bbox_x1y1x2y2 = bbox.np_xcycwh_to_xy_min_xy_max(bbox_xcycwh)
27 |
28 | # Set the labels if not defined
29 | if labels is None: labels = np.zeros((bbox_x1y1x2y2.shape[0]))
30 |
31 | bbox_area = []
32 | # Go through each bbox
33 | for b in range(0, bbox_x1y1x2y2.shape[0]):
34 | x1, y1, x2, y2 = bbox_x1y1x2y2[b]
35 | bbox_area.append((x2-x1)*(y2-y1))
36 |
37 | # Go through each bbox
38 | for b in np.argsort(bbox_area)[::-1]:
39 | # Take a new color at reandon for this instance
40 | instance_color = np.random.randint(0, 255, (3))
41 |
42 |
43 | x1, y1, x2, y2 = bbox_x1y1x2y2[b]
44 | x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
45 | x1, y1, x2, y2 = max(0, x1), max(0, y1), min(image.shape[1], x2), min(image.shape[0], y2)
46 |
47 | # Select the class associated with this bbox
48 | class_id = labels[int(b)]
49 |
50 | if scores is not None and len(scores) > 0:
51 | label_name = class_name[int(class_id)]
52 | label_name = "%s:%.2f" % (label_name, scores[b])
53 | else:
54 | label_name = class_name[int(class_id)]
55 |
56 | class_color = CLASS_COLOR_MAP[int(class_id)]
57 |
58 | color = instance_color
59 |
60 | multiplier = image.shape[0] / 500
61 | cv2.rectangle(image, (x1, y1), (x1 + int(multiplier*15)*len(label_name), y1 + 20), class_color.tolist(), -10)
62 | cv2.putText(image, label_name, (x1+2, y1 + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6 * multiplier, (0, 0, 0), 1)
63 | cv2.rectangle(image, (x1, y1), (x2, y2), tuple(class_color.tolist()), 2)
64 |
65 | return image
66 |
67 |
68 | def get_model_inference(m_outputs: dict, background_class, bbox_format="xy_center"):
69 |
70 | predicted_bbox = m_outputs["pred_boxes"][0]
71 | predicted_labels = m_outputs["pred_logits"][0]
72 |
73 | softmax = tf.nn.softmax(predicted_labels)
74 | predicted_scores = tf.reduce_max(softmax, axis=-1)
75 | predicted_labels = tf.argmax(softmax, axis=-1)
76 |
77 |
78 | indices = tf.where(predicted_labels != background_class)
79 | indices = tf.squeeze(indices, axis=-1)
80 |
81 | predicted_scores = tf.gather(predicted_scores, indices)
82 | predicted_labels = tf.gather(predicted_labels, indices)
83 | predicted_bbox = tf.gather(predicted_bbox, indices)
84 |
85 |
86 | if bbox_format == "xy_center":
87 | predicted_bbox = predicted_bbox
88 | elif bbox_format == "xyxy":
89 | predicted_bbox = bbox.xcycwh_to_xy_min_xy_max(predicted_bbox)
90 | elif bbox_format == "yxyx":
91 | predicted_bbox = bbox.xcycwh_to_yx_min_yx_max(predicted_bbox)
92 | else:
93 | raise NotImplementedError()
94 |
95 | return predicted_bbox, predicted_labels, predicted_scores
96 |
--------------------------------------------------------------------------------
/detr_tf/logger/training_logging.py:
--------------------------------------------------------------------------------
1 | """
2 | This script is used to process the results of the training (loss, model outputs and targets)
3 | In order to send everything on wandb.
4 | """
5 | from typing import Union,Dict,Tuple
6 | import tensorflow as tf
7 |
8 | from ..loss.compute_map import cal_map, calc_map, APDataObject
9 | from .wandb_logging import WandbSender
10 | from ..inference import get_model_inference
11 |
12 |
13 | import numpy as np
14 | import cv2
15 |
16 | from ..import bbox
17 |
18 | if int(tf.__version__.split('.')[1]) >= 4:
19 | RAGGED = True
20 | else:
21 | RAGGED = False
22 |
23 |
24 | def tf_send_batch_log_to_wandb(images, target_bbox, target_class, m_outputs: dict, config, class_name=[], step=None, prefix=""):
25 |
26 | # Warning: In graph mode, this class is init only once. In eager mode, this class is init at each step.
27 | img_sender = WandbSender()
28 |
29 | predicted_bbox = m_outputs["pred_boxes"]
30 | for b in range(predicted_bbox.shape[0]):
31 | # Select within the batch the elements at indice b
32 | image = images[b]
33 |
34 | elem_m_outputs = {key:m_outputs[key][b:b+1] if (m_outputs[key] is not None and not isinstance(m_outputs[key], list)) else m_outputs[key] for key in m_outputs}
35 |
36 | # Target
37 | t_bbox, t_class = target_bbox[b], target_class[b]
38 |
39 | if not RAGGED:
40 | size = tf.cast(t_bbox[0][0], tf.int32)
41 | t_bbox = tf.slice(t_bbox, [1, 0], [size, 4])
42 | t_bbox = bbox.xcycwh_to_xy_min_xy_max(t_bbox)
43 | t_class = tf.slice(t_class, [1, 0], [size, -1])
44 | t_class = tf.squeeze(t_class, axis=-1)
45 |
46 | # Predictions
47 | predicted_bbox, predicted_labels, predicted_scores = get_model_inference(elem_m_outputs, config.background_class, bbox_format="xyxy")
48 |
49 | np_func_params = {
50 | "image": image, "p_bbox": np.array(predicted_bbox), "p_scores": np.array(predicted_scores), "t_bbox": np.array(t_bbox),
51 | "p_labels": np.array(predicted_labels), "t_labels": np.array(t_class), "class_name": class_name
52 | }
53 | img_sender.gather_inference(**np_func_params)
54 |
55 | img_sender.send(step=step, prefix=prefix)
56 |
57 |
58 |
59 | def compute_map_on_batch(images, target_bbox, target_class, m_outputs: dict, config, class_name=[], step=None, send=True, prefix=""):
60 | predicted_bbox = m_outputs["pred_boxes"]
61 | batch_size = predicted_bbox.shape[0]
62 | for b in range(batch_size):
63 |
64 | image = images[b]
65 | elem_m_outputs = {key:m_outputs[key][b:b+1] if (m_outputs[key] is not None and not isinstance(m_outputs[key], list)) else m_outputs[key] for key in m_outputs}
66 |
67 | # Target
68 | t_bbox, t_class = target_bbox[b], target_class[b]
69 |
70 | if not RAGGED:
71 | size = tf.cast(t_bbox[0][0], tf.int32)
72 | t_bbox = tf.slice(t_bbox, [1, 0], [size, 4])
73 | t_bbox = bbox.xcycwh_to_yx_min_yx_max(t_bbox)
74 | t_class = tf.slice(t_class, [1, 0], [size, -1])
75 | t_class = tf.squeeze(t_class, axis=-1)
76 |
77 | # Inference ops
78 | predicted_bbox, predicted_labels, predicted_scores = get_model_inference(elem_m_outputs, config.background_class, bbox_format="yxyx")
79 | pred_mask = None
80 |
81 | pred_mask = np.zeros((138, 138, len(predicted_bbox)))
82 | target_mask = np.zeros((138, 138, len(t_bbox)))
83 | WandbSender.compute_map(
84 | np.array(predicted_bbox),
85 | np.array(predicted_labels), np.array(predicted_scores),
86 | np.array(t_bbox),
87 | np.array(t_class),
88 | b, batch_size, prefix, step, send, pred_mask, target_mask)
89 |
90 |
91 |
92 | def train_log(images, t_bbox, t_class, m_outputs: dict, config, step, class_name=[], prefix="train/"):
93 | # Every 1000 steps, log some progress of the training
94 | # (Images with bbox and images logs)
95 | if step % 100 == 0:
96 | tf_send_batch_log_to_wandb(images, t_bbox, t_class, m_outputs, config, class_name=class_name, step=step, prefix=prefix)
97 |
98 |
99 | def valid_log(images, t_bbox, t_class, m_outputs: dict, config, step, global_step, class_name=[], evaluation_step=200, prefix="train/"):
100 |
101 | # Set the number of class
102 | WandbSender.init_ap_data(nb_class=len(class_name))
103 | map_list = compute_map_on_batch(images, t_bbox, t_class, m_outputs, config, class_name=class_name, step=global_step, send=(step+1==evaluation_step), prefix="val/")
104 |
105 | if step == 0:
106 | tf_send_batch_log_to_wandb(images, t_bbox, t_class, m_outputs, config, class_name=class_name, step=global_step, prefix="val/")
107 |
--------------------------------------------------------------------------------
/detr_tf/logger/wandb_logging.py:
--------------------------------------------------------------------------------
1 | """
2 | This scripts is used to send training logs to Wandb.
3 | """
4 | from typing import Union,Dict,Tuple
5 | import tensorflow as tf
6 | import numpy as np
7 |
8 | try:
9 | # Should be optional
10 | import wandb
11 | except:
12 | wandb = None
13 |
14 | import cv2
15 |
16 | from ..loss.compute_map import cal_map, calc_map, APDataObject
17 |
18 | class WandbSender(object):
19 | """
20 | Class used within the Yolact project to send data to Wandb to
21 | log experiments.
22 | """
23 |
24 | IOU_THRESHOLDS = [x / 100. for x in range(50, 100, 5)]
25 | AP_DATA = None
26 | NB_CLASS = None
27 |
28 | def __init__(self):
29 | self.init_buffer()
30 |
31 | @staticmethod
32 | def init_ap_data(nb_class=None):
33 | """ Init the ap data used to compute the Map metrics.
34 | If nb_class is not provided, used the last provided nb_class.
35 | """
36 | if nb_class is not None:
37 | WandbSender.NB_CLASS = nb_class
38 |
39 | if WandbSender.NB_CLASS is None:
40 | raise ValueError("NB_CLASS is not sed in WandbSender")
41 |
42 | if WandbSender.AP_DATA is None:
43 | WandbSender.AP_DATA = {
44 | 'box' : [[APDataObject() for _ in [f"class_{i}" for i in range(WandbSender.NB_CLASS)]] for _ in [x / 100. for x in range(50, 100, 5)]],
45 | 'mask': [[APDataObject() for _ in [f"class_{i}" for i in range(WandbSender.NB_CLASS)]] for _ in [x / 100. for x in range(50, 100, 5)]]
46 | }
47 |
48 |
49 | def init_buffer(self):
50 | """ Init list used to store the information from a batch of data.
51 | Onced the list is filled, the send method
52 | send all images online.
53 | """
54 | self.images = []
55 | self.queries = []
56 | self.images_mask_ground_truth = []
57 | self.images_mask_prediction = []
58 | self.p_labels_batch = []
59 | self.t_labels_batch = []
60 | self.batch_mAP = []
61 |
62 |
63 | @staticmethod
64 | @tf.autograph.experimental.do_not_convert()
65 | def compute_map(p_bbox: np.array, p_labels: np.array, p_scores: np.array, t_bbox: np.array, t_labels: np.array,
66 | b: int, batch: int, prefix: str, step: int, send: bool,
67 | p_mask: np.array, t_mask: np.array):
68 | """
69 | For some reason, autograph is trying to understand what I'm doing here. With some failure.
70 | Thus, @tf.autograph.experimental.do_not_convert() is used to prevent autograph to scan this method.
71 |
72 | Args:
73 | p_bbox/t_bbox: List of bbox (n, 4) [y1, x2, y2, x2]
74 | p_labels/t_labels: List of labels index (n)
75 | p_mask/t_mask: predicted/target mask (n, h, w) with h and w the size of the mask
76 | p_scores: List of predicted scores (n)
77 | b: Batch indice
78 | batch: size of a batch
79 | prefix: Prefix to use to log something on wandb
80 | step: Step number
81 | send: Whether to send the result of all computed map to wandb.
82 | """
83 |
84 | # Init Ap Data
85 | if WandbSender.AP_DATA is None:
86 | WandbSender.init_ap_data()
87 |
88 | # Set fake class name. (we do not really need the real name of each class at this point)
89 | class_name = [f"class_{i}" for i in range(WandbSender.NB_CLASS)]
90 |
91 | try:
92 | # Compyute
93 | cal_map(p_bbox, p_labels, p_scores, p_mask, t_bbox, t_labels, t_mask, WandbSender.AP_DATA, WandbSender.IOU_THRESHOLDS)
94 |
95 | # Last element of the validation set.
96 |
97 | if send and b + 1 == batch:
98 |
99 | all_maps = calc_map(WandbSender.AP_DATA, WandbSender.IOU_THRESHOLDS, class_name, print_result=True)
100 | wandb.log({
101 | f"val/map50_bbox": all_maps["box"][50],
102 | f"val/map50_mask": all_maps["mask"][50],
103 | f"val/map_bbox": all_maps["box"]["all"],
104 | f"val/map_mask": all_maps["mask"]["all"]
105 | }, step=step)
106 | wandb.run.summary.update({
107 | f"val/map50_bbox": all_maps["box"][50],
108 | f"val/map50_mask": all_maps["mask"][50],
109 | f"val/map_bbox": all_maps["box"]["all"],
110 | f"val/map_mask": all_maps["mask"]["all"]
111 | })
112 |
113 |
114 | WandbSender.AP_DATA = None
115 | WandbSender.init_ap_data()
116 |
117 | return np.array([0.0, 0.0], np.float64)
118 |
119 | except Exception as e:
120 | print("compute_map error. e=", e)
121 | #raise e
122 | return np.array([0.0, 0.0], np.float64)
123 | return np.array([0.0, 0.0], np.float64)
124 |
125 |
126 | def get_wandb_bbox_mask_image(self, image: np.array, bbox: np.array, labels : np.array, masks=None, scores=None, class_name=[]) -> Tuple[list, np.array]:
127 | """
128 | Serialize the model inference into a dict and an image ready to be send to wandb.
129 | Args:
130 | image: (550, 550, 3)
131 | bbox: List of bbox (n, 4) [x1, y2, x2, y2]
132 | labels: List of labels index (n)
133 | masks: predicted/target mask (n, h, w) with h and w the size of the mask
134 | scores: List of predicted scores (n) (Optional)
135 | class_name; List of class name for each label
136 | Return:
137 | A dict with the box data for wandb
138 | and a copy of the original image with the instance masks
139 | """
140 | height, width = image.shape[0], image.shape[1]
141 | image_mask = np.copy(image)
142 | instance_id = 1
143 | box_data = []
144 |
145 | for b in range(len(bbox)):
146 | # Sample a new color for the mask instance
147 | instance_color = np.random.uniform(0, 1, (3))
148 | # Retrive bbox coordinates
149 | x1, y1, x2, y2 = bbox[b]
150 | x1, y1, x2, y2 = float(x1), float(y1), float(x2), float(y2)
151 |
152 | # Fill the mask
153 | if masks is not None:
154 | mask = masks[:,:,b]
155 | mask = cv2.resize(mask, (width, height))
156 | mask = mask[int(y1*height):int(y2*height),int(x1*width):int(x2*width)]
157 | image_mask[int(y1*height):int(y2*height),int(x1*width):int(x2*width)][mask > 0.5] = 0.5*image[int(y1*height):int(y2*height),int(x1*width):int(x2*width)][mask > 0.5] + 0.5*instance_color
158 |
159 | image_mask = cv2.rectangle(image_mask, (int(x1*width), int(y1*height)), (int(x2*width), int(y2*height)), (1, 1, 0), 3)
160 |
161 |
162 | #if scores is None:
163 | box_caption = "%s" % (class_name[int(labels[b])])
164 | #else:
165 | # box_caption = "%s-{:.2f}" % (class_name[int(labels[b])], float(scores[b]))
166 |
167 |
168 | box_dict = {
169 | "position": {"minX": x1, "maxX": x2, "minY": y1, "maxY": y2},
170 | "class_id" : int(labels[b]),
171 | "box_caption" : box_caption
172 | }
173 | # b < len(scores) for some reason sometime scores is not of the same length than the bbox
174 | if scores is not None and b < len(scores):
175 | box_dict["scores"] = {"conf": float(scores[b])}
176 | #print("append", box_dict)
177 | box_data.append(box_dict)
178 | instance_id += 1
179 |
180 | return box_data, image_mask
181 |
182 | def gather_inference(self, image: np.array, p_bbox: np.array, p_scores: np.array, t_bbox: np.array,
183 | p_labels: np.array, t_labels: np.array, p_masks=None, t_masks=None, class_name=[]):
184 | self.class_name = class_name
185 |
186 | # This is what wandb expext to get as input to display images with bbox.
187 | boxes = {
188 | "ground_truth": {"box_data": []},
189 | "predictions": {"box_data": []}
190 | }
191 |
192 | # Ground Truth
193 | box_data, _ = self.get_wandb_bbox_mask_image(image, t_bbox, t_labels, t_masks, class_name=class_name, scores=p_scores)
194 | boxes["ground_truth"]["box_data"] = box_data
195 | boxes["ground_truth"]["class_labels"] = {_id:str(label) for _id, label in enumerate(class_name)}
196 |
197 | # Predictions
198 | box_data, _ = self.get_wandb_bbox_mask_image(image, p_bbox, p_labels, p_masks, class_name=class_name, scores=p_scores)
199 | boxes["predictions"]["box_data"] = box_data
200 | boxes["predictions"]["class_labels"] = {_id:str(label) for _id, label in enumerate(class_name)}
201 |
202 |
203 | # Append the target and the predictions to the buffer
204 | self.images.append(wandb.Image(image, boxes=boxes))
205 |
206 | return np.array(0, dtype=np.int64)
207 |
208 | @tf.autograph.experimental.do_not_convert()
209 | def send(self, step: tf.Tensor, prefix=""):
210 | """
211 | For some reason, autograph is trying to understand what I'm doing here. With some failure.
212 | Thus, @tf.autograph.experimental.do_not_convert() is used to prevent autograph to scan this method.
213 |
214 | Send the buffer to wandb
215 | Args:
216 | step: The global training step as eager tensor
217 | prefix: Prefix used before each log name.
218 | """
219 | step = int(step)
220 |
221 | wandb.log({f"{prefix}Images bbox": self.images}, step=step)
222 |
223 | if len(self.batch_mAP) > 0:
224 | wandb.log({f"{prefix}mAp": np.mean(self.batch_mAP)}, step=step)
225 |
226 | self.init_buffer()
227 |
228 | return np.array(0, dtype=np.int64)
229 |
230 | @staticmethod
231 | @tf.autograph.experimental.do_not_convert()
232 | def send_depth(depth_map, step: np.array, prefix=""):
233 | """
234 | For some reason, autograph is trying to understand what I'm doing here. With some failure.
235 | Thus, @tf.autograph.experimental.do_not_convert() is used to prevent autograph to scan this method.
236 |
237 | Send the depth map to wandb
238 | Args:
239 | depth_map: (8, h, w, 1) Depth images used to train the model
240 | step: The global training step as eager tensor
241 | prefix: Prefix used before each log name
242 | """
243 | step = int(step)
244 | depth_map_images = []
245 | for depth in depth_map:
246 | depth_map_images.append(wandb.Image(depth))
247 | wandb.log({f"{prefix}Depth map": depth_map_images}, step=step)
248 | return np.array(0, dtype=np.int64)
249 |
250 |
251 | @staticmethod
252 | @tf.autograph.experimental.do_not_convert()
253 | def send_proto_sample(proto_map: np.array, proto_sample: np.array, proto_targets: np.array, step: np.array, prefix=""):
254 | """
255 | For some reason, autograph is trying to understand what I'm doing here. With some failure.
256 | Thus, @tf.autograph.experimental.do_not_convert() is used to prevent autograph to scan this method.
257 |
258 | Send the proto images logs to wandb.
259 | Args:
260 | proto_map: The k (32 by default) proto map of the proto network (h, w, k)
261 | proto_sample: Some generated mask from the network for a batch (n, h, w) with n the number of mask
262 | proto_targets: The target mask for each generated mask. (n, h, w) with n the number of mask
263 | step: The global training step as eager tensor
264 | prefix: Prefix used before each log name
265 | """
266 | step = int(step)
267 |
268 | proto_map_images = []
269 | proto_sample_images = []
270 | proto_targets_images = []
271 |
272 | for p in range(proto_map.shape[-1]):
273 | proto_map_images.append(wandb.Image(np.clip(proto_map[:,:,p]*100, 0, 255)))
274 | for p in range(len(proto_sample)):
275 | proto_sample_images.append(wandb.Image(proto_sample[p]))
276 | proto_targets_images.append(wandb.Image(proto_targets[p]))
277 |
278 | wandb.log({f"{prefix}Proto Map": proto_map_images}, step=step)
279 | wandb.log({f"{prefix}Instance segmentation prediction": proto_sample_images}, step=step)
280 | wandb.log({f"{prefix}Instance segmentation target": proto_targets_images}, step=step)
281 | return np.array(0, dtype=np.int64)
282 |
283 |
284 |
285 | @staticmethod
286 | @tf.autograph.experimental.do_not_convert()
287 | def send_images(images, step: np.array, name: str, captions=None, masks_prediction=None, masks_target=None):
288 | """
289 | For some reason, autograph is trying to understand what I'm doing here. With some failure.
290 | Thus, @tf.autograph.experimental.do_not_convert() is used to prevent autograph to scan this method.
291 |
292 | Send some images to wandb
293 | Args:
294 | images: (8, h, w, c) Images to log in wandb
295 | step: The global training step as eager tensor
296 | name: Image names
297 | """
298 | class_labels = {
299 | 0: "background",
300 | 1: "0",
301 | 2: "1",
302 | 3: "2",
303 | 4: "3",
304 | 5: "4",
305 | 6: "5",
306 | 7: "6",
307 | 8: "7",
308 | 9: "8",
309 | 10: "9"
310 | }
311 |
312 | step = int(step)
313 | images_list = []
314 | for i, img in enumerate(images):
315 | img_params = {}
316 | if captions is not None:
317 | img_params["caption"] = captions[i]
318 |
319 | if masks_prediction is not None:
320 | mask_pred = cv2.resize(masks_prediction[i], (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST)
321 | mask_pred = mask_pred.astype(np.int32)
322 | if "masks" not in img_params:
323 | img_params["masks"] = {}
324 |
325 | #seg = np.expand_dims(masks[i].astype(np.int32), axis=-1)
326 | img_params["masks"]["predictions"] = {
327 | "mask_data": mask_pred,
328 | "class_labels": class_labels
329 | }
330 |
331 |
332 | if masks_target is not None:
333 | if "masks" not in img_params:
334 | img_params["masks"] = {}
335 |
336 | mask_target = masks_target[i].astype(np.int32)
337 | #seg = np.expand_dims(masks[i].astype(np.int32), axis=-1)
338 | print(mask_target.shape)
339 | img_params["masks"]["groud_truth"] = {
340 | "mask_data": mask_target,
341 | "class_labels": class_labels
342 | }
343 |
344 | images_list.append(wandb.Image(img, **img_params))
345 |
346 | wandb.log({name: images_list}, step=step)
347 | return np.array(0, dtype=np.int64)
348 |
349 |
--------------------------------------------------------------------------------
/detr_tf/loss/compute_map.py:
--------------------------------------------------------------------------------
1 | from scipy.special import softmax
2 | import matplotlib.pyplot as plt
3 | from itertools import product
4 | import tensorflow as tf
5 | import numpy as np
6 | import argparse
7 | import random
8 | import json
9 | import time
10 | import cv2
11 | import os
12 |
13 | from ..import bbox
14 | from collections import OrderedDict
15 |
16 |
17 | class APDataObject:
18 | """ Stores all the information necessary to calculate the AP for one IoU and one class.
19 | """
20 |
21 | def __init__(self):
22 | self.data_points = []
23 | self.num_gt_positives = 0
24 |
25 | def push(self, score:float, is_true:bool):
26 | self.data_points.append((score, is_true))
27 |
28 | def add_gt_positives(self, num_positives:int):
29 | """ Call this once per image. """
30 | self.num_gt_positives += num_positives
31 |
32 | def is_empty(self) -> bool:
33 | return len(self.data_points) == 0 and self.num_gt_positives == 0
34 |
35 | def get_ap(self) -> float:
36 | """ Warning: result not cached. """
37 |
38 | if self.num_gt_positives == 0:
39 | return 0
40 |
41 | # Sort descending by score
42 | self.data_points.sort(key=lambda x: -x[0])
43 |
44 | precisions = []
45 | recalls = []
46 | num_true = 0
47 | num_false = 0
48 |
49 | # Compute the precision-recall curve. The x axis is recalls and the y axis precisions.
50 | for datum in self.data_points:
51 | # datum[1] is whether the detection a true or false positive
52 | if datum[1]: num_true += 1
53 | else: num_false += 1
54 |
55 | precision = num_true / (num_true + num_false)
56 | recall = num_true / self.num_gt_positives
57 |
58 | precisions.append(precision)
59 | recalls.append(recall)
60 |
61 | # Smooth the curve by computing [max(precisions[i:]) for i in range(len(precisions))]
62 | # Basically, remove any temporary dips from the curve.
63 | # At least that's what I think, idk. COCOEval did it so I do too.
64 | for i in range(len(precisions)-1, 0, -1):
65 | if precisions[i] > precisions[i-1]:
66 | precisions[i-1] = precisions[i]
67 |
68 | # Compute the integral of precision(recall) d_recall from recall=0->1 using fixed-length riemann summation with 101 bars.
69 | y_range = [0] * 101 # idx 0 is recall == 0.0 and idx 100 is recall == 1.00
70 | x_range = np.array([x / 100 for x in range(101)])
71 | recalls = np.array(recalls)
72 |
73 | # I realize this is weird, but all it does is find the nearest precision(x) for a given x in x_range.
74 | # Basically, if the closest recall we have to 0.01 is 0.009 this sets precision(0.01) = precision(0.009).
75 | # I approximate the integral this way, because that's how COCOEval does it.
76 | indices = np.searchsorted(recalls, x_range, side='left')
77 | for bar_idx, precision_idx in enumerate(indices):
78 | if precision_idx < len(precisions):
79 | y_range[bar_idx] = precisions[precision_idx]
80 |
81 | # Finally compute the riemann sum to get our integral.
82 | # avg([precision(x) for x in 0:0.01:1])
83 | return sum(y_range) / len(y_range)
84 |
85 | def compute_overlaps_masks(masks1, masks2):
86 | """Computes IoU overlaps between two sets of masks.
87 | masks1, masks2: [Height, Width, instances]
88 | """
89 | # If either set of masks is empty return empty result
90 | if masks1.shape[-1] == 0 or masks2.shape[-1] == 0:
91 | return np.zeros((masks1.shape[-1], masks2.shape[-1]))
92 | # flatten masks and compute their areas
93 | masks1 = np.reshape(masks1 > .5, (-1, masks1.shape[-1])).astype(np.float32)
94 | masks2 = np.reshape(masks2 > .5, (-1, masks2.shape[-1])).astype(np.float32)
95 | area1 = np.sum(masks1, axis=0)
96 | area2 = np.sum(masks2, axis=0)
97 |
98 | # intersections and union
99 | intersections = np.dot(masks1.T, masks2)
100 | union = area1[:, None] + area2[None, :] - intersections
101 | overlaps = intersections / union
102 |
103 | return overlaps
104 |
105 | def compute_iou(box, boxes, box_area, boxes_area):
106 | """Calculates IoU of the given box with the array of the given boxes.
107 | box: 1D vector [y1, x1, y2, x2]
108 | boxes: [boxes_count, (y1, x1, y2, x2)]
109 | box_area: float. the area of 'box'
110 | boxes_area: array of length boxes_count.
111 | Note: the areas are passed in rather than calculated here for
112 | efficiency. Calculate once in the caller to avoid duplicate work.
113 | """
114 | # Calculate intersection areas
115 | y1 = np.maximum(box[0], boxes[:, 0])
116 | y2 = np.minimum(box[2], boxes[:, 2])
117 | x1 = np.maximum(box[1], boxes[:, 1])
118 | x2 = np.minimum(box[3], boxes[:, 3])
119 | intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0)
120 | union = box_area + boxes_area[:] - intersection[:]
121 | iou = intersection / union
122 | return iou
123 |
124 | def compute_overlaps(boxes1, boxes2):
125 | """Computes IoU overlaps between two sets of boxes.
126 | boxes1, boxes2: [N, (y1, x1, y2, x2)].
127 | For better performance, pass the largest set first and the smaller second.
128 | """
129 | # Areas of anchors and GT boxes
130 | area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
131 | area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
132 |
133 | # Compute overlaps to generate matrix [boxes1 count, boxes2 count]
134 | # Each cell contains the IoU value.
135 | overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0]))
136 | for i in range(overlaps.shape[1]):
137 | box2 = boxes2[i]
138 | overlaps[:, i] = compute_iou(box2, boxes1, area2[i], area1)
139 | return overlaps
140 |
141 | def calc_map(ap_data, iou_thresholds, class_name, print_result=False):
142 | #print('Calculating mAP...')
143 | aps = [{'box': [], 'mask': []} for _ in iou_thresholds]
144 |
145 | for _class in range(len(class_name)):
146 | for iou_idx in range(len(iou_thresholds)):
147 | for iou_type in ('box', 'mask'):
148 | ap_obj = ap_data[iou_type][iou_idx][_class]
149 |
150 | if not ap_obj.is_empty():
151 | aps[iou_idx][iou_type].append(ap_obj.get_ap())
152 |
153 | all_maps = {'box': OrderedDict(), 'mask': OrderedDict()}
154 |
155 | # Looking back at it, this code is really hard to read :/
156 | for iou_type in ('box', 'mask'):
157 | all_maps[iou_type]['all'] = 0 # Make this first in the ordereddict
158 | for i, threshold in enumerate(iou_thresholds):
159 | mAP = sum(aps[i][iou_type]) / len(aps[i][iou_type]) * 100 if len(aps[i][iou_type]) > 0 else 0
160 | all_maps[iou_type][int(threshold*100)] = mAP
161 | all_maps[iou_type]['all'] = (sum(all_maps[iou_type].values()) / (len(all_maps[iou_type].values())-1))
162 |
163 | if print_result:
164 | print_maps(all_maps)
165 |
166 | # Put in a prettier format so we can serialize it to json during training
167 | all_maps = {k: {j: round(u, 2) for j, u in v.items()} for k, v in all_maps.items()}
168 | return all_maps
169 |
170 | def print_maps(all_maps):
171 | # Warning: hacky
172 | make_row = lambda vals: (' %5s |' * len(vals)) % tuple(vals)
173 | make_sep = lambda n: ('-------+' * n)
174 |
175 | print()
176 | print(make_row([''] + [('.%d ' % x if isinstance(x, int) else x + ' ') for x in all_maps['box'].keys()]))
177 | print(make_sep(len(all_maps['box']) + 1))
178 | for iou_type in ('box', 'mask'):
179 | print(make_row([iou_type] + ['%.2f' % x if x < 100 else '%.1f' % x for x in all_maps[iou_type].values()]))
180 | print(make_sep(len(all_maps['box']) + 1))
181 | print()
182 |
183 | def cal_map(p_bbox, p_labels, p_scores, p_mask, t_bbox, gt_classes, t_mask, ap_data, iou_thresholds):
184 |
185 | #print("p_bbox", p_bbox.shape)
186 | #print("p_labels", p_labels.shape)
187 | #print("p_scores", p_scores.shape)
188 | #print("p_mask", p_mask.shape)
189 | #print("t_bbox", t_bbox.shape)
190 | #print("gt_classes", gt_classes)
191 | #print("t_mask", t_mask.shape)
192 |
193 | num_crowd = 0
194 |
195 | classes = list(np.array(p_labels).astype(int))
196 | scores = list(np.array(p_scores).astype(float))
197 |
198 | box_scores = scores
199 | mask_scores = scores
200 |
201 | masks = p_mask
202 |
203 | num_pred = len(classes)
204 | num_gt = len(gt_classes)
205 |
206 | mask_iou_cache = compute_overlaps_masks(masks, t_mask)
207 | bbox_iou_cache = compute_overlaps(p_bbox, t_bbox)
208 |
209 | crowd_mask_iou_cache = None
210 | crowd_bbox_iou_cache = None
211 |
212 | box_indices = sorted(range(num_pred), key=lambda i: -box_scores[i])
213 | mask_indices = sorted(box_indices, key=lambda i: -mask_scores[i])
214 |
215 | iou_types = [
216 | ('box', lambda i,j: bbox_iou_cache[i, j].item(),
217 | lambda i,j: crowd_bbox_iou_cache[i,j].item(),
218 | lambda i: box_scores[i], box_indices),
219 | ('mask', lambda i,j: mask_iou_cache[i, j].item(),
220 | lambda i,j: crowd_mask_iou_cache[i,j].item(),
221 | lambda i: mask_scores[i], mask_indices)
222 | ]
223 | #print("run", list(classes), list(gt_classes))
224 | #print(classes + gt_classes)
225 | for _class in set(list(classes) + list(gt_classes)):
226 | ap_per_iou = []
227 | num_gt_for_class = sum([1 for x in gt_classes if x == _class])
228 |
229 | for iouIdx in range(len(iou_thresholds)):
230 | iou_threshold = iou_thresholds[iouIdx]
231 | for iou_type, iou_func, crowd_func, score_func, indices in iou_types:
232 | gt_used = [False] * len(gt_classes)
233 | ap_obj = ap_data[iou_type][iouIdx][_class]
234 | ap_obj.add_gt_positives(num_gt_for_class)
235 |
236 | for i in indices:
237 | if classes[i] != _class:
238 | continue
239 | max_iou_found = iou_threshold
240 | max_match_idx = -1
241 | for j in range(num_gt):
242 | if gt_used[j] or gt_classes[j] != _class:
243 | continue
244 | iou = iou_func(i, j)
245 |
246 | if iou > max_iou_found:
247 | max_iou_found = iou
248 | max_match_idx = j
249 |
250 | if max_match_idx >= 0:
251 | gt_used[max_match_idx] = True
252 | ap_obj.push(score_func(i), True)
253 | else:
254 | # If the detection matches a crowd, we can just ignore it
255 | matched_crowd = False
256 |
257 | if num_crowd > 0:
258 | for j in range(len(crowd_classes)):
259 | if crowd_classes[j] != _class:
260 | continue
261 |
262 | iou = crowd_func(i, j)
263 |
264 | if iou > iou_threshold:
265 | matched_crowd = True
266 | break
267 |
268 | # All this crowd code so that we can make sure that our eval code gives the
269 | # same result as COCOEval. There aren't even that many crowd annotations to
270 | # begin with, but accuracy is of the utmost importance.
271 | if not matched_crowd:
272 | ap_obj.push(score_func(i), False)
273 |
--------------------------------------------------------------------------------
/detr_tf/loss/hungarian_matching.py:
--------------------------------------------------------------------------------
1 | from typing import Union,Dict,Tuple
2 | from itertools import product
3 | import tensorflow as tf
4 | import numpy as np
5 |
6 | from .. import bbox
7 | from scipy.optimize import linear_sum_assignment
8 |
9 |
10 | def get_offsets(anchors_xywh, target_bbox_xywh):
11 | # Return the offset between the boxes in anchors_xywh and the boxes
12 | # in anchors_xywh
13 |
14 | variances = [0.1, 0.2]
15 |
16 | tiled_a_bbox, tiled_t_bbox = bbox.merge(anchors_xywh, target_bbox_xywh)
17 |
18 | g_cxcy = (tiled_t_bbox[:,:,:2] - tiled_a_bbox[:,:,:2])
19 | g_cxcy = g_cxcy / (variances[0] * tiled_a_bbox[:,:,2:])
20 |
21 | g_wh = tiled_t_bbox[:,:,2:] / tiled_a_bbox[:,:,2:]
22 | g_wh = tf.math.log(g_wh) / variances[1]
23 |
24 | return tf.concat([g_cxcy, g_wh], axis=-1)
25 |
26 |
27 | def np_tf_linear_sum_assignment(matrix):
28 |
29 | indices = linear_sum_assignment(matrix)
30 | target_indices = indices[0]
31 | pred_indices = indices[1]
32 |
33 | #print(matrix.shape, target_indices, pred_indices)
34 |
35 | target_selector = np.zeros(matrix.shape[0])
36 | target_selector[target_indices] = 1
37 | target_selector = target_selector.astype(np.bool)
38 |
39 | pred_selector = np.zeros(matrix.shape[1])
40 | pred_selector[pred_indices] = 1
41 | pred_selector = pred_selector.astype(np.bool)
42 |
43 | #print('target_indices', target_indices)
44 | #print("pred_indices", pred_indices)
45 |
46 | return [target_indices, pred_indices, target_selector, pred_selector]
47 |
48 |
49 | def box_cxcywh_to_xyxy(x):
50 | x_c, y_c, w, h = x.unbind(-1)
51 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
52 | (x_c + 0.5 * w), (y_c + 0.5 * h)]
53 | return torch.stack(b, dim=-1)
54 |
55 |
56 | def generalized_box_iou(boxes1, boxes2):
57 | """
58 | Generalized IoU from https://giou.stanford.edu/
59 | The boxes should be in [x0, y0, x1, y1] format
60 | Returns a [N, M] pairwise matrix, where N = len(boxes1)
61 | and M = len(boxes2)
62 | """
63 | # degenerate boxes gives inf / nan results
64 | # so do an early check
65 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
66 |
67 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
68 | iou, union = box_iou(boxes1, boxes2)
69 |
70 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
71 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
72 |
73 | wh = (rb - lt).clamp(min=0) # [N,M,2]
74 | area = wh[:, :, 0] * wh[:, :, 1]
75 |
76 | return iou - (area - union) / area
77 |
78 | def box_iou(boxes1, boxes2):
79 | area1 = box_area(boxes1)
80 | area2 = box_area(boxes2)
81 |
82 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
83 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
84 |
85 | wh = (rb - lt).clamp(min=0) # [N,M,2]
86 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
87 |
88 | union = area1[:, None] + area2 - inter
89 |
90 | iou = inter / union
91 | return iou, union
92 |
93 |
94 |
95 | def _get_src_permutation_idx(indices):
96 | # permute predictions following indices
97 | batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
98 | src_idx = torch.cat([src for (src, _) in indices])
99 | return batch_idx, src_idx
100 |
101 |
102 | def loss_labels(outputs, targets, indices, num_boxes, log=True):
103 | """Classification loss (NLL)
104 | targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
105 | """
106 | assert 'pred_logits' in outputs
107 | src_logits = outputs['pred_logits']
108 |
109 | idx = _get_src_permutation_idx(indices)
110 | target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
111 | target_classes = torch.full(src_logits.shape[:2], 0,
112 | dtype=torch.int64, device=src_logits.device)
113 | target_classes[idx] = target_classes_o
114 |
115 | empty_weight = torch.ones(81)
116 | empty_weight[0] = 0.1
117 |
118 | #print("log_softmax(input, 1)", F.softmax(src_logits, 1).mean())
119 | #print("src_logits", src_logits.shape)
120 | #print("target_classes", target_classes, target_classes.shape)
121 |
122 | #print("target_classes", target_classes)
123 |
124 | loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, empty_weight)
125 | #print('>loss_ce', loss_ce)
126 | losses = {'loss_ce': loss_ce}
127 |
128 | #if log:
129 | # # TODO this should probably be a separate loss, not hacked in this one here
130 | # losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
131 | return losses
132 |
133 |
134 |
135 | def loss_boxes(outputs, targets, indices, num_boxes):
136 | """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
137 | targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
138 | The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
139 | """
140 | #print("------")
141 | #assert 'pred_boxes' in outputs
142 | idx = _get_src_permutation_idx(indices)
143 | src_boxes = outputs['pred_boxes'][idx]
144 | target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
145 |
146 | #print("target_boxes", target_boxes)
147 | #print("src_boxes", src_boxes)
148 |
149 | loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
150 | #print("loss_bbox", loss_bbox)
151 | losses = {}
152 | losses['loss_bbox'] = loss_bbox.sum() / target_boxes.shape[0]
153 | #print(">loss_bbox", losses['loss_bbox'])
154 |
155 | loss_giou = 1 - torch.diag(generalized_box_iou(
156 | box_cxcywh_to_xyxy(src_boxes),
157 | box_cxcywh_to_xyxy(target_boxes)))
158 | #print('>loss_giou', loss_giou)
159 | losses['loss_giou'] = loss_giou.sum() / target_boxes.shape[0]
160 | #print(">loss_giou", losses['loss_giou'])
161 | return losses
162 |
163 | def hungarian_matching(t_bbox, t_class, p_bbox, p_class, fcost_class=1, fcost_bbox=5, fcost_giou=2, slice_preds=True) -> tuple:
164 |
165 | if slice_preds:
166 | size = tf.cast(t_bbox[0][0], tf.int32)
167 | t_bbox = tf.slice(t_bbox, [1, 0], [size, 4])
168 | t_class = tf.slice(t_class, [1, 0], [size, -1])
169 | t_class = tf.squeeze(t_class, axis=-1)
170 |
171 | # Convert frpm [xc, yc, w, h] to [xmin, ymin, xmax, ymax]
172 | p_bbox_xy = bbox.xcycwh_to_xy_min_xy_max(p_bbox)
173 | t_bbox_xy = bbox.xcycwh_to_xy_min_xy_max(t_bbox)
174 |
175 | softmax = tf.nn.softmax(p_class)
176 |
177 | # Classification cost for the Hungarian algorithom
178 | # On each prediction. We select the prob of the expected class
179 | cost_class = -tf.gather(softmax, t_class, axis=1)
180 |
181 | # L1 cost for the hungarian algorithm
182 | _p_bbox, _t_bbox = bbox.merge(p_bbox, t_bbox)
183 | cost_bbox = tf.norm(_p_bbox - _t_bbox, ord=1, axis=-1)
184 |
185 | # Generalized IOU
186 | iou, union = bbox.jaccard(p_bbox_xy, t_bbox_xy, return_union=True)
187 | _p_bbox_xy, _t_bbox_xy = bbox.merge(p_bbox_xy, t_bbox_xy)
188 | top_left = tf.math.minimum(_p_bbox_xy[:,:,:2], _t_bbox_xy[:,:,:2])
189 | bottom_right = tf.math.maximum(_p_bbox_xy[:,:,2:], _t_bbox_xy[:,:,2:])
190 | size = tf.nn.relu(bottom_right - top_left)
191 | area = size[:,:,0] * size[:,:,1]
192 | cost_giou = -(iou - (area - union) / area)
193 |
194 | # Final hungarian cost matrix
195 | cost_matrix = fcost_bbox * cost_bbox + fcost_class * cost_class + fcost_giou * cost_giou
196 |
197 | selectors = tf.numpy_function(np_tf_linear_sum_assignment, [cost_matrix], [tf.int64, tf.int64, tf.bool, tf.bool] )
198 | target_indices = selectors[0]
199 | pred_indices = selectors[1]
200 | target_selector = selectors[2]
201 | pred_selector = selectors[3]
202 |
203 | return pred_indices, target_indices, pred_selector, target_selector, t_bbox, t_class
204 |
--------------------------------------------------------------------------------
/detr_tf/loss/loss.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from .. import bbox
3 | from .hungarian_matching import hungarian_matching
4 |
5 |
6 | def get_total_losss(losses):
7 | """
8 | Get model total losss including auxiliary loss
9 | """
10 | train_loss = ["label_cost", "giou_loss", "l1_loss"]
11 | loss_weights = [1, 2, 5]
12 |
13 | total_loss = 0
14 | for key in losses:
15 | selector = [l for l, loss_name in enumerate(train_loss) if loss_name in key]
16 | if len(selector) == 1:
17 | #print("Add to the total loss", key, losses[key], loss_weights[selector[0]])
18 | total_loss += losses[key]*loss_weights[selector[0]]
19 | return total_loss
20 |
21 |
22 | def get_losses(m_outputs, t_bbox, t_class, config):
23 | losses = get_detr_losses(m_outputs, t_bbox, t_class, config)
24 |
25 | # Get auxiliary loss for each auxiliary output
26 | if "aux" in m_outputs:
27 | for a, aux_m_outputs in enumerate(m_outputs["aux"]):
28 | aux_losses = get_detr_losses(aux_m_outputs, t_bbox, t_class, config, suffix="_{}".format(a))
29 | losses.update(aux_losses)
30 |
31 | # Compute the total loss
32 | total_loss = get_total_losss(losses)
33 |
34 | return total_loss, losses
35 |
36 |
37 | def loss_labels(p_bbox, p_class, t_bbox, t_class, t_indices, p_indices, t_selector, p_selector, background_class=0):
38 |
39 | neg_indices = tf.squeeze(tf.where(p_selector == False), axis=-1)
40 | neg_p_class = tf.gather(p_class, neg_indices)
41 | neg_t_class = tf.zeros((tf.shape(neg_p_class)[0],), tf.int64) + background_class
42 |
43 | neg_weights = tf.zeros((tf.shape(neg_indices)[0],)) + 0.1
44 | pos_weights = tf.zeros((tf.shape(t_indices)[0],)) + 1.0
45 | weights = tf.concat([neg_weights, pos_weights], axis=0)
46 |
47 | pos_p_class = tf.gather(p_class, p_indices)
48 | pos_t_class = tf.gather(t_class, t_indices)
49 |
50 | #############
51 | # Metrics
52 | #############
53 | # True negative
54 | cls_neg_p_class = tf.argmax(neg_p_class, axis=-1)
55 | true_neg = tf.reduce_mean(tf.cast(cls_neg_p_class == background_class, tf.float32))
56 | # True positive
57 | cls_pos_p_class = tf.argmax(pos_p_class, axis=-1)
58 | true_pos = tf.reduce_mean(tf.cast(cls_pos_p_class != background_class, tf.float32))
59 | # True accuracy
60 | cls_pos_p_class = tf.argmax(pos_p_class, axis=-1)
61 | pos_accuracy = tf.reduce_mean(tf.cast(cls_pos_p_class == pos_t_class, tf.float32))
62 |
63 | targets = tf.concat([neg_t_class, pos_t_class], axis=0)
64 | preds = tf.concat([neg_p_class, pos_p_class], axis=0)
65 |
66 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(targets, preds)
67 | loss = tf.reduce_sum(loss * weights) / tf.reduce_sum(weights)
68 |
69 | return loss, true_neg, true_pos, pos_accuracy
70 |
71 |
72 | def loss_boxes(p_bbox, p_class, t_bbox, t_class, t_indices, p_indices, t_selector, p_selector):
73 | #print("------")
74 | p_bbox = tf.gather(p_bbox, p_indices)
75 | t_bbox = tf.gather(t_bbox, t_indices)
76 |
77 |
78 | p_bbox_xy = bbox.xcycwh_to_xy_min_xy_max(p_bbox)
79 | t_bbox_xy = bbox.xcycwh_to_xy_min_xy_max(t_bbox)
80 |
81 | l1_loss = tf.abs(p_bbox-t_bbox)
82 | l1_loss = tf.reduce_sum(l1_loss) / tf.cast(tf.shape(p_bbox)[0], tf.float32)
83 |
84 | iou, union = bbox.jaccard(p_bbox_xy, t_bbox_xy, return_union=True)
85 |
86 | _p_bbox_xy, _t_bbox_xy = bbox.merge(p_bbox_xy, t_bbox_xy)
87 | top_left = tf.math.minimum(_p_bbox_xy[:,:,:2], _t_bbox_xy[:,:,:2])
88 | bottom_right = tf.math.maximum(_p_bbox_xy[:,:,2:], _t_bbox_xy[:,:,2:])
89 | size = tf.nn.relu(bottom_right - top_left)
90 | area = size[:,:,0] * size[:,:,1]
91 | giou = (iou - (area - union) / area)
92 | loss_giou = 1 - tf.linalg.diag_part(giou)
93 |
94 | loss_giou = tf.reduce_sum(loss_giou) / tf.cast(tf.shape(p_bbox)[0], tf.float32)
95 |
96 | return loss_giou, l1_loss
97 |
98 | def get_detr_losses(m_outputs, target_bbox, target_label, config, suffix=""):
99 |
100 | predicted_bbox = m_outputs["pred_boxes"]
101 | predicted_label = m_outputs["pred_logits"]
102 |
103 | all_target_bbox = []
104 | all_target_class = []
105 | all_predicted_bbox = []
106 | all_predicted_class = []
107 | all_target_indices = []
108 | all_predcted_indices = []
109 | all_target_selector = []
110 | all_predcted_selector = []
111 |
112 | t_offset = 0
113 | p_offset = 0
114 |
115 | for b in range(predicted_bbox.shape[0]):
116 |
117 | p_bbox, p_class, t_bbox, t_class = predicted_bbox[b], predicted_label[b], target_bbox[b], target_label[b]
118 | t_indices, p_indices, t_selector, p_selector, t_bbox, t_class = hungarian_matching(t_bbox, t_class, p_bbox, p_class, slice_preds=True)
119 |
120 | t_indices = t_indices + tf.cast(t_offset, tf.int64)
121 | p_indices = p_indices + tf.cast(p_offset, tf.int64)
122 |
123 | all_target_bbox.append(t_bbox)
124 | all_target_class.append(t_class)
125 | all_predicted_bbox.append(p_bbox)
126 | all_predicted_class.append(p_class)
127 | all_target_indices.append(t_indices)
128 | all_predcted_indices.append(p_indices)
129 | all_target_selector.append(t_selector)
130 | all_predcted_selector.append(p_selector)
131 |
132 | t_offset += tf.shape(t_bbox)[0]
133 | p_offset += tf.shape(p_bbox)[0]
134 |
135 | all_target_bbox = tf.concat(all_target_bbox, axis=0)
136 | all_target_class = tf.concat(all_target_class, axis=0)
137 | all_predicted_bbox = tf.concat(all_predicted_bbox, axis=0)
138 | all_predicted_class = tf.concat(all_predicted_class, axis=0)
139 | all_target_indices = tf.concat(all_target_indices, axis=0)
140 | all_predcted_indices = tf.concat(all_predcted_indices, axis=0)
141 | all_target_selector = tf.concat(all_target_selector, axis=0)
142 | all_predcted_selector = tf.concat(all_predcted_selector, axis=0)
143 |
144 |
145 | label_cost, true_neg, true_pos, pos_accuracy = loss_labels(
146 | all_predicted_bbox,
147 | all_predicted_class,
148 | all_target_bbox,
149 | all_target_class,
150 | all_target_indices,
151 | all_predcted_indices,
152 | all_target_selector,
153 | all_predcted_selector,
154 | background_class=config.background_class,
155 | )
156 |
157 | giou_loss, l1_loss = loss_boxes(
158 | all_predicted_bbox,
159 | all_predicted_class,
160 | all_target_bbox,
161 | all_target_class,
162 | all_target_indices,
163 | all_predcted_indices,
164 | all_target_selector,
165 | all_predcted_selector
166 | )
167 |
168 | label_cost = label_cost
169 | giou_loss = giou_loss
170 | l1_loss = l1_loss
171 |
172 | return {
173 | "label_cost{}".format(suffix): label_cost,
174 | "true_neg{}".format(suffix): true_neg,
175 | "true_pos{}".format(suffix): true_pos,
176 | "pos_accuracy{}".format(suffix): pos_accuracy,
177 | "giou_loss{}".format(suffix): giou_loss,
178 | "l1_loss{}".format(suffix): l1_loss
179 | }
180 |
--------------------------------------------------------------------------------
/detr_tf/networks/custom_layers.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | class FrozenBatchNorm2D(tf.keras.layers.Layer):
5 | def __init__(self, eps=1e-5, **kwargs):
6 | super().__init__(**kwargs)
7 | self.eps = eps
8 |
9 |
10 | def build(self, input_shape):
11 | self.weight = self.add_weight(name='weight', shape=[input_shape[-1]],
12 | initializer=tf.keras.initializers.GlorotUniform(), trainable=False)
13 | self.bias = self.add_weight(name='bias', shape=[input_shape[-1]],
14 | initializer=tf.keras.initializers.GlorotUniform(), trainable=False)
15 | self.running_mean = self.add_weight(name='running_mean', shape=[input_shape[-1]],
16 | initializer='zeros', trainable=False)
17 | self.running_var = self.add_weight(name='running_var', shape=[input_shape[-1]],
18 | initializer='ones', trainable=False)
19 |
20 |
21 | def call(self, x):
22 | scale = self.weight * tf.math.rsqrt(self.running_var + self.eps)
23 | shift = self.bias - self.running_mean * scale
24 | return x * scale + shift
25 |
26 |
27 | def compute_output_shape(self, input_shape):
28 | return input_shape
29 |
30 |
31 | class Linear(tf.keras.layers.Layer):
32 | '''
33 | Use this custom layer instead of tf.keras.layers.Dense to allow
34 | loading converted PyTorch Dense weights that have shape (output_dim, input_dim)
35 | '''
36 | def __init__(self, output_dim, **kwargs):
37 | super().__init__(**kwargs)
38 | self.output_dim = output_dim
39 |
40 |
41 | def build(self, input_shape):
42 | self.kernel = self.add_weight(name='kernel',
43 | shape=[self.output_dim, input_shape[-1]],
44 | initializer=tf.keras.initializers.GlorotUniform(), trainable=True)
45 | self.bias = self.add_weight(name='bias',
46 | shape=[self.output_dim],
47 | initializer=tf.keras.initializers.GlorotUniform(), trainable=True)
48 |
49 | def call(self, x):
50 | return tf.matmul(x, self.kernel, transpose_b=True) + self.bias
51 |
52 |
53 | def compute_output_shape(self, input_shape):
54 | return input_shape.as_list()[:-1] + [self.output_dim]
55 |
56 |
57 | class FixedEmbedding(tf.keras.layers.Layer):
58 | def __init__(self, embed_shape, **kwargs):
59 | super().__init__(**kwargs)
60 | self.embed_shape = embed_shape
61 |
62 | def build(self, input_shape):
63 | self.w = self.add_weight(name='kernel', shape=self.embed_shape,
64 | initializer=tf.keras.initializers.GlorotUniform(), trainable=True)
65 |
66 | def call(self, x=None):
67 | return self.w
68 |
--------------------------------------------------------------------------------
/detr_tf/networks/detr.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import tensorflow as tf
3 | import numpy as np
4 | import time
5 | import cv2
6 | import matplotlib.pyplot as plt
7 | import os
8 | from pathlib import Path
9 |
10 |
11 | from .resnet_backbone import ResNet50Backbone
12 | from .custom_layers import Linear, FixedEmbedding
13 | from .position_embeddings import PositionEmbeddingSine
14 | from .transformer import Transformer
15 | from .. bbox import xcycwh_to_xy_min_xy_max
16 | from .weights import load_weights
17 |
18 |
19 | class DETR(tf.keras.Model):
20 | def __init__(self, num_classes=92, num_queries=100,
21 | backbone=None,
22 | pos_encoder=None,
23 | transformer=None,
24 | num_encoder_layers=6,
25 | num_decoder_layers=6,
26 | return_intermediate_dec=True,
27 | **kwargs):
28 | super().__init__(**kwargs)
29 | self.num_queries = num_queries
30 |
31 | self.backbone = ResNet50Backbone(name='backbone')
32 | self.transformer = transformer or Transformer(
33 | num_encoder_layers=num_encoder_layers,
34 | num_decoder_layers=num_decoder_layers,
35 | return_intermediate_dec=return_intermediate_dec,
36 | name='transformer'
37 | )
38 |
39 | self.model_dim = self.transformer.model_dim
40 |
41 | self.pos_encoder = pos_encoder or PositionEmbeddingSine(
42 | num_pos_features=self.model_dim // 2, normalize=True, name="position_embedding_sine")
43 |
44 | self.input_proj = tf.keras.layers.Conv2D(self.model_dim, kernel_size=1, name='input_proj')
45 |
46 | self.query_embed = FixedEmbedding((num_queries, self.model_dim),
47 | name='query_embed')
48 |
49 | self.class_embed = Linear(num_classes, name='class_embed')
50 |
51 | self.bbox_embed_linear1 = Linear(self.model_dim, name='bbox_embed_0')
52 | self.bbox_embed_linear2 = Linear(self.model_dim, name='bbox_embed_1')
53 | self.bbox_embed_linear3 = Linear(4, name='bbox_embed_2')
54 | self.activation = tf.keras.layers.ReLU(name='re_lu')
55 |
56 |
57 | def downsample_masks(self, masks, x):
58 | masks = tf.cast(masks, tf.int32)
59 | masks = tf.expand_dims(masks, -1)
60 | masks = tf.compat.v1.image.resize_nearest_neighbor(masks, tf.shape(x)[1:3], align_corners=False, half_pixel_centers=False)
61 | masks = tf.squeeze(masks, -1)
62 | masks = tf.cast(masks, tf.bool)
63 | return masks
64 |
65 | def call(self, inp, training=False, post_process=False):
66 | x, masks = inp
67 | x = self.backbone(x, training=training)
68 | masks = self.downsample_masks(masks, x)
69 |
70 | pos_encoding = self.pos_encoder(masks)
71 |
72 | hs = self.transformer(self.input_proj(x), masks, self.query_embed(None),
73 | pos_encoding, training=training)[0]
74 |
75 | outputs_class = self.class_embed(hs)
76 |
77 | box_ftmps = self.activation(self.bbox_embed_linear1(hs))
78 | box_ftmps = self.activation(self.bbox_embed_linear2(box_ftmps))
79 | outputs_coord = tf.sigmoid(self.bbox_embed_linear3(box_ftmps))
80 |
81 | output = {'pred_logits': outputs_class[-1],
82 | 'pred_boxes': outputs_coord[-1]}
83 |
84 | if post_process:
85 | output = self.post_process(output)
86 | return output
87 |
88 |
89 | def build(self, input_shape=None, **kwargs):
90 | if input_shape is None:
91 | input_shape = [(None, None, None, 3), (None, None, None)]
92 | super().build(input_shape, **kwargs)
93 |
94 | def add_heads_nlayers(config, detr, nb_class):
95 | image_input = tf.keras.Input((None, None, 3))
96 | # Setup the new layers
97 | cls_layer = tf.keras.layers.Dense(nb_class, name="cls_layer")
98 | pos_layer = tf.keras.models.Sequential([
99 | tf.keras.layers.Dense(256, activation="relu"),
100 | tf.keras.layers.Dense(256, activation="relu"),
101 | tf.keras.layers.Dense(4, activation="sigmoid"),
102 | ], name="pos_layer")
103 | config.add_nlayers([cls_layer, pos_layer])
104 |
105 | transformer_output = detr(image_input)
106 | cls_preds = cls_layer(transformer_output)
107 | pos_preds = pos_layer(transformer_output)
108 |
109 | # Define the main outputs along with the auxialiary loss
110 | outputs = {'pred_logits': cls_preds[-1], 'pred_boxes': pos_preds[-1]}
111 | outputs["aux"] = [ {"pred_logits": cls_preds[i], "pred_boxes": pos_preds[i]} for i in range(0, 5)]
112 |
113 | n_detr = tf.keras.Model(image_input, outputs, name="detr_finetuning")
114 | return n_detr
115 |
116 | def get_detr_model(config, include_top=False, nb_class=None, weights=None, tf_backbone=False, num_decoder_layers=6, num_encoder_layers=6):
117 | """ Get the DETR model
118 |
119 | Parameters
120 | ----------
121 | include_top: bool
122 | If false, the last layers of the transformers used to predict the bbox position
123 | and cls will not be include. And therefore could be replace for finetuning if the `weight` parameter
124 | is set.
125 | nb_class: int
126 | If include_top is False and nb_class is set, then, this method will automaticly add two new heads to predict
127 | the bbox pos and the bbox class on the decoder.
128 | weights: str
129 | Name of the weights to load. Only "detr" is avaiable to get started for now.
130 | More weight as detr-r101 will be added soon.
131 | tf_backbone:
132 | Using the pretrained weight from pytorch, the resnet backbone does not used
133 | tf.keras.application to load the weight. If you do want to load the tf backbone, and not
134 | laod the weights from pytorch, set this variable to True.
135 | """
136 | detr = DETR(num_decoder_layers=num_decoder_layers, num_encoder_layers=num_encoder_layers)
137 |
138 | if weights is not None:
139 | load_weights(detr, weights)
140 |
141 | image_input = tf.keras.Input((None, None, 3))
142 |
143 | # Backbone
144 | if not tf_backbone:
145 | backbone = detr.get_layer("backbone")
146 | else:
147 | config.normalized_method = "tf_resnet"
148 | backbone = tf.keras.applications.ResNet50(include_top=False, weights="imagenet", input_shape=(None, None, 3))
149 |
150 | # Transformer
151 | transformer = detr.get_layer("transformer")
152 | # Positional embedding of the feature map
153 | position_embedding_sine = detr.get_layer("position_embedding_sine")
154 | # Used to project the feature map before to fit in into the encoder
155 | input_proj = detr.get_layer('input_proj')
156 | # Decoder objects query embedding
157 | query_embed = detr.get_layer('query_embed')
158 |
159 |
160 | # Used to project the output of the decoder into a class prediction
161 | # This layer will be replace for finetuning
162 | class_embed = detr.get_layer('class_embed')
163 |
164 | # Predict the bbox pos
165 | bbox_embed_linear1 = detr.get_layer('bbox_embed_0')
166 | bbox_embed_linear2 = detr.get_layer('bbox_embed_1')
167 | bbox_embed_linear3 = detr.get_layer('bbox_embed_2')
168 | activation = detr.get_layer("re_lu")
169 |
170 | x = backbone(image_input)
171 |
172 | masks = tf.zeros((tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]), tf.bool)
173 | pos_encoding = position_embedding_sine(masks)
174 |
175 | hs = transformer(input_proj(x), masks, query_embed(None), pos_encoding)[0]
176 |
177 | detr = tf.keras.Model(image_input, hs, name="detr")
178 | if include_top is False and nb_class is None:
179 | return detr
180 | elif include_top is False and nb_class is not None:
181 | return add_heads_nlayers(config, detr, nb_class)
182 |
183 | transformer_output = detr(image_input)
184 |
185 | outputs_class = class_embed(transformer_output)
186 | box_ftmps = activation(bbox_embed_linear1(transformer_output))
187 | box_ftmps = activation(bbox_embed_linear2(box_ftmps))
188 | outputs_coord = tf.sigmoid(bbox_embed_linear3(box_ftmps))
189 |
190 | outputs = {}
191 |
192 | output = {'pred_logits': outputs_class[-1],
193 | 'pred_boxes': outputs_coord[-1]}
194 |
195 | output["aux"] = []
196 | for i in range(0, num_decoder_layers - 1):
197 | out_class = outputs_class[i]
198 | pred_boxes = outputs_coord[i]
199 | output["aux"].append({
200 | "pred_logits": out_class,
201 | "pred_boxes": pred_boxes
202 | })
203 |
204 | return tf.keras.Model(image_input, output, name="detr_finetuning")
205 |
206 |
--------------------------------------------------------------------------------
/detr_tf/networks/position_embeddings.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 |
5 | class PositionEmbeddingSine(tf.keras.Model):
6 |
7 |
8 | def __init__(self, num_pos_features=64, temperature=10000,
9 | normalize=False, scale=None, eps=1e-6, **kwargs):
10 | super().__init__(**kwargs)
11 |
12 | self.num_pos_features = num_pos_features
13 | self.temperature = temperature
14 | self.normalize = normalize
15 | if scale is not None and normalize is False:
16 | raise ValueError('normalize should be True if scale is passed')
17 | if scale is None:
18 | scale = 2 * np.pi
19 | self.scale = scale
20 | self.eps = eps
21 |
22 |
23 | def call(self, mask):
24 | not_mask = tf.cast(~mask, tf.float32)
25 | y_embed = tf.math.cumsum(not_mask, axis=1)
26 | x_embed = tf.math.cumsum(not_mask, axis=2)
27 |
28 | if self.normalize:
29 | y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale
30 | x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale
31 |
32 | dim_t = tf.range(self.num_pos_features, dtype=tf.float32)
33 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_features)
34 |
35 | pos_x = x_embed[..., tf.newaxis] / dim_t
36 | pos_y = y_embed[..., tf.newaxis] / dim_t
37 |
38 | pos_x = tf.stack([tf.math.sin(pos_x[..., 0::2]),
39 | tf.math.cos(pos_x[..., 1::2])], axis=4)
40 |
41 | pos_y = tf.stack([tf.math.sin(pos_y[..., 0::2]),
42 | tf.math.cos(pos_y[..., 1::2])], axis=4)
43 |
44 |
45 | shape = [tf.shape(pos_x)[i] for i in range(3)] + [-1]
46 | pos_x = tf.reshape(pos_x, shape)
47 | pos_y = tf.reshape(pos_y, shape)
48 |
49 | pos_emb = tf.concat([pos_y, pos_x], axis=3)
50 | return pos_emb
51 |
--------------------------------------------------------------------------------
/detr_tf/networks/resnet_backbone.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras.layers import ZeroPadding2D, Conv2D, ReLU, MaxPool2D
3 |
4 | from .custom_layers import FrozenBatchNorm2D
5 |
6 |
7 | class ResNetBase(tf.keras.Model):
8 | def __init__(self, **kwargs):
9 | super().__init__(**kwargs)
10 |
11 | self.pad1 = ZeroPadding2D(3, name='pad1')
12 | self.conv1 = Conv2D(64, kernel_size=7, strides=2, padding='valid',
13 | use_bias=False, name='conv1')
14 | self.bn1 = FrozenBatchNorm2D(name='bn1')
15 | self.relu = ReLU(name='relu')
16 | self.pad2 = ZeroPadding2D(1, name='pad2')
17 | self.maxpool = MaxPool2D(pool_size=3, strides=2, padding='valid')
18 |
19 |
20 | def call(self, x):
21 | x = self.pad1(x)
22 | x = self.conv1(x)
23 | x = self.bn1(x)
24 | x = self.relu(x)
25 | x = self.pad2(x)
26 | x = self.maxpool(x)
27 |
28 | x = self.layer1(x)
29 | x = self.layer2(x)
30 | x = self.layer3(x)
31 | x = self.layer4(x)
32 | return x
33 |
34 |
35 | class ResNet50Backbone(ResNetBase):
36 | def __init__(self, replace_stride_with_dilation=[False, False, False], **kwargs):
37 | super().__init__(**kwargs)
38 |
39 | self.layer1 = ResidualBlock(num_bottlenecks=3, dim1=64, dim2=256, strides=1,
40 | replace_stride_with_dilation=False, name='layer1')
41 | self.layer2 = ResidualBlock(num_bottlenecks=4, dim1=128, dim2=512, strides=2,
42 | replace_stride_with_dilation=replace_stride_with_dilation[0],
43 | name='layer2')
44 | self.layer3 = ResidualBlock(num_bottlenecks=6, dim1=256, dim2=1024, strides=2,
45 | replace_stride_with_dilation=replace_stride_with_dilation[1],
46 | name='layer3')
47 | self.layer4 = ResidualBlock(num_bottlenecks=3, dim1=512, dim2=2048, strides=2,
48 | replace_stride_with_dilation=replace_stride_with_dilation[2],
49 | name='layer4')
50 |
51 |
52 | class ResNet101Backbone(ResNetBase):
53 | def __init__(self, replace_stride_with_dilation=[False, False, False], **kwargs):
54 | super().__init__(**kwargs)
55 |
56 | self.layer1 = ResidualBlock(num_bottlenecks=3, dim1=64, dim2=256, strides=1,
57 | replace_stride_with_dilation=False, name='layer1')
58 | self.layer2 = ResidualBlock(num_bottlenecks=4, dim1=128, dim2=512, strides=2,
59 | replace_stride_with_dilation=replace_stride_with_dilation[0],
60 | name='layer2')
61 | self.layer3 = ResidualBlock(num_bottlenecks=23, dim1=256, dim2=1024, strides=2,
62 | replace_stride_with_dilation=replace_stride_with_dilation[1],
63 | name='layer3')
64 | self.layer4 = ResidualBlock(num_bottlenecks=3, dim1=512, dim2=2048, strides=2,
65 | replace_stride_with_dilation=replace_stride_with_dilation[2],
66 | name='layer4')
67 |
68 |
69 | class ResidualBlock(tf.keras.Model):
70 | def __init__(self, num_bottlenecks, dim1, dim2, strides=1,
71 | replace_stride_with_dilation=False, **kwargs):
72 | super().__init__(**kwargs)
73 |
74 | if replace_stride_with_dilation:
75 | strides = 1
76 | dilation = 2
77 | else:
78 | dilation = 1
79 |
80 | self.bottlenecks = [BottleNeck(dim1, dim2, strides=strides,
81 | downsample=True, name='0')]
82 |
83 | for idx in range(1, num_bottlenecks):
84 | self.bottlenecks.append(BottleNeck(dim1, dim2, name=str(idx),
85 | dilation=dilation))
86 |
87 |
88 | def call(self, x):
89 | for btn in self.bottlenecks:
90 | x = btn(x)
91 | return x
92 |
93 |
94 | class BottleNeck(tf.keras.Model):
95 | def __init__(self, dim1, dim2, strides=1, dilation=1, downsample=False, **kwargs):
96 | super().__init__(**kwargs)
97 | self.downsample = downsample
98 | self.pad = ZeroPadding2D(dilation)
99 | self.relu = ReLU(name='relu')
100 |
101 | self.conv1 = Conv2D(dim1, kernel_size=1, use_bias=False, name='conv1')
102 | self.bn1 = FrozenBatchNorm2D(name='bn1')
103 |
104 | self.conv2 = Conv2D(dim1, kernel_size=3, strides=strides, dilation_rate=dilation,
105 | use_bias=False, name='conv2')
106 | self.bn2 = FrozenBatchNorm2D(name='bn2')
107 |
108 | self.conv3 = Conv2D(dim2, kernel_size=1, use_bias=False, name='conv3')
109 | self.bn3 = FrozenBatchNorm2D(name='bn3')
110 |
111 | self.downsample_conv = Conv2D(dim2, kernel_size=1, strides=strides,
112 | use_bias=False, name='downsample_0')
113 | self.downsample_bn = FrozenBatchNorm2D(name='downsample_1')
114 |
115 |
116 | def call(self, x):
117 | identity = x
118 |
119 | out = self.conv1(x)
120 | out = self.bn1(out)
121 | out = self.relu(out)
122 |
123 | out = self.pad(out)
124 | out = self.conv2(out)
125 | out = self.bn2(out)
126 | out = self.relu(out)
127 |
128 | out = self.conv3(out)
129 | out = self.bn3(out)
130 |
131 | if self.downsample:
132 | identity = self.downsample_bn(self.downsample_conv(x))
133 |
134 | out += identity
135 | out = self.relu(out)
136 |
137 | return out
--------------------------------------------------------------------------------
/detr_tf/networks/transformer.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras.layers import Dropout, Activation, LayerNormalization
3 |
4 | from .custom_layers import Linear
5 |
6 |
7 | class Transformer(tf.keras.Model):
8 | def __init__(self, model_dim=256, num_heads=8, num_encoder_layers=6,
9 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
10 | activation='relu', normalize_before=False,
11 | return_intermediate_dec=False, **kwargs):
12 | super().__init__(**kwargs)
13 |
14 | self.model_dim = model_dim
15 | self.num_heads = num_heads
16 |
17 | enc_norm = LayerNormalization(epsilon=1e-5, name='norm_pre') if normalize_before else None
18 | self.encoder = TransformerEncoder(model_dim, num_heads, dim_feedforward,
19 | dropout, activation, normalize_before, enc_norm,
20 | num_encoder_layers, name='encoder')
21 |
22 | dec_norm = LayerNormalization(epsilon=1e-5, name='norm')
23 | self.decoder = TransformerDecoder(model_dim, num_heads, dim_feedforward,
24 | dropout, activation, normalize_before, dec_norm,
25 | num_decoder_layers, name='decoder',
26 | return_intermediate=return_intermediate_dec)
27 |
28 |
29 | def call(self, source, mask, query_encoding, pos_encoding, training=False):
30 |
31 | batch_size, rows, cols = [tf.shape(source)[i] for i in range(3)]
32 | source = tf.reshape(source, [batch_size, -1, self.model_dim])
33 | source = tf.transpose(source, [1, 0, 2])
34 |
35 |
36 |
37 | pos_encoding = tf.reshape(pos_encoding, [batch_size, -1, self.model_dim])
38 | pos_encoding = tf.transpose(pos_encoding, [1, 0, 2])
39 |
40 | query_encoding = tf.expand_dims(query_encoding, axis=1)
41 | query_encoding = tf.tile(query_encoding, [1, batch_size, 1])
42 |
43 | mask = tf.reshape(mask, [batch_size, -1])
44 |
45 | target = tf.zeros_like(query_encoding)
46 |
47 | memory = self.encoder(source, source_key_padding_mask=mask,
48 | pos_encoding=pos_encoding, training=training)
49 | hs = self.decoder(target, memory, memory_key_padding_mask=mask,
50 | pos_encoding=pos_encoding, query_encoding=query_encoding,
51 | training=training)
52 |
53 | hs = tf.transpose(hs, [0, 2, 1, 3])
54 | memory = tf.transpose(memory, [1, 0, 2])
55 | memory = tf.reshape(memory, [batch_size, rows, cols, self.model_dim])
56 |
57 | return hs, memory
58 |
59 |
60 | class TransformerEncoder(tf.keras.Model):
61 | def __init__(self, model_dim=256, num_heads=8, dim_feedforward=2048,
62 | dropout=0.1, activation='relu', normalize_before=False, norm=None,
63 | num_encoder_layers=6, **kwargs):
64 | super().__init__(**kwargs)
65 |
66 | self.enc_layers = [EncoderLayer(model_dim, num_heads, dim_feedforward,
67 | dropout, activation, normalize_before,
68 | name='layer_%d'%i)
69 | for i in range(num_encoder_layers)]
70 |
71 | self.norm = norm
72 |
73 |
74 | def call(self, source, mask=None, source_key_padding_mask=None,
75 | pos_encoding=None, training=False):
76 | x = source
77 |
78 |
79 | for layer in self.enc_layers:
80 | x = layer(x, source_mask=mask, source_key_padding_mask=source_key_padding_mask,
81 | pos_encoding=pos_encoding, training=training)
82 |
83 | if self.norm:
84 | x = self.norm(x)
85 |
86 | return x
87 |
88 |
89 | class TransformerDecoder(tf.keras.Model):
90 | def __init__(self, model_dim=256, num_heads=8, dim_feedforward=2048,
91 | dropout=0.1, activation='relu', normalize_before=False, norm=None,
92 | num_decoder_layers=6, return_intermediate=False, **kwargs):
93 | super().__init__(**kwargs)
94 |
95 | self.dec_layers = [DecoderLayer(model_dim, num_heads, dim_feedforward,
96 | dropout, activation, normalize_before,
97 | name='layer_%d'%i)
98 | for i in range(num_decoder_layers)]
99 |
100 | self.norm = norm
101 | self.return_intermediate = return_intermediate
102 |
103 |
104 | def call(self, target, memory, target_mask=None, memory_mask=None,
105 | target_key_padding_mask=None, memory_key_padding_mask=None,
106 | pos_encoding=None, query_encoding=None, training=False):
107 |
108 | x = target
109 | intermediate = []
110 |
111 |
112 | for layer in self.dec_layers:
113 | x = layer(x, memory,
114 | target_mask=target_mask,
115 | memory_mask=memory_mask,
116 | target_key_padding_mask=target_key_padding_mask,
117 | memory_key_padding_mask=memory_key_padding_mask,
118 | pos_encoding=pos_encoding,
119 | query_encoding=query_encoding)
120 |
121 | if self.return_intermediate:
122 | if self.norm:
123 | intermediate.append(self.norm(x))
124 | else:
125 | intermediate.append(x)
126 |
127 | if self.return_intermediate:
128 | return tf.stack(intermediate, axis=0)
129 |
130 | if self.norm:
131 | x = self.norm(x)
132 |
133 | return x
134 |
135 |
136 | class EncoderLayer(tf.keras.layers.Layer):
137 | def __init__(self, model_dim=256, num_heads=8, dim_feedforward=2048,
138 | dropout=0.1, activation='relu', normalize_before=False,
139 | **kwargs):
140 | super().__init__(**kwargs)
141 |
142 | self.self_attn = MultiHeadAttention(model_dim, num_heads, dropout=dropout,
143 | name='self_attn')
144 |
145 | self.dropout = Dropout(dropout)
146 | self.activation = Activation(activation)
147 |
148 | self.linear1 = Linear(dim_feedforward, name='linear1')
149 | self.linear2 = Linear(model_dim, name='linear2')
150 |
151 | self.norm1 = LayerNormalization(epsilon=1e-5, name='norm1')
152 | self.norm2 = LayerNormalization(epsilon=1e-5, name='norm2')
153 |
154 | self.normalize_before = normalize_before
155 |
156 |
157 | def call(self, source, source_mask=None, source_key_padding_mask=None,
158 | pos_encoding=None, training=False):
159 |
160 |
161 | if pos_encoding is None:
162 | query = key = source
163 | else:
164 | query = key = source + pos_encoding
165 |
166 | attn_source = self.self_attn((query, key, source), attn_mask=source_mask,
167 | key_padding_mask=source_key_padding_mask,
168 | need_weights=False)
169 | source += self.dropout(attn_source, training=training)
170 | source = self.norm1(source)
171 |
172 | x = self.linear1(source)
173 | x = self.activation(x)
174 | x = self.dropout(x, training=training)
175 | x = self.linear2(x)
176 | source += self.dropout(x, training=training)
177 | source = self.norm2(source)
178 |
179 | return source
180 |
181 |
182 |
183 | class DecoderLayer(tf.keras.layers.Layer):
184 | def __init__(self, model_dim=256, num_heads=8, dim_feedforward=2048,
185 | dropout=0.1, activation='relu', normalize_before=False,
186 | **kwargs):
187 | super().__init__(**kwargs)
188 |
189 | self.self_attn = MultiHeadAttention(model_dim, num_heads, dropout=dropout,
190 | name='self_attn')
191 | self.multihead_attn = MultiHeadAttention(model_dim, num_heads, dropout=dropout,
192 | name='multihead_attn')
193 |
194 | self.dropout = Dropout(dropout)
195 | self.activation = Activation(activation)
196 |
197 | self.linear1 = Linear(dim_feedforward, name='linear1')
198 | self.linear2 = Linear(model_dim, name='linear2')
199 |
200 | self.norm1 = LayerNormalization(epsilon=1e-5, name='norm1')
201 | self.norm2 = LayerNormalization(epsilon=1e-5, name='norm2')
202 | self.norm3 = LayerNormalization(epsilon=1e-5, name='norm3')
203 |
204 | self.normalize_before = normalize_before
205 |
206 |
207 | def call(self, target, memory, target_mask=None, memory_mask=None,
208 | target_key_padding_mask=None, memory_key_padding_mask=None,
209 | pos_encoding=None, query_encoding=None, training=False):
210 |
211 | query_tgt = key_tgt = target + query_encoding
212 | attn_target = self.self_attn((query_tgt, key_tgt, target), attn_mask=target_mask,
213 | key_padding_mask=target_key_padding_mask,
214 | need_weights=False)
215 | target += self.dropout(attn_target, training=training)
216 | target = self.norm1(target)
217 |
218 | query_tgt = target + query_encoding
219 | key_mem = memory + pos_encoding
220 |
221 | attn_target2 = self.multihead_attn((query_tgt, key_mem, memory), attn_mask=memory_mask,
222 | key_padding_mask=memory_key_padding_mask,
223 | need_weights=False)
224 | target += self.dropout(attn_target2, training=training)
225 | target = self.norm2(target)
226 |
227 | x = self.linear1(target)
228 | x = self.activation(x)
229 | x = self.dropout(x, training=training)
230 | x = self.linear2(x)
231 | target += self.dropout(x, training=training)
232 | target = self.norm3(target)
233 |
234 | return target
235 |
236 |
237 | class MultiHeadAttention(tf.keras.layers.Layer):
238 | def __init__(self, model_dim, num_heads, dropout=0.0, **kwargs):
239 | super().__init__(**kwargs)
240 |
241 | self.model_dim = model_dim
242 | self.num_heads = num_heads
243 |
244 | assert model_dim % num_heads == 0
245 | self.head_dim = model_dim // num_heads
246 |
247 | self.dropout = Dropout(rate=dropout)
248 |
249 |
250 | def build(self, input_shapes):
251 | in_dim = sum([shape[-1] for shape in input_shapes[:3]])
252 |
253 | self.in_proj_weight = self.add_weight(
254 | name='in_proj_kernel', shape=(in_dim, self.model_dim),
255 | initializer=tf.keras.initializers.GlorotUniform(), dtype=tf.float32, trainable=True
256 | )
257 | self.in_proj_bias = self.add_weight(
258 | name='in_proj_bias', shape=(in_dim,),
259 | initializer=tf.keras.initializers.GlorotUniform(), dtype=tf.float32, trainable=True
260 | )
261 | self.out_proj_weight = self.add_weight(
262 | name='out_proj_kernel', shape=(self.model_dim, self.model_dim),
263 | initializer=tf.keras.initializers.GlorotUniform(), dtype=tf.float32, trainable=True
264 | )
265 | self.out_proj_bias = self.add_weight(
266 | name='out_proj_bias', shape=(self.model_dim,),
267 | initializer=tf.keras.initializers.GlorotUniform(), dtype=tf.float32, trainable=True
268 | )
269 |
270 |
271 |
272 |
273 | #self.in_proj_weight = tf.Variable(
274 | # tf.zeros((in_dim, self.model_dim), dtype=tf.float32), name='in_proj_kernel')
275 | #self.in_proj_bias = tf.Variable(tf.zeros((in_dim,), dtype=tf.float32),
276 | # name='in_proj_bias')
277 |
278 | #self.out_proj_weight = tf.Variable(
279 | # tf.zeros((self.model_dim, self.model_dim), dtype=tf.float32), name='out_proj_kernel')
280 | #self.out_proj_bias = tf.Variable(
281 | # tf.zeros((self.model_dim,), dtype=tf.float32), name='out_proj_bias')
282 |
283 |
284 |
285 | def call(self, inputs, attn_mask=None, key_padding_mask=None,
286 | need_weights=True, training=False):
287 |
288 | query, key, value = inputs
289 |
290 | batch_size = tf.shape(query)[1]
291 | target_len = tf.shape(query)[0]
292 | source_len = tf.shape(key)[0]
293 |
294 | W = self.in_proj_weight[:self.model_dim, :]
295 | b = self.in_proj_bias[:self.model_dim]
296 |
297 | WQ = tf.matmul(query, W, transpose_b=True) + b
298 |
299 | W = self.in_proj_weight[self.model_dim:2*self.model_dim, :]
300 | b = self.in_proj_bias[self.model_dim:2*self.model_dim]
301 | WK = tf.matmul(key, W, transpose_b=True) + b
302 |
303 | W = self.in_proj_weight[2*self.model_dim:, :]
304 | b = self.in_proj_bias[2*self.model_dim:]
305 | WV = tf.matmul(value, W, transpose_b=True) + b
306 |
307 | WQ *= float(self.head_dim) ** -0.5
308 | WQ = tf.reshape(WQ, [target_len, batch_size * self.num_heads, self.head_dim])
309 | WQ = tf.transpose(WQ, [1, 0, 2])
310 |
311 | WK = tf.reshape(WK, [source_len, batch_size * self.num_heads, self.head_dim])
312 | WK = tf.transpose(WK, [1, 0, 2])
313 |
314 | WV = tf.reshape(WV, [source_len, batch_size * self.num_heads, self.head_dim])
315 | WV = tf.transpose(WV, [1, 0, 2])
316 |
317 | attn_output_weights = tf.matmul(WQ, WK, transpose_b=True)
318 |
319 | if attn_mask is not None:
320 | attn_output_weights += attn_mask
321 |
322 | """
323 | if key_padding_mask is not None:
324 | attn_output_weights = tf.reshape(attn_output_weights,
325 | [batch_size, self.num_heads, target_len, source_len])
326 |
327 | key_padding_mask = tf.expand_dims(key_padding_mask, 1)
328 | key_padding_mask = tf.expand_dims(key_padding_mask, 2)
329 | key_padding_mask = tf.tile(key_padding_mask, [1, self.num_heads, target_len, 1])
330 |
331 | #print("before attn_output_weights", attn_output_weights.shape)
332 | attn_output_weights = tf.where(key_padding_mask,
333 | tf.zeros_like(attn_output_weights) + float('-inf'),
334 | attn_output_weights)
335 | attn_output_weights = tf.reshape(attn_output_weights,
336 | [batch_size * self.num_heads, target_len, source_len])
337 | """
338 |
339 |
340 | attn_output_weights = tf.nn.softmax(attn_output_weights, axis=-1)
341 | attn_output_weights = self.dropout(attn_output_weights, training=training)
342 |
343 | attn_output = tf.matmul(attn_output_weights, WV)
344 | attn_output = tf.transpose(attn_output, [1, 0, 2])
345 | attn_output = tf.reshape(attn_output, [target_len, batch_size, self.model_dim])
346 | attn_output = tf.matmul(attn_output, self.out_proj_weight,
347 | transpose_b=True) + self.out_proj_bias
348 |
349 | if need_weights:
350 | attn_output_weights = tf.reshape(attn_output_weights,
351 | [batch_size, self.num_heads, target_len, source_len])
352 | # Retrun the average weight over the heads
353 | avg_weights = tf.reduce_mean(attn_output_weights, axis=1)
354 | return attn_output, avg_weights
355 |
356 | return attn_output
357 |
--------------------------------------------------------------------------------
/detr_tf/networks/weights.py:
--------------------------------------------------------------------------------
1 | import os
2 | import requests
3 |
4 |
5 | WEIGHT_NAME_TO_CKPT = {
6 | "detr": [
7 | "https://storage.googleapis.com/visualbehavior-publicweights/detr/checkpoint",
8 | "https://storage.googleapis.com/visualbehavior-publicweights/detr/detr.ckpt.data-00000-of-00001",
9 | "https://storage.googleapis.com/visualbehavior-publicweights/detr/detr.ckpt.index"
10 | ]
11 | }
12 |
13 | def load_weights(model, weights: str):
14 | """ Load weight on a given model
15 | weights are supposed to be sotred in the weight folder at the root of the repository. If weights
16 | does not exists, but are publicly known, the weight will be download from gcloud.
17 | """
18 | if not os.path.exists('weights'):
19 | os.makedirs('weights')
20 |
21 | if "ckpt" in "weights":
22 | model.load(weights)
23 | elif weights in WEIGHT_NAME_TO_CKPT:
24 | wdir = f"weights/{weights}"
25 | if not os.path.exists(wdir):
26 | os.makedirs(wdir)
27 | for f in WEIGHT_NAME_TO_CKPT[weights]:
28 | fname = f.split("/")[-1]
29 | if not os.path.exists(os.path.join(wdir, fname)):
30 | print("Download....", f)
31 | r = requests.get(f, allow_redirects=True)
32 | open(os.path.join(wdir, fname), 'wb').write(r.content)
33 | print("Load weights from", os.path.join(wdir, f"{weights}.ckpt"))
34 | l = model.load_weights(os.path.join(wdir, f"{weights}.ckpt"))
35 | l.expect_partial()
36 | else:
37 | raise Exception(f'Cant load the weights: {weights}')
--------------------------------------------------------------------------------
/detr_tf/optimizers.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | def disable_batchnorm_training(model):
4 | for l in model.layers:
5 | if hasattr(l, "layers"):
6 | disable_batchnorm_training(l)
7 | elif isinstance(l, tf.keras.layers.BatchNormalization):
8 | l.trainable = False
9 |
10 | def get_transformers_trainable_variables(model, exclude=[]):
11 | transformers_variables = []
12 |
13 | # Transformers variables
14 | transformers_variables = model.get_layer("detr").get_layer("transformer").trainable_variables
15 |
16 | for layer in model.layers[2:]:
17 | if layer.name not in exclude:
18 | transformers_variables += layer.trainable_variables
19 | else:
20 | pass
21 |
22 | return transformers_variables
23 |
24 |
25 | def get_backbone_trainable_variables(model):
26 | backbone_variables = []
27 | # layer [1] is the detr model including the backbone and the transformers
28 |
29 | detr = model.get_layer("detr")
30 | tr_index = [l.name for l in detr.layers].index('transformer')
31 |
32 | for l, layer in enumerate(detr.layers):
33 | if l != tr_index:
34 | backbone_variables += layer.trainable_variables
35 |
36 | return backbone_variables
37 |
38 |
39 | def get_nlayers_trainables_variables(model, nlayers_names):
40 | nlayers_variables = []
41 | for nlayer_name in nlayers_names:
42 | nlayers_variables += model.get_layer(nlayer_name).trainable_variables
43 | return nlayers_variables
44 |
45 |
46 | def get_trainable_variables(model, config):
47 |
48 | disable_batchnorm_training(model)
49 |
50 | backbone_variables = []
51 | transformers_variables = []
52 | nlayers_variables = []
53 |
54 |
55 | # Retrieve the gradient ofr each trainable variables
56 | #if config.train_backbone:
57 | backbone_variables = get_backbone_trainable_variables(model)
58 | #if config.train_transformers:
59 | transformers_variables = get_transformers_trainable_variables(model, exclude=config.nlayers)
60 | #if config.train_nlayers:
61 | nlayers_variables = get_nlayers_trainables_variables(model, config.nlayers)
62 |
63 |
64 | return backbone_variables, transformers_variables, nlayers_variables
65 |
66 |
67 | def setup_optimizers(model, config):
68 | """ Method call by the Scheduler to init user data
69 | """
70 | @tf.function
71 | def get_backbone_learning_rate():
72 | return config.backbone_lr
73 |
74 | @tf.function
75 | def get_transformers_learning_rate():
76 | return config.transformers_lr
77 |
78 | @tf.function
79 | def get_nlayers_learning_rate():
80 | return config.nlayers_lr
81 |
82 | # Disable batch norm on the backbone
83 | disable_batchnorm_training(model)
84 |
85 | # Optimizers
86 | backbone_optimizer = tf.keras.optimizers.Adam(learning_rate=get_backbone_learning_rate, clipnorm=config.gradient_norm_clipping)
87 | transformers_optimizer = tf.keras.optimizers.Adam(learning_rate=get_transformers_learning_rate, clipnorm=config.gradient_norm_clipping)
88 | nlayers_optimizer = tf.keras.optimizers.Adam(learning_rate=get_nlayers_learning_rate, clipnorm=config.gradient_norm_clipping)
89 |
90 | # Set trainable variables
91 |
92 | backbone_variables, transformers_variables, nlayers_variables = [], [], []
93 |
94 | backbone_variables = get_backbone_trainable_variables(model)
95 | transformers_variables = get_transformers_trainable_variables(model, exclude=config.nlayers)
96 | nlayers_variables = get_nlayers_trainables_variables(model, config.nlayers)
97 |
98 |
99 | return {
100 | "backbone_optimizer": backbone_optimizer,
101 | "transformers_optimizer": transformers_optimizer,
102 | "nlayers_optimizer": nlayers_optimizer,
103 |
104 | "backbone_variables": backbone_variables,
105 | "transformers_variables": transformers_variables,
106 | "nlayers_variables": nlayers_variables,
107 | }
108 |
109 |
110 | def gather_gradient(model, optimizers, total_loss, tape, config, log):
111 |
112 | backbone_variables, transformers_variables, nlayers_variables = get_trainable_variables(model, config)
113 | trainables_variables = backbone_variables + transformers_variables + nlayers_variables
114 |
115 | gradients = tape.gradient(total_loss, trainables_variables)
116 |
117 | # Retrieve the gradients from the tap
118 | backbone_gradients = gradients[:len(optimizers["backbone_variables"])]
119 | transformers_gradients = gradients[len(optimizers["backbone_variables"]):len(optimizers["backbone_variables"])+len(optimizers["transformers_variables"])]
120 | nlayers_gradients = gradients[len(optimizers["backbone_variables"])+len(optimizers["transformers_variables"]):]
121 |
122 | gradient_steps = {}
123 |
124 | gradient_steps["backbone"] = {"gradients": backbone_gradients}
125 | gradient_steps["transformers"] = {"gradients": transformers_gradients}
126 | gradient_steps["nlayers"] = {"gradients": nlayers_gradients}
127 |
128 |
129 | log.update({"backbone_lr": optimizers["backbone_optimizer"]._serialize_hyperparameter("learning_rate")})
130 | log.update({"transformers_lr": optimizers["transformers_optimizer"]._serialize_hyperparameter("learning_rate")})
131 | log.update({"nlayers_lr": optimizers["nlayers_optimizer"]._serialize_hyperparameter("learning_rate")})
132 |
133 | return gradient_steps
134 |
135 |
136 |
137 | def aggregate_grad_and_apply(name, optimizers, gradients, step, config):
138 |
139 | gradient_aggregate = None
140 | if config.target_batch is not None:
141 | gradient_aggregate = int(config.target_batch // config.batch_size)
142 |
143 | gradient_name = "{}_gradients".format(name)
144 | optimizer_name = "{}_optimizer".format(name)
145 | variables_name = "{}_variables".format(name)
146 | train_part_name = "train_{}".format(name)
147 |
148 | if getattr(config, train_part_name):
149 |
150 | # Init the aggregate gradient
151 | if gradient_aggregate is not None and step % gradient_aggregate == 0:
152 | optimizers[gradient_name] = [tf.zeros_like(tv) for tv in optimizers[variables_name]]
153 |
154 |
155 | if gradient_aggregate is not None:
156 | # Aggregate the gradient
157 | optimizers[gradient_name] = [(gradient+n_gradient) if n_gradient is not None else None for gradient, n_gradient in zip(optimizers[gradient_name], gradients) ]
158 | else:
159 | optimizers[gradient_name] = gradients
160 |
161 | # Apply gradient if no gradient aggregate or if we finished gathering gradient oversteps
162 | if gradient_aggregate is None or (step+1) % gradient_aggregate == 0:
163 | optimizers[optimizer_name].apply_gradients(zip(optimizers[gradient_name], optimizers[variables_name]))
--------------------------------------------------------------------------------
/detr_tf/training.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from .optimizers import gather_gradient, aggregate_grad_and_apply
4 | from .logger.training_logging import valid_log, train_log
5 | from .loss.loss import get_losses
6 | import time
7 | import wandb
8 |
9 | @tf.function
10 | def run_train_step(model, images, t_bbox, t_class, optimizers, config):
11 |
12 | if config.target_batch is not None:
13 | gradient_aggregate = int(config.target_batch // config.batch_size)
14 | else:
15 | gradient_aggregate = 1
16 |
17 | with tf.GradientTape() as tape:
18 | m_outputs = model(images, training=True)
19 | total_loss, log = get_losses(m_outputs, t_bbox, t_class, config)
20 | total_loss = total_loss / gradient_aggregate
21 |
22 | # Compute gradient for each part of the network
23 | gradient_steps = gather_gradient(model, optimizers, total_loss, tape, config, log)
24 |
25 | return m_outputs, total_loss, log, gradient_steps
26 |
27 |
28 | @tf.function
29 | def run_val_step(model, images, t_bbox, t_class, config):
30 | m_outputs = model(images, training=False)
31 | total_loss, log = get_losses(m_outputs, t_bbox, t_class, config)
32 | return m_outputs, total_loss, log
33 |
34 |
35 | def fit(model, train_dt, optimizers, config, epoch_nb, class_names):
36 | """ Train the model for one epoch
37 | """
38 | # Aggregate the gradient for bigger batch and better convergence
39 | gradient_aggregate = None
40 | if config.target_batch is not None:
41 | gradient_aggregate = int(config.target_batch // config.batch_size)
42 | t = None
43 | for epoch_step , (images, t_bbox, t_class) in enumerate(train_dt):
44 |
45 | # Run the prediction and retrieve the gradient step for each part of the network
46 | m_outputs, total_loss, log, gradient_steps = run_train_step(model, images, t_bbox, t_class, optimizers, config)
47 |
48 | # Load the predictions
49 | if config.log:
50 | train_log(images, t_bbox, t_class, m_outputs, config, config.global_step, class_names, prefix="train/")
51 |
52 | # Aggregate and apply the gradient
53 | for name in gradient_steps:
54 | aggregate_grad_and_apply(name, optimizers, gradient_steps[name]["gradients"], epoch_step, config)
55 |
56 | # Log every 100 steps
57 | if epoch_step % 100 == 0:
58 | t = t if t is not None else time.time()
59 | elapsed = time.time() - t
60 | print(f"Epoch: [{epoch_nb}], \t Step: [{epoch_step}], \t ce: [{log['label_cost']:.2f}] \t giou : [{log['giou_loss']:.2f}] \t l1 : [{log['l1_loss']:.2f}] \t time : [{elapsed:.2f}]")
61 | if config.log:
62 | wandb.log({f"train/{k}":log[k] for k in log}, step=config.global_step)
63 | t = time.time()
64 |
65 | config.global_step += 1
66 |
67 |
68 | def eval(model, valid_dt, config, class_name, evaluation_step=200):
69 | """ Evaluate the model on the validation set
70 | """
71 | t = None
72 | for val_step, (images, t_bbox, t_class) in enumerate(valid_dt):
73 | # Run prediction
74 | m_outputs, total_loss, log = run_val_step(model, images, t_bbox, t_class, config)
75 | # Log the predictions
76 | if config.log:
77 | valid_log(images, t_bbox, t_class, m_outputs, config, val_step, config.global_step, class_name, evaluation_step=evaluation_step, prefix="train/")
78 | # Log the metrics
79 | if config.log and val_step == 0:
80 | wandb.log({f"val/{k}":log[k] for k in log}, step=config.global_step)
81 | # Log the progress
82 | if val_step % 10 == 0:
83 | t = t if t is not None else time.time()
84 | elapsed = time.time() - t
85 | print(f"Validation step: [{val_step}], \t ce: [{log['label_cost']:.2f}] \t giou : [{log['giou_loss']:.2f}] \t l1 : [{log['l1_loss']:.2f}] \t time : [{elapsed:.2f}]")
86 | if val_step+1 >= evaluation_step:
87 | break
88 |
--------------------------------------------------------------------------------
/detr_tf/training_config.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import argparse
3 | import os
4 |
5 |
6 | def training_config_parser():
7 | """ Training config class can be overide using the script arguments
8 | """
9 | parser = argparse.ArgumentParser()
10 |
11 | # Dataset info
12 | parser.add_argument("--data_dir", type=str, required=False, help="Path to the dataset directory")
13 | parser.add_argument("--img_dir", type=str, required=False, help="Image directory relative to data_dir")
14 | parser.add_argument("--ann_file", type=str, required=False, help="Annotation file relative to data_dir")
15 | parser.add_argument("--ann_dir", type=str, required=False, help="Annotation directory relative to data_dir")
16 |
17 | parser.add_argument("--background_class", type=int, required=False, default=0, help="Default background class")
18 |
19 | # What to train
20 | parser.add_argument("--train_backbone", action='store_true', required=False, default=False, help="Train backbone")
21 | parser.add_argument("--train_transformers", action='store_true', required=False, default=False, help="Train transformers")
22 | parser.add_argument("--train_nlayers", action='store_true', required=False, default=False, help="Train new layers")
23 |
24 | # How to train
25 | parser.add_argument("--finetuning", default=False, required=False, action='store_true', help="Load the model weight before to train")
26 | parser.add_argument("--batch_size", type=int, required=False, default=1, help="Batch size to use to train the model")
27 | parser.add_argument("--gradient_norm_clipping", type=float, required=False, default=0.1, help="Gradient norm clipping")
28 | parser.add_argument("--target_batch", type=int, required=False, default=None, help="When running on a single GPU, aggretate the gradient before to apply.")
29 |
30 | # Learning rate
31 | parser.add_argument("--backbone_lr", type=bool, required=False, default=1e-5, help="Train backbone")
32 | parser.add_argument("--transformers_lr", type=bool, required=False, default=1e-4, help="Train transformers")
33 | parser.add_argument("--nlayers_lr", type=bool, required=False, default=1e-4, help="Train new layers")
34 |
35 | # Logging
36 | parser.add_argument("--log", required=False, action="store_true", default=False, help="Log into wandb")
37 |
38 | return parser
39 |
40 |
41 | class TrainingConfig():
42 |
43 | def __init__(self):
44 |
45 | # Dataset info
46 | self.data_dir, self.img_dir, self.ann_dir, self.ann_file = None, None, None, None
47 | self.data = DataConfig(data_dir=None, img_dir=None, ann_file=None, ann_dir=None)
48 | self.background_class = 0
49 | self.image_size = 376, 672
50 |
51 | # What to train
52 | self.train_backbone = False
53 | self.train_transformers = False
54 | self.train_nlayers = False
55 |
56 | # How to train
57 | self.finetuning = False
58 | self.batch_size = 1
59 | self.gradient_norm_clipping = 0.1
60 | # Batch aggregate before to backprop
61 | self.target_batch = 1
62 |
63 | # Learning rate
64 | # Set as tf.Variable so that the variable can be update during the training while
65 | # keeping the same graph
66 | self.backbone_lr = tf.Variable(1e-5)
67 | self.transformers_lr = tf.Variable(1e-4)
68 | self.nlayers_lr = tf.Variable(1e-4)
69 | self.nlayers = []
70 |
71 | # Training progress
72 | self.global_step = 0
73 | self.log = False
74 |
75 | # Pipeline variables
76 | self.normalized_method = "torch_resnet"
77 |
78 |
79 | def add_nlayers(self, layers):
80 | """ Set the new layers to train on the training config
81 | """
82 | self.nlayers = [l.name for l in layers]
83 |
84 |
85 | def update_from_args(self, args):
86 | """ Update the training config from args
87 | """
88 | args = vars(args)
89 | for key in args:
90 | if isinstance(getattr(self, key), tf.Variable):
91 | getattr(self, key).assign(args[key])
92 | else:
93 | setattr(self, key, args[key])
94 |
95 | # Set the config on the data class
96 |
97 |
98 | self.data = DataConfig(
99 | data_dir=self.data_dir,
100 | img_dir=self.img_dir,
101 | ann_file=self.ann_file,
102 | ann_dir=self.ann_dir
103 | )
104 |
105 |
106 | class DataConfig():
107 |
108 | def __init__(self, data_dir=None, img_dir=None, ann_file=None, ann_dir=None):
109 | self.data_dir = data_dir
110 | self.img_dir = os.path.join(data_dir, img_dir) if data_dir is not None and img_dir is not None else None
111 | self.ann_file = os.path.join(self.data_dir, ann_file) if ann_file is not None else None
112 | self.ann_dir = os.path.join(self.data_dir, ann_dir) if ann_dir is not None else None
113 |
114 |
115 | if __name__ == "__main__":
116 | args = training_config_parser()
117 | config = TrainingConfig()
118 | config.update_from_args(args)
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | """ Eval a model on the coco dataset
2 | """
3 |
4 | import argparse
5 | import tensorflow as tf
6 | import os
7 | import matplotlib.pyplot as plt
8 | import numpy as np
9 |
10 | from detr_tf.inference import get_model_inference
11 | from detr_tf.data.coco import load_coco_dataset
12 | from detr_tf.loss.compute_map import cal_map, calc_map, APDataObject
13 | from detr_tf.networks.detr import get_detr_model
14 | from detr_tf.bbox import xcycwh_to_xy_min_xy_max, xcycwh_to_yx_min_yx_max
15 | from detr_tf.inference import numpy_bbox_to_image
16 | from detr_tf.training_config import TrainingConfig, training_config_parser
17 |
18 |
19 | def build_model(config):
20 | """ Build the model with the pretrained weights. In this example
21 | we do not add new layers since the pretrained model is already trained on coco.
22 | See examples/finetuning_voc.py to add new layers.
23 | """
24 | # Load the pretrained model
25 | detr = get_detr_model(config, include_top=True, weights="detr")
26 | detr.summary()
27 | return detr
28 |
29 |
30 | def eval_model(model, config, class_names, valid_dt):
31 | """ Run evaluation
32 | """
33 |
34 | iou_thresholds = [x / 100. for x in range(50, 100, 5)]
35 | ap_data = {
36 | 'box' : [[APDataObject() for _ in class_names] for _ in iou_thresholds],
37 | 'mask': [[APDataObject() for _ in class_names] for _ in iou_thresholds]
38 | }
39 | it = 0
40 |
41 | for images, target_bbox, target_class in valid_dt:
42 | # Forward pass
43 | m_outputs = model(images)
44 | # Run predictions
45 | p_bbox, p_labels, p_scores = get_model_inference(m_outputs, config.background_class, bbox_format="yxyx")
46 | # Remove padding
47 | t_bbox, t_class = target_bbox[0], target_class[0]
48 | size = tf.cast(t_bbox[0][0], tf.int32)
49 | t_bbox = tf.slice(t_bbox, [1, 0], [size, 4])
50 | t_bbox = xcycwh_to_yx_min_yx_max(t_bbox)
51 | t_class = tf.slice(t_class, [1, 0], [size, -1])
52 | t_class = tf.squeeze(t_class, axis=-1)
53 | # Compute map
54 | cal_map(p_bbox, p_labels, p_scores, np.zeros((138, 138, len(p_bbox))), np.array(t_bbox), np.array(t_class), np.zeros((138, 138, len(t_bbox))), ap_data, iou_thresholds)
55 | print(f"Computing map.....{it}", end="\r")
56 | it += 1
57 | #if it > 10:
58 | # break
59 |
60 | # Compute the mAp over all thresholds
61 | calc_map(ap_data, iou_thresholds, class_names, print_result=True)
62 |
63 | if __name__ == "__main__":
64 |
65 | physical_devices = tf.config.list_physical_devices('GPU')
66 | if len(physical_devices) == 1:
67 | tf.config.experimental.set_memory_growth(physical_devices[0], True)
68 |
69 | config = TrainingConfig()
70 | args = training_config_parser().parse_args()
71 | config.update_from_args(args)
72 |
73 | # Load the model with the new layers to finetune
74 | detr = build_model(config)
75 |
76 | valid_dt, class_names = load_coco_dataset(config, 1, augmentation=None)
77 |
78 | # Run training
79 | eval_model(detr, config, class_names, valid_dt)
80 |
81 |
82 |
--------------------------------------------------------------------------------
/finetune_coco.py:
--------------------------------------------------------------------------------
1 | """ Example on how to finetune on COCO dataset
2 | """
3 |
4 | import argparse
5 | import matplotlib.pyplot as plt
6 | import tensorflow as tf
7 | import numpy as np
8 | import time
9 | import os
10 |
11 | from detr_tf.data.coco import load_coco_dataset
12 | from detr_tf.networks.detr import get_detr_model
13 | from detr_tf.optimizers import setup_optimizers
14 | from detr_tf.optimizers import gather_gradient, aggregate_grad_and_apply
15 | from detr_tf.logger.training_logging import train_log, valid_log
16 | from detr_tf.loss.loss import get_losses
17 | from detr_tf.training_config import TrainingConfig, training_config_parser
18 | from detr_tf import training
19 |
20 | try:
21 | # Should be optional if --log is not set
22 | import wandb
23 | except:
24 | wandb = None
25 |
26 |
27 | import time
28 |
29 |
30 | def build_model(config):
31 | """ Build the model with the pretrained weights. In this example
32 | we do not add new layers since the pretrained model is already trained on coco.
33 | See examples/finetuning_voc.py to add new layers.
34 | """
35 | # Load the pretrained model
36 | detr = get_detr_model(config, include_top=True, weights="detr")
37 | detr.summary()
38 | return detr
39 |
40 |
41 | def run_finetuning(config):
42 |
43 | # Load the model with the new layers to finetune
44 | detr = build_model(config)
45 |
46 | # Load the training and validation dataset
47 | train_dt, coco_class_names = load_coco_dataset("train", config.batch_size, config, augmentation=True)
48 | valid_dt, _ = load_coco_dataset("val", 1, config, augmentation=False)
49 |
50 | # Train/finetune the transformers only
51 | config.train_backbone = False
52 | config.train_transformers = True
53 |
54 | # Setup the optimziers and the trainable variables
55 | optimzers = setup_optimizers(detr, config)
56 |
57 | # Run the training for 5 epochs
58 | for epoch_nb in range(100):
59 | training.eval(detr, valid_dt, config, coco_class_names, evaluation_step=200)
60 | training.fit(detr, train_dt, optimzers, config, epoch_nb, coco_class_names)
61 |
62 |
63 | if __name__ == "__main__":
64 |
65 | physical_devices = tf.config.list_physical_devices('GPU')
66 | if len(physical_devices) == 1:
67 | tf.config.experimental.set_memory_growth(physical_devices[0], True)
68 |
69 | config = TrainingConfig()
70 | args = training_config_parser().parse_args()
71 | config.update_from_args(args)
72 |
73 | if config.log:
74 | wandb.init(project="detr-tensorflow", reinit=True)
75 |
76 | # Run training
77 | run_finetuning(config)
78 |
79 |
80 |
81 |
82 |
83 |
--------------------------------------------------------------------------------
/finetune_hardhat.py:
--------------------------------------------------------------------------------
1 | """ Example on how to finetune on the HardHat dataset
2 | using custom layers. This script assume the dataset is already download
3 | on your computer in raw and Tensorflow Object detection csv format.
4 |
5 | Please, for more information, checkout the following notebooks:
6 | - DETR : How to setup a custom dataset
7 | """
8 |
9 | import argparse
10 | import matplotlib.pyplot as plt
11 | import tensorflow as tf
12 | import numpy as np
13 | import time
14 | import os
15 |
16 | from detr_tf.data import load_tfcsv_dataset
17 |
18 | from detr_tf.networks.detr import get_detr_model
19 | from detr_tf.optimizers import setup_optimizers
20 | from detr_tf.logger.training_logging import train_log, valid_log
21 | from detr_tf.loss.loss import get_losses
22 | from detr_tf.inference import numpy_bbox_to_image
23 | from detr_tf.training_config import TrainingConfig, training_config_parser
24 | from detr_tf import training
25 |
26 | try:
27 | # Should be optional if --log is not set
28 | import wandb
29 | except:
30 | wandb = None
31 |
32 | import time
33 |
34 |
35 | def build_model(config):
36 | """ Build the model with the pretrained weights
37 | and add new layers to finetune
38 | """
39 | # Load the pretrained model with new heads at the top
40 | # 3 class : background head and helmet (we exclude here person from the dataset)
41 | detr = get_detr_model(config, include_top=False, nb_class=3, weights="detr", num_decoder_layers=6, num_encoder_layers=6)
42 | detr.summary()
43 | return detr
44 |
45 |
46 | def run_finetuning(config):
47 |
48 | # Load the model with the new layers to finetune
49 | detr = build_model(config)
50 |
51 | # Load the training and validation dataset and exclude the person class
52 | train_dt, class_names = load_tfcsv_dataset(
53 | config, config.batch_size, augmentation=True, exclude=["person"], ann_file="train/_annotations.csv", img_dir="train")
54 | valid_dt, _ = load_tfcsv_dataset(
55 | config, 4, augmentation=False, exclude=["person"], ann_file="test/_annotations.csv", img_dir="test")
56 |
57 | # Train/finetune the transformers only
58 | config.train_backbone = tf.Variable(False)
59 | config.train_transformers = tf.Variable(False)
60 | config.train_nlayers = tf.Variable(True)
61 | # Learning rate (NOTE: The transformers and the backbone are NOT trained with)
62 | # a 0 learning rate. They're not trained, but we set the LR to 0 just so that it is clear
63 | # in the log that both are not trained at the begining
64 | config.backbone_lr = tf.Variable(0.0)
65 | config.transformers_lr = tf.Variable(0.0)
66 | config.nlayers_lr = tf.Variable(1e-3)
67 |
68 | # Setup the optimziers and the trainable variables
69 | optimzers = setup_optimizers(detr, config)
70 |
71 | # Run the training for 180 epochs
72 | for epoch_nb in range(180):
73 |
74 | if epoch_nb > 0:
75 | # After the first epoch, we finetune the transformers and the new layers
76 | config.train_transformers.assign(True)
77 | config.transformers_lr.assign(1e-4)
78 | config.nlayers_lr.assign(1e-3)
79 |
80 | training.eval(detr, valid_dt, config, class_names, evaluation_step=100)
81 | training.fit(detr, train_dt, optimzers, config, epoch_nb, class_names)
82 |
83 |
84 | if __name__ == "__main__":
85 |
86 | physical_devices = tf.config.list_physical_devices('GPU')
87 | if len(physical_devices) == 1:
88 | tf.config.experimental.set_memory_growth(physical_devices[0], True)
89 |
90 | config = TrainingConfig()
91 | args = training_config_parser().parse_args()
92 | config.update_from_args(args)
93 |
94 | if config.log:
95 | wandb.init(project="detr-tensorflow", reinit=True)
96 |
97 | # Run training
98 | run_finetuning(config)
99 |
100 |
101 |
102 |
103 |
104 |
--------------------------------------------------------------------------------
/finetune_voc.py:
--------------------------------------------------------------------------------
1 | """ Example on how to finetune on the VOC dataset
2 | using custom layers.
3 | """
4 |
5 | import argparse
6 | import matplotlib.pyplot as plt
7 | import tensorflow as tf
8 | import numpy as np
9 | import time
10 |
11 | try:
12 | # Should be optional if --log is not set
13 | import wandb
14 | except:
15 | wandb = None
16 |
17 | import os
18 |
19 | from detr_tf.data import load_voc_dataset
20 | from detr_tf.networks.detr import get_detr_model
21 | from detr_tf.optimizers import setup_optimizers
22 | from detr_tf.training_config import TrainingConfig, training_config_parser
23 | from detr_tf import training
24 |
25 | VOC_CLASS_NAME = [
26 | 'aeroplane', 'bicycle', 'bird', 'boat',
27 | 'bottle', 'bus', 'car', 'cat', 'chair',
28 | 'cow', 'diningtable', 'dog', 'horse',
29 | 'motorbike', 'person', 'pottedplant',
30 | 'sheep', 'sofa', 'train', 'tvmonitor'
31 | ]
32 |
33 | def build_model(config):
34 | """ Build the model with the pretrained weights
35 | and add new layers to finetune
36 | """
37 | # Input
38 | image_input = tf.keras.Input((None, None, 3))
39 |
40 | # Load the pretrained model
41 | detr = get_detr_model(config, include_top=False, weights="detr", num_decoder_layers=6, num_encoder_layers=6)
42 |
43 | # Setup the new layers
44 | cls_layer = tf.keras.layers.Dense(len(VOC_CLASS_NAME) + 1, name="cls_layer")
45 | pos_layer = tf.keras.models.Sequential([
46 | tf.keras.layers.Dense(256, activation="relu"),
47 | tf.keras.layers.Dense(256, activation="relu"),
48 | tf.keras.layers.Dense(4, activation="sigmoid"),
49 | ], name="pos_layer")
50 | config.add_nlayers([cls_layer, pos_layer])
51 |
52 | transformer_output = detr(image_input)
53 | cls_preds = cls_layer(transformer_output)
54 | pos_preds = pos_layer(transformer_output)
55 |
56 | # Define the main outputs along with the auxialiary loss
57 | outputs = {'pred_logits': cls_preds[-1], 'pred_boxes': pos_preds[-1]}
58 | outputs["aux"] = [ {"pred_logits": cls_preds[i], "pred_boxes": pos_preds[i]} for i in range(0, 5)]
59 |
60 | detr = tf.keras.Model(image_input, outputs, name="detr_finetuning")
61 | detr.summary()
62 | return detr
63 |
64 |
65 | def run_finetuning(config):
66 |
67 | # Load the model with the new layers to finetune
68 | detr = build_model(config)
69 |
70 | # Load the training and validation dataset (for the purpose of this example we're gonna load the training
71 | # as the validation, but in practise you should have different folder loader for the training and the validation)
72 | train_dt, class_names = load_voc_dataset(config, config.batch_size, augmentation=True)
73 | valid_dt, _ = load_voc_dataset(config, 1, augmentation=False)
74 |
75 | # Train/finetune the transformers only
76 | config.train_backbone = tf.Variable(False)
77 | config.train_transformers = tf.Variable(False)
78 | config.train_nlayers = tf.Variable(True)
79 | # Learning rate (NOTE: The transformers and the backbone are NOT trained with)
80 | # a 0 learning rate. They're not trained, but we set the LR to 0 just so that it is clear
81 | # in the log that both are not trained at the begining
82 | config.backbone_lr = tf.Variable(0.0)
83 | config.transformers_lr = tf.Variable(0.0)
84 | config.nlayers_lr = tf.Variable(1e-3)
85 |
86 | # Setup the optimziers and the trainable variables
87 | optimzers = setup_optimizers(detr, config)
88 |
89 | # Run the training for 5 epochs
90 | for epoch_nb in range(10):
91 |
92 | if epoch_nb > 0:
93 | # After the first epoch, we finetune the transformers and the new layers
94 | config.train_transformers.assign(True)
95 | config.transformers_lr.assign(1e-4)
96 | config.nlayers_lr.assign(1e-3)
97 |
98 | training.eval(detr, valid_dt, config, class_names, evaluation_step=200)
99 | training.fit(detr, train_dt, optimzers, config, epoch_nb, class_names)
100 |
101 |
102 | if __name__ == "__main__":
103 |
104 | physical_devices = tf.config.list_physical_devices('GPU')
105 | if len(physical_devices) == 1:
106 | tf.config.experimental.set_memory_growth(physical_devices[0], True)
107 |
108 | config = TrainingConfig()
109 | args = training_config_parser().parse_args()
110 | config.update_from_args(args)
111 |
112 | if config.log:
113 | wandb.init(project="detr-tensorflow", reinit=True)
114 |
115 | # Run training
116 | run_finetuning(config)
117 |
118 |
119 |
120 |
121 |
122 |
--------------------------------------------------------------------------------
/images/datasetsupport.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-Behavior/detr-tensorflow/78fb71bca7b2ebf90e73151a51d29e8dbb46cb38/images/datasetsupport.png
--------------------------------------------------------------------------------
/images/detr-figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-Behavior/detr-tensorflow/78fb71bca7b2ebf90e73151a51d29e8dbb46cb38/images/detr-figure.png
--------------------------------------------------------------------------------
/images/hardhatdataset.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-Behavior/detr-tensorflow/78fb71bca7b2ebf90e73151a51d29e8dbb46cb38/images/hardhatdataset.jpg
--------------------------------------------------------------------------------
/images/tutorials/data-pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-Behavior/detr-tensorflow/78fb71bca7b2ebf90e73151a51d29e8dbb46cb38/images/tutorials/data-pipeline.png
--------------------------------------------------------------------------------
/images/tutorials/download_hardhat_dataset.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-Behavior/detr-tensorflow/78fb71bca7b2ebf90e73151a51d29e8dbb46cb38/images/tutorials/download_hardhat_dataset.png
--------------------------------------------------------------------------------
/images/wandb_logging.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-Behavior/detr-tensorflow/78fb71bca7b2ebf90e73151a51d29e8dbb46cb38/images/wandb_logging.png
--------------------------------------------------------------------------------
/images/webcam_detr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-Behavior/detr-tensorflow/78fb71bca7b2ebf90e73151a51d29e8dbb46cb38/images/webcam_detr.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | wandb
2 | matplotlib
3 | numpy
4 | pycocotools
5 | scikit-image
6 | imageio
7 | pandas
--------------------------------------------------------------------------------
/train_coco.py:
--------------------------------------------------------------------------------
1 | """ Example on how to train on COCO from scratch
2 | """
3 |
4 |
5 | import argparse
6 | import matplotlib.pyplot as plt
7 | import tensorflow as tf
8 | import numpy as np
9 | import time
10 | import os
11 |
12 | from detr_tf.data.coco import load_coco_dataset
13 | from detr_tf.networks.detr import get_detr_model
14 | from detr_tf.optimizers import setup_optimizers
15 | from detr_tf.optimizers import gather_gradient, aggregate_grad_and_apply
16 | from detr_tf.logger.training_logging import train_log, valid_log
17 | from detr_tf.loss.loss import get_losses
18 | from detr_tf.inference import numpy_bbox_to_image
19 | from detr_tf.training_config import TrainingConfig, training_config_parser
20 | from detr_tf import training
21 |
22 | try:
23 | # Should be optional if --log is not set
24 | import wandb
25 | except:
26 | wandb = None
27 |
28 |
29 | import time
30 |
31 |
32 | def build_model(config):
33 | """ Build the model with the pretrained weights. In this example
34 | we do not add new layers since the pretrained model is already trained on coco.
35 | See examples/finetuning_voc.py to add new layers.
36 | """
37 | # Load detr model without weight.
38 | # Use the tensorflow backbone with the imagenet weights
39 | detr = get_detr_model(config, include_top=True, weights=None, tf_backbone=True)
40 | detr.summary()
41 | return detr
42 |
43 |
44 | def run_finetuning(config):
45 |
46 | # Load the model with the new layers to finetune
47 | detr = build_model(config)
48 |
49 | # Load the training and validation dataset
50 | train_dt, coco_class_names = load_coco_dataset(
51 | config, config.batch_size, augmentation=True, img_dir="train2017", ann_fil="annotations/instances_train2017.json")
52 | valid_dt, _ = load_coco_dataset(
53 | config, 1, augmentation=False, img_dir="val2017", ann_fil="annotations/instances_val2017.json")
54 |
55 | # Train the backbone and the transformers
56 | # Check the training_config file for the other hyperparameters
57 | config.train_backbone = True
58 | config.train_transformers = True
59 |
60 | # Setup the optimziers and the trainable variables
61 | optimzers = setup_optimizers(detr, config)
62 |
63 | # Run the training for 100 epochs
64 | for epoch_nb in range(100):
65 | training.eval(detr, valid_dt, config, coco_class_names, evaluation_step=200)
66 | training.fit(detr, train_dt, optimzers, config, epoch_nb, coco_class_names)
67 |
68 |
69 | if __name__ == "__main__":
70 |
71 | physical_devices = tf.config.list_physical_devices('GPU')
72 | if len(physical_devices) == 1:
73 | tf.config.experimental.set_memory_growth(physical_devices[0], True)
74 |
75 | config = TrainingConfig()
76 | args = training_config_parser().parse_args()
77 | config.update_from_args(args)
78 |
79 | if config.log:
80 | wandb.init(project="detr-tensorflow", reinit=True)
81 |
82 | # Run training
83 | run_finetuning(config)
84 |
85 |
86 |
87 |
88 |
89 |
--------------------------------------------------------------------------------
/webcam_inference.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import cv2
4 |
5 | from detr_tf.training_config import TrainingConfig, training_config_parser
6 | from detr_tf.networks.detr import get_detr_model
7 | from detr_tf.data import processing
8 | from detr_tf.data.coco import COCO_CLASS_NAME
9 | from detr_tf.inference import get_model_inference, numpy_bbox_to_image
10 |
11 | @tf.function
12 | def run_inference(model, images, config):
13 | m_outputs = model(images, training=False)
14 | predicted_bbox, predicted_labels, predicted_scores = get_model_inference(m_outputs, config.background_class, bbox_format="xy_center")
15 | return predicted_bbox, predicted_labels, predicted_scores
16 |
17 |
18 | def run_webcam_inference(detr):
19 |
20 | cap = cv2.VideoCapture(0)
21 |
22 | while(True):
23 | ret, frame = cap.read()
24 |
25 | # Convert to RGB and process the input image
26 | model_input = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
27 | model_input = processing.normalized_images(model_input, config)
28 |
29 | # Run inference
30 | predicted_bbox, predicted_labels, predicted_scores = run_inference(detr, np.expand_dims(model_input, axis=0), config)
31 |
32 | frame = frame.astype(np.float32)
33 | frame = frame / 255
34 | frame = numpy_bbox_to_image(frame, predicted_bbox, labels=predicted_labels, scores=predicted_scores, class_name=COCO_CLASS_NAME)
35 |
36 | cv2.imshow('frame', frame)
37 | if cv2.waitKey(1) & 0xFF == ord('q'):
38 | break
39 |
40 | # When everything done, release the capture
41 | cap.release()
42 | cv2.destroyAllWindows()
43 |
44 | if __name__ == "__main__":
45 |
46 | physical_devices = tf.config.list_physical_devices('GPU')
47 | if len(physical_devices) == 1:
48 | tf.config.experimental.set_memory_growth(physical_devices[0], True)
49 |
50 | config = TrainingConfig()
51 | args = training_config_parser().parse_args()
52 | config.update_from_args(args)
53 |
54 | # Load the model with the new layers to finetune
55 | detr = get_detr_model(config, include_top=True, weights="detr")
56 | config.background_class = 91
57 |
58 | # Run webcam inference
59 | run_webcam_inference(detr)
60 |
--------------------------------------------------------------------------------