├── .gitignore ├── LICENSE ├── README.md ├── demo.py ├── detr_tensorflow ├── datasets.py ├── models │ ├── __init__.py │ ├── backbone.py │ ├── custom_layers.py │ ├── default.py │ ├── detr.py │ ├── position_embeddings.py │ └── transformer.py └── utils.py ├── env.yml ├── eval.py ├── samples ├── sample_1.jpg └── sample_1_boxes.png └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/* 2 | detr_tensorflow.egg-info/ 3 | env/ 4 | *~ 5 | *.pth 6 | *.h5 7 | *.json 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Leonardo Blanger 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DETR Tensorflow 2 | 3 | This project is my attempt at a Tensorflow implementation of the DETR architecture for Object Detection, from the paper *End-to-end Object Detection with Transformers* [(Carion *et al.*)](https://ai.facebook.com/research/publications/end-to-end-object-detection-with-transformers). 4 | 5 | **Attention:** This is a work in progress. It still does not offer all the functionality from the original implementation. If you only want to perform detection using DETR in Tensorflow, this is already possible. If you want to perform Panoptic Segmentation, fully replicate the paper's experiments, or train on your own dataset, this is still not possible. 6 | 7 | ## Overview 8 | 9 | DETR, which stands for **De**tection **Tr**ansformers, was proposed by a team from the Facebook AI group, and it is, as of today, a radical shift from the current approaches to perform Deep Learning based Object Detection. 10 | 11 | Instead of filtering and refining a set of object proposals, as done by two-stage techniques like Faster-RCNN and its adaptations, or generating dense detection grids, as done by single-stage techniques like SSD and YOLO, DETR frames the detection problem as an image to set mapping. With this formulation, both the architecture and the training process become significantly simpler. There is no need for hand-designed anchor matching schemes or post-processing steps like Non Max Suppression to discard redundant detections. 12 | 13 | DETR uses a CNN backbone to extract a higher level feature representation of the image, which is then fed into a Transformer model. The Transformer Encoder is responsible for processing this image representation, while the Decoder maps a fixed set of learned object queries to detections, performing attention over the Encoder's output. 14 | 15 | DETR is trained with a set-based global loss that finds a bipartite matching between the set of detections and ground-truth objects (non matched detections are assigned to a special _no object_ class), which in turn forces unique detections. 16 | 17 | For more details into the technique, please refer to their [paper](https://ai.facebook.com/research/publications/end-to-end-object-detection-with-transformers) (Carion *et al.*) and [blog post](https://ai.facebook.com/blog/end-to-end-object-detection-with-transformers). Both are very well written. 18 | 19 | ## About this Implementation 20 | 21 | After spending some time working with Object Detection for my Master's degree, and wanting to learn more about this apparently useful thing called Transformers that everybody keeps talking about, I came across this very cool idea that proposes a completely different way of doing Object Detection. So I decided to make it accessible to the Tensorflow community as well. This implementation had the main purpose of allowing myself to understand the technique more in depth, while also being an exercise on the Tensorflow framework. 22 | 23 | I tried my best to replicate the precise behavior of the original Pytorch implementation, trying to account for small details like the difference between how convolutions use padding in the two frameworks. This way, we can convert the existing Pytorch weights to the Tensorflow/Keras format and load them in this implementation. This turned out to be a great exercise to better understand not only the DETR architecture, but also how both frameworks work at a greater level of detail. 24 | 25 | Currently, I still have not implemented any training related code, so the only way to use this implementation is by loading the converted Pytorch weights. I also did not implement the Panoptic Segmentation part yet. Regarding the Object Detection part, that is already working. 26 | 27 | ## Evaluation Results 28 | 29 | Bellow are the results for the COCO val2017 dataset, as reported by the official Pytorch version, and achieved by this implementation using the converted weights. The small deviations are probably mostly due to the differences between how the two frameworks and implementations perform image loading and resizing, as well as floating point errors from differences in how they perform certain low level operations. 30 | 31 | **name** | **backbone** | **box AP (official)** | **box AP (ours)** 32 | -------- | ------------ | --------------------- | ----------------- 33 | DETR | R50 | 42.0 | 41.9 34 | DETR-DC5 | R50 | 43.3 | 43.2 35 | DETR | R101 | 43.5 | 43.4 36 | DETR-DC5 | R101 | 44.9 | 44.8 37 | 38 | ## Requirements 39 | 40 | The code was tested with `python 3.8.10` and `tensorflow-gpu 2.4.1`. For running the evaluation, we used the `pycocotools 2.0.2` library. You can create a local environment with `conda` and install the requirements with: 41 | 42 | ```bash 43 | # inside the repo's root directory 44 | conda env create --file=env.yml --prefix=./env 45 | conda activate ./env 46 | ``` 47 | 48 | ## How to Use 49 | 50 | You can install it as a package as follows. If you are testing on a local environment (see above), make sure it is active. 51 | 52 | ```bash 53 | # inside the repo's root directory 54 | python -m pip install . 55 | ``` 56 | 57 | In order to use the same models as in the official version, download the converted Pytorch weights in the TF/Keras `.h5` file format for the model version you want to use from [here](https://drive.google.com/drive/folders/1OMzJNxsx-D5lyLgrQokLvbpvrZ5rM9rW?usp=sharing). 58 | 59 | You can use one of the pre-built loading methods from the `models.default` package to instantiate one of the four versions that are equivalent to the ones provided by the original implementation. 60 | 61 | ```python 62 | from detr_tensorflow.models.default import build_detr_resnet50 63 | detr = build_detr_resnet50(num_classes=91) # 91 classes for the COCO dataset 64 | detr.build() 65 | detr.load_weights("detr-r50-e632da11.h5") 66 | ``` 67 | 68 | Or directly instantiate the `models.DETR` class to create your own custom combination of backbone CNN, transformer architecture, and positional encoding scheme. Please, check the files `models/default.py` and `models/detr.py` for more details. 69 | 70 | The `detr_tensorflow.utils.preprocess_image` function is designed to perform all the preprocessing required before running the model, including data normalization, resizing following the scheme used for training, and generating the image masks. It is completely implemented using only Tensorflow operations, so you can use it in combination with the `map` functionality from `tf.data.Dataset`. 71 | 72 | Finally, to get the final detections, call the model on your data with the `post_processing` flag. This way, it returns softmax scores instead of the pre-activation logits, and also discards the `no-object` dimension from the output. It doesn't discard low scored detections tough, so as to give more flexibility in how to use the detections, but the output from DETR is simple enough that this isn't hard to do. 73 | 74 | ```python 75 | from detr_tensorflow.utils import preprocess_image, absolute2relative 76 | 77 | inp_image, mask = preprocess_image(image) 78 | inp_image = tf.expand_dims(inp_image, axis=0) 79 | mask = tf.expand_dims(mask, axis=0) 80 | 81 | outputs = detr((inp_image, mask), post_process=True) 82 | labels, scores, boxes = [outputs[k][0].numpy() for k in ['labels', 'scores', 'boxes']] 83 | 84 | keep = scores > 0.7 85 | labels = labels[keep] 86 | scores = scores[keep] 87 | boxes = boxes[keep] 88 | boxes = absolute2relative(boxes, (image.shape[1], image.shape[0])).numpy() 89 | ``` 90 | 91 | (so much easier than anchor decoding + Non Max Suppression) 92 | 93 | 94 | ### Demo 95 | 96 | Short demo script that summarizes the above instructions. 97 | 98 | ```bash 99 | python demo.py 100 | ``` 101 | 102 | ### Running Evaluation 103 | 104 | I provided an `eval.py` script that evaluates the model on the COCO val2017 dataset, same as reported in the paper. Note that you don't need to download the whole COCO dataset for this, only the val2017 partition (~1GB) and annotations (~241MB), from [here](https://cocodataset.org/#download). 105 | 106 | ```bash 107 | python eval.py --coco_path=/path/to/coco \ 108 | --backbone=resnet50-dc5 \ 109 | --frozen_weights=detr-r50-dc5-f0fb7ef5.h5 \ 110 | --results_file=resnet50_dc5_results.json 111 | --batch_size=1 112 | ``` 113 | 114 | It will save the detections into the `resnet50_dc5_results.json` file, in the COCO dictionary format, so you can run evaluation again with the `--from_file` flag, and it won't need to perform image inference this time. 115 | 116 | 117 | ## Detection Samples 118 | 119 | ![sample](/samples/sample_1_boxes.png) 120 | 121 | 122 | ## TODOs 123 | 124 | - [x] Provide pretrained weights already in `hdf5` format. 125 | - [ ] Implement the training related code. 126 | - [ ] Repeat the paper's experiments. 127 | 128 | 129 | ## References 130 | 131 | * **The DETR paper:** Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko, *End-to-end Object Detection with Transformers*, 2020, from the Facebook AI group. [link to paper](https://arxiv.org/abs/2005.12872) 132 | 133 | * **The official Pytorch implementation:** https://github.com/facebookresearch/detr 134 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from os import path 3 | import tensorflow as tf 4 | 5 | from detr_tensorflow.models.default import build_detr_resnet50 6 | from detr_tensorflow.utils import (read_jpeg_image, 7 | preprocess_image, 8 | absolute2relative) 9 | 10 | detr = build_detr_resnet50() 11 | detr.build() 12 | detr.load_weights('detr-r50-e632da11.h5') 13 | 14 | 15 | CLASSES = [ 16 | 'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 17 | 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 18 | 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 19 | 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 20 | 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 21 | 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 22 | 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 23 | 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 24 | 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 25 | 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 26 | 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 27 | 'cellphone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 28 | 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 29 | 'toothbrush', 30 | ] 31 | 32 | # colors for visualization 33 | COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], 34 | [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]] 35 | 36 | 37 | image = read_jpeg_image(path.join('samples', 'sample_1.jpg')) 38 | 39 | inp_image, mask = preprocess_image(image) 40 | inp_image = tf.expand_dims(inp_image, axis=0) 41 | mask = tf.expand_dims(mask, axis=0) 42 | outputs = detr((inp_image, mask), post_process=True) 43 | 44 | labels, scores, boxes = [outputs[k][0].numpy() 45 | for k in ['labels', 'scores', 'boxes']] 46 | 47 | keep = scores > 0.7 48 | labels = labels[keep] 49 | scores = scores[keep] 50 | boxes = boxes[keep] 51 | boxes = absolute2relative(boxes, (image.shape[1], image.shape[0])).numpy() 52 | 53 | 54 | def plot_results(img, labels, probs, boxes): 55 | plt.figure(figsize=(14, 8)) 56 | plt.imshow(img) 57 | ax = plt.gca() 58 | for cl, p, (xmin, ymin, xmax, ymax), c in zip( 59 | labels, probs, boxes.tolist(), COLORS * 100): 60 | ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, 61 | fill=False, color=c, linewidth=3)) 62 | text = f'{CLASSES[cl]}: {p:0.2f}' 63 | ax.text(xmin, ymin, text, fontsize=15, 64 | bbox=dict(facecolor='yellow', alpha=0.5)) 65 | plt.axis('off') 66 | plt.subplots_adjust(left=0, right=1, top=1, bottom=0) 67 | plt.show() 68 | 69 | 70 | plot_results(image.numpy(), labels, scores, boxes) 71 | -------------------------------------------------------------------------------- /detr_tensorflow/datasets.py: -------------------------------------------------------------------------------- 1 | from pycocotools.coco import COCO 2 | import numpy as np 3 | from os import path 4 | import tensorflow as tf 5 | 6 | 7 | class COCODatasetBBoxes(tf.keras.utils.Sequence): 8 | def __init__(self, cocopath, partition='val2017', return_boxes=True, 9 | ignore_crowded=True, **kwargs): 10 | super().__init__(**kwargs) 11 | self.cocopath = cocopath 12 | self.partition = partition 13 | self.return_boxes = return_boxes 14 | self.ignore_crowded = ignore_crowded 15 | 16 | self.coco = COCO(path.join( 17 | cocopath, 'annotations', f'instances_{partition}.json')) 18 | self.img_ids = sorted(self.coco.getImgIds()) 19 | 20 | def __len__(self): 21 | return len(self.img_ids) 22 | 23 | def __getitem__(self, idx): 24 | img_info = self.coco.loadImgs(self.img_ids[idx])[0] 25 | img_path = path.join(self.cocopath, self.partition, 26 | img_info['file_name']) 27 | if not self.return_boxes: 28 | return self.img_ids[idx], img_path 29 | ann_ids = self.coco.getAnnIds(self.img_ids[idx]) 30 | boxes = self.parse_annotations(ann_ids) 31 | return self.img_ids[idx], img_path, boxes 32 | 33 | def parse_annotations(self, ann_ids): 34 | boxes = [] 35 | for ann in self.coco.loadAnns(ann_ids): 36 | if 'iscrowd' in ann and ann['iscrowd'] > 0 and self.ignore_crowded: 37 | continue 38 | box = ann['bbox'] + [ann['category_id']] 39 | box = np.array(box, dtype=np.float32) 40 | box[2:4] += box[0:2] 41 | boxes.append(box) 42 | return boxes 43 | -------------------------------------------------------------------------------- /detr_tensorflow/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import default 2 | from .detr import DETR 3 | 4 | __all__ = ['default', 'DETR'] 5 | -------------------------------------------------------------------------------- /detr_tensorflow/models/backbone.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.layers import ZeroPadding2D, Conv2D, ReLU, MaxPool2D 3 | 4 | from .custom_layers import FrozenBatchNorm2D 5 | 6 | 7 | class ResNetBase(tf.keras.Model): 8 | def __init__(self, **kwargs): 9 | super().__init__(**kwargs) 10 | 11 | self.pad1 = ZeroPadding2D(3, name='pad1') 12 | self.conv1 = Conv2D(64, kernel_size=7, strides=2, padding='valid', 13 | use_bias=False, name='conv1') 14 | self.bn1 = FrozenBatchNorm2D(name='bn1') 15 | self.relu = ReLU(name='relu') 16 | self.pad2 = ZeroPadding2D(1, name='pad2') 17 | self.maxpool = MaxPool2D(pool_size=3, strides=2, padding='valid') 18 | 19 | def call(self, x): 20 | x = self.pad1(x) 21 | x = self.conv1(x) 22 | x = self.bn1(x) 23 | x = self.relu(x) 24 | x = self.pad2(x) 25 | x = self.maxpool(x) 26 | 27 | x = self.layer1(x) 28 | x = self.layer2(x) 29 | x = self.layer3(x) 30 | x = self.layer4(x) 31 | return x 32 | 33 | 34 | class ResNet50Backbone(ResNetBase): 35 | def __init__(self, 36 | replace_stride_with_dilation=[False, False, False], 37 | **kwargs): 38 | super().__init__(**kwargs) 39 | 40 | self.layer1 = ResidualBlock( 41 | num_bottlenecks=3, dim1=64, dim2=256, strides=1, 42 | replace_stride_with_dilation=False, name='layer1') 43 | 44 | self.layer2 = ResidualBlock( 45 | num_bottlenecks=4, dim1=128, dim2=512, strides=2, 46 | replace_stride_with_dilation=replace_stride_with_dilation[0], 47 | name='layer2') 48 | 49 | self.layer3 = ResidualBlock( 50 | num_bottlenecks=6, dim1=256, dim2=1024, strides=2, 51 | replace_stride_with_dilation=replace_stride_with_dilation[1], 52 | name='layer3') 53 | 54 | self.layer4 = ResidualBlock( 55 | num_bottlenecks=3, dim1=512, dim2=2048, strides=2, 56 | replace_stride_with_dilation=replace_stride_with_dilation[2], 57 | name='layer4') 58 | 59 | 60 | class ResNet101Backbone(ResNetBase): 61 | def __init__(self, 62 | replace_stride_with_dilation=[False, False, False], 63 | **kwargs): 64 | super().__init__(**kwargs) 65 | 66 | self.layer1 = ResidualBlock( 67 | num_bottlenecks=3, dim1=64, dim2=256, strides=1, 68 | replace_stride_with_dilation=False, name='layer1') 69 | 70 | self.layer2 = ResidualBlock( 71 | num_bottlenecks=4, dim1=128, dim2=512, strides=2, 72 | replace_stride_with_dilation=replace_stride_with_dilation[0], 73 | name='layer2') 74 | 75 | self.layer3 = ResidualBlock( 76 | num_bottlenecks=23, dim1=256, dim2=1024, strides=2, 77 | replace_stride_with_dilation=replace_stride_with_dilation[1], 78 | name='layer3') 79 | 80 | self.layer4 = ResidualBlock( 81 | num_bottlenecks=3, dim1=512, dim2=2048, strides=2, 82 | replace_stride_with_dilation=replace_stride_with_dilation[2], 83 | name='layer4') 84 | 85 | 86 | class ResidualBlock(tf.keras.Model): 87 | def __init__(self, num_bottlenecks, dim1, dim2, strides=1, 88 | replace_stride_with_dilation=False, **kwargs): 89 | super().__init__(**kwargs) 90 | 91 | if replace_stride_with_dilation: 92 | strides = 1 93 | dilation = 2 94 | else: 95 | dilation = 1 96 | 97 | self.bottlenecks = [BottleNeck(dim1, dim2, strides=strides, 98 | downsample=True, name='0')] 99 | 100 | for idx in range(1, num_bottlenecks): 101 | self.bottlenecks.append(BottleNeck(dim1, dim2, name=str(idx), 102 | dilation=dilation)) 103 | 104 | def call(self, x): 105 | for btn in self.bottlenecks: 106 | x = btn(x) 107 | return x 108 | 109 | 110 | class BottleNeck(tf.keras.Model): 111 | def __init__(self, dim1, dim2, strides=1, 112 | dilation=1, downsample=False, **kwargs): 113 | super().__init__(**kwargs) 114 | self.downsample = downsample 115 | self.pad = ZeroPadding2D(dilation) 116 | self.relu = ReLU(name='relu') 117 | 118 | self.conv1 = Conv2D(dim1, kernel_size=1, use_bias=False, name='conv1') 119 | self.bn1 = FrozenBatchNorm2D(name='bn1') 120 | 121 | self.conv2 = Conv2D(dim1, kernel_size=3, strides=strides, 122 | dilation_rate=dilation, 123 | use_bias=False, name='conv2') 124 | self.bn2 = FrozenBatchNorm2D(name='bn2') 125 | 126 | self.conv3 = Conv2D(dim2, kernel_size=1, use_bias=False, name='conv3') 127 | self.bn3 = FrozenBatchNorm2D(name='bn3') 128 | 129 | self.downsample_conv = Conv2D(dim2, kernel_size=1, strides=strides, 130 | use_bias=False, name='downsample_0') 131 | self.downsample_bn = FrozenBatchNorm2D(name='downsample_1') 132 | 133 | def call(self, x): 134 | identity = x 135 | 136 | out = self.conv1(x) 137 | out = self.bn1(out) 138 | out = self.relu(out) 139 | 140 | out = self.pad(out) 141 | out = self.conv2(out) 142 | out = self.bn2(out) 143 | out = self.relu(out) 144 | 145 | out = self.conv3(out) 146 | out = self.bn3(out) 147 | 148 | if self.downsample: 149 | identity = self.downsample_bn(self.downsample_conv(x)) 150 | 151 | out += identity 152 | out = self.relu(out) 153 | 154 | return out 155 | -------------------------------------------------------------------------------- /detr_tensorflow/models/custom_layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class FrozenBatchNorm2D(tf.keras.layers.Layer): 5 | def __init__(self, eps=1e-5, **kwargs): 6 | super().__init__(**kwargs) 7 | self.eps = eps 8 | 9 | def build(self, input_shape): 10 | self.weight = self.add_weight(name='weight', shape=[input_shape[-1]], 11 | initializer='zeros', trainable=False) 12 | self.bias = self.add_weight(name='bias', shape=[input_shape[-1]], 13 | initializer='zeros', trainable=False) 14 | self.running_mean = self.add_weight(name='running_mean', 15 | shape=[input_shape[-1]], 16 | initializer='zeros', 17 | trainable=False) 18 | self.running_var = self.add_weight(name='running_var', 19 | shape=[input_shape[-1]], 20 | initializer='ones', 21 | trainable=False) 22 | 23 | def call(self, x): 24 | scale = self.weight * tf.math.rsqrt(self.running_var + self.eps) 25 | shift = self.bias - self.running_mean * scale 26 | return x * scale + shift 27 | 28 | def compute_output_shape(self, input_shape): 29 | return input_shape 30 | 31 | 32 | class Linear(tf.keras.layers.Layer): 33 | ''' 34 | Use this custom layer instead of tf.keras.layers.Dense 35 | to allow loading converted PyTorch Dense weights 36 | that have shape (output_dim, input_dim) 37 | ''' 38 | def __init__(self, output_dim, **kwargs): 39 | super().__init__(**kwargs) 40 | self.output_dim = output_dim 41 | 42 | def build(self, input_shape): 43 | self.kernel = self.add_weight(name='kernel', 44 | shape=[self.output_dim, input_shape[-1]], 45 | initializer='zeros', trainable=True) 46 | self.bias = self.add_weight(name='bias', 47 | shape=[self.output_dim], 48 | initializer='zeros', trainable=True) 49 | 50 | def call(self, x): 51 | return tf.matmul(x, self.kernel, transpose_b=True) + self.bias 52 | 53 | def compute_output_shape(self, input_shape): 54 | return input_shape.as_list()[:-1] + [self.output_dim] 55 | 56 | 57 | class FixedEmbedding(tf.keras.layers.Layer): 58 | def __init__(self, embed_shape, **kwargs): 59 | super().__init__(**kwargs) 60 | self.embed_shape = embed_shape 61 | 62 | def build(self, input_shape): 63 | self.w = self.add_weight(name='kernel', shape=self.embed_shape, 64 | initializer='zeros', trainable=True) 65 | 66 | def call(self, x=None): 67 | return self.w 68 | -------------------------------------------------------------------------------- /detr_tensorflow/models/default.py: -------------------------------------------------------------------------------- 1 | from .detr import DETR 2 | 3 | 4 | def build_detr_resnet50(num_classes=91, num_queries=100): 5 | from .backbone import ResNet50Backbone 6 | return DETR(num_classes=num_classes, 7 | num_queries=num_queries, 8 | backbone=ResNet50Backbone(name='backbone')) 9 | 10 | 11 | def build_detr_resnet50_dc5(num_classes=91, num_queries=100): 12 | from .backbone import ResNet50Backbone 13 | return DETR(num_classes=num_classes, 14 | num_queries=num_queries, 15 | backbone=ResNet50Backbone( 16 | replace_stride_with_dilation=[False, False, True], 17 | name='backbone')) 18 | 19 | 20 | def build_detr_resnet101(num_classes=91, num_queries=100): 21 | from .backbone import ResNet101Backbone 22 | return DETR(num_classes=num_classes, 23 | num_queries=num_queries, 24 | backbone=ResNet101Backbone(name='backbone')) 25 | 26 | 27 | def build_detr_resnet101_dc5(num_classes=91, num_queries=100): 28 | from .backbone import ResNet101Backbone 29 | return DETR(num_classes=num_classes, 30 | num_queries=num_queries, 31 | backbone=ResNet101Backbone( 32 | replace_stride_with_dilation=[False, False, True], 33 | name='backbone')) 34 | -------------------------------------------------------------------------------- /detr_tensorflow/models/detr.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.layers import Conv2D, ReLU 3 | 4 | from .backbone import ResNet50Backbone 5 | from .custom_layers import Linear, FixedEmbedding 6 | from .position_embeddings import PositionEmbeddingSine 7 | from .transformer import Transformer 8 | from ..utils import cxcywh2xyxy 9 | 10 | 11 | class DETR(tf.keras.Model): 12 | def __init__(self, num_classes=91, num_queries=100, 13 | backbone=None, 14 | pos_encoder=None, 15 | transformer=None, 16 | **kwargs): 17 | super().__init__(**kwargs) 18 | self.num_queries = num_queries 19 | 20 | self.backbone = backbone or ResNet50Backbone(name='backbone') 21 | self.transformer = transformer or Transformer( 22 | return_intermediate_dec=True, name='transformer') 23 | self.model_dim = self.transformer.model_dim 24 | 25 | self.pos_encoder = pos_encoder or PositionEmbeddingSine( 26 | num_pos_features=self.model_dim // 2, normalize=True) 27 | 28 | self.input_proj = Conv2D( 29 | self.model_dim, kernel_size=1, name='input_proj') 30 | 31 | self.query_embed = FixedEmbedding((num_queries, self.model_dim), 32 | name='query_embed') 33 | 34 | self.class_embed = Linear(num_classes + 1, name='class_embed') 35 | 36 | self.bbox_embed_linear1 = Linear(self.model_dim, name='bbox_embed_0') 37 | self.bbox_embed_linear2 = Linear(self.model_dim, name='bbox_embed_1') 38 | self.bbox_embed_linear3 = Linear(4, name='bbox_embed_2') 39 | self.activation = ReLU() 40 | 41 | def call(self, inp, training=False, post_process=False): 42 | x, masks = inp 43 | x = self.backbone(x, training=training) 44 | masks = self.downsample_masks(masks, x) 45 | pos_encoding = self.pos_encoder(masks) 46 | 47 | hs = self.transformer(self.input_proj(x), masks, 48 | self.query_embed(None), 49 | pos_encoding, 50 | training=training)[0] 51 | 52 | outputs_class = self.class_embed(hs) 53 | 54 | box_ftmps = self.activation(self.bbox_embed_linear1(hs)) 55 | box_ftmps = self.activation(self.bbox_embed_linear2(box_ftmps)) 56 | outputs_coord = tf.sigmoid(self.bbox_embed_linear3(box_ftmps)) 57 | 58 | output = {'pred_logits': outputs_class[-1], 59 | 'pred_boxes': outputs_coord[-1]} 60 | 61 | if post_process: 62 | output = self.post_process(output) 63 | return output 64 | 65 | def build(self, input_shape=None, **kwargs): 66 | if input_shape is None: 67 | input_shape = [(None, None, None, 3), (None, None, None)] 68 | super().build(input_shape, **kwargs) 69 | 70 | def downsample_masks(self, masks, x): 71 | masks = tf.cast(masks, tf.int32) 72 | masks = tf.expand_dims(masks, -1) 73 | # The existing tf.image.resize with method='nearest' 74 | # does not expose the half_pixel_centers option in TF 2.2.0 75 | # The original Pytorch F.interpolate uses it like this 76 | masks = tf.compat.v1.image.resize_nearest_neighbor( 77 | masks, tf.shape(x)[1:3], align_corners=False, 78 | half_pixel_centers=False) 79 | masks = tf.squeeze(masks, -1) 80 | masks = tf.cast(masks, tf.bool) 81 | return masks 82 | 83 | def post_process(self, output): 84 | logits, boxes = [output[k] for k in ['pred_logits', 'pred_boxes']] 85 | 86 | probs = tf.nn.softmax(logits, axis=-1)[..., :-1] 87 | scores = tf.reduce_max(probs, axis=-1) 88 | labels = tf.argmax(probs, axis=-1) 89 | boxes = cxcywh2xyxy(boxes) 90 | 91 | output = {'scores': scores, 92 | 'labels': labels, 93 | 'boxes': boxes} 94 | return output 95 | -------------------------------------------------------------------------------- /detr_tensorflow/models/position_embeddings.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | class PositionEmbeddingSine(tf.keras.Model): 6 | def __init__(self, num_pos_features=64, temperature=10000, 7 | normalize=False, scale=None, eps=1e-6, **kwargs): 8 | # These are the default parameters used in the original project 9 | super().__init__(**kwargs) 10 | 11 | self.num_pos_features = num_pos_features 12 | self.temperature = temperature 13 | self.normalize = normalize 14 | if scale is not None and normalize is False: 15 | raise ValueError('normalize should be True if scale is passed') 16 | if scale is None: 17 | scale = 2 * np.pi 18 | self.scale = scale 19 | self.eps = eps 20 | 21 | def call(self, mask): 22 | not_mask = tf.cast(~mask, tf.float32) 23 | y_embed = tf.math.cumsum(not_mask, axis=1) 24 | x_embed = tf.math.cumsum(not_mask, axis=2) 25 | 26 | if self.normalize: 27 | y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale 28 | x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale 29 | 30 | dim_t = tf.range(self.num_pos_features, dtype=tf.float32) 31 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_features) 32 | 33 | pos_x = x_embed[..., tf.newaxis] / dim_t 34 | pos_y = y_embed[..., tf.newaxis] / dim_t 35 | 36 | pos_x = tf.stack([tf.math.sin(pos_x[..., 0::2]), 37 | tf.math.cos(pos_x[..., 1::2])], axis=4) 38 | 39 | pos_y = tf.stack([tf.math.sin(pos_y[..., 0::2]), 40 | tf.math.cos(pos_y[..., 1::2])], axis=4) 41 | 42 | shape = [tf.shape(pos_x)[i] for i in range(3)] + [-1] 43 | pos_x = tf.reshape(pos_x, shape) 44 | pos_y = tf.reshape(pos_y, shape) 45 | 46 | pos_emb = tf.concat([pos_y, pos_x], axis=3) 47 | return pos_emb 48 | -------------------------------------------------------------------------------- /detr_tensorflow/models/transformer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.layers import Dropout, Activation, LayerNormalization 3 | 4 | from .custom_layers import Linear 5 | 6 | 7 | class Transformer(tf.keras.Model): 8 | def __init__(self, model_dim=256, num_heads=8, num_encoder_layers=6, 9 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 10 | activation='relu', normalize_before=False, 11 | return_intermediate_dec=False, **kwargs): 12 | super().__init__(**kwargs) 13 | 14 | self.model_dim = model_dim 15 | self.num_heads = num_heads 16 | 17 | enc_norm = None 18 | if normalize_before: 19 | enc_norm = LayerNormalization(epsilon=1e-5, name='norm_pre') 20 | dec_norm = LayerNormalization(epsilon=1e-5, name='norm') 21 | 22 | self.encoder = TransformerEncoder( 23 | model_dim, num_heads, dim_feedforward, dropout, activation, 24 | normalize_before, enc_norm, num_encoder_layers, name='encoder') 25 | 26 | self.decoder = TransformerDecoder( 27 | model_dim, num_heads, dim_feedforward, dropout, activation, 28 | normalize_before, dec_norm, num_decoder_layers, name='decoder', 29 | return_intermediate=return_intermediate_dec) 30 | 31 | def call(self, source, mask, query_encoding, pos_encoding, training=False): 32 | batch_size, rows, cols = [tf.shape(source)[i] for i in range(3)] 33 | 34 | source = tf.reshape(source, [batch_size, -1, self.model_dim]) 35 | source = tf.transpose(source, [1, 0, 2]) 36 | 37 | pos_encoding = tf.reshape(pos_encoding, 38 | [batch_size, -1, self.model_dim]) 39 | pos_encoding = tf.transpose(pos_encoding, [1, 0, 2]) 40 | 41 | query_encoding = tf.expand_dims(query_encoding, axis=1) 42 | query_encoding = tf.tile(query_encoding, [1, batch_size, 1]) 43 | 44 | mask = tf.reshape(mask, [batch_size, -1]) 45 | 46 | target = tf.zeros_like(query_encoding) 47 | 48 | memory = self.encoder(source, source_key_padding_mask=mask, 49 | pos_encoding=pos_encoding, training=training) 50 | hs = self.decoder(target, memory, 51 | memory_key_padding_mask=mask, 52 | pos_encoding=pos_encoding, 53 | query_encoding=query_encoding, 54 | training=training) 55 | 56 | hs = tf.transpose(hs, [0, 2, 1, 3]) 57 | memory = tf.transpose(memory, [1, 0, 2]) 58 | memory = tf.reshape(memory, [batch_size, rows, cols, self.model_dim]) 59 | 60 | return hs, memory 61 | 62 | 63 | class TransformerEncoder(tf.keras.Model): 64 | def __init__(self, model_dim=256, num_heads=8, dim_feedforward=2048, 65 | dropout=0.1, activation='relu', normalize_before=False, 66 | norm=None, num_encoder_layers=6, **kwargs): 67 | super().__init__(**kwargs) 68 | 69 | self.enc_layers = [EncoderLayer(model_dim, num_heads, dim_feedforward, 70 | dropout, activation, normalize_before, 71 | name=f'layer_{i}') 72 | for i in range(num_encoder_layers)] 73 | 74 | self.norm = norm 75 | 76 | def call(self, source, mask=None, source_key_padding_mask=None, 77 | pos_encoding=None, training=False): 78 | x = source 79 | 80 | for layer in self.enc_layers: 81 | x = layer(x, source_mask=mask, 82 | source_key_padding_mask=source_key_padding_mask, 83 | pos_encoding=pos_encoding, 84 | training=training) 85 | 86 | if self.norm: 87 | x = self.norm(x) 88 | 89 | return x 90 | 91 | 92 | class TransformerDecoder(tf.keras.Model): 93 | def __init__(self, model_dim=256, num_heads=8, dim_feedforward=2048, 94 | dropout=0.1, activation='relu', normalize_before=False, 95 | norm=None, num_decoder_layers=6, return_intermediate=False, 96 | **kwargs): 97 | super().__init__(**kwargs) 98 | 99 | self.dec_layers = [DecoderLayer(model_dim, num_heads, dim_feedforward, 100 | dropout, activation, normalize_before, 101 | name=f'layer_{i}') 102 | for i in range(num_decoder_layers)] 103 | 104 | self.norm = norm 105 | self.return_intermediate = return_intermediate 106 | 107 | def call(self, target, memory, target_mask=None, memory_mask=None, 108 | target_key_padding_mask=None, memory_key_padding_mask=None, 109 | pos_encoding=None, query_encoding=None, training=False): 110 | 111 | x = target 112 | intermediate = [] 113 | 114 | for layer in self.dec_layers: 115 | x = layer(x, memory, 116 | target_mask=target_mask, 117 | memory_mask=memory_mask, 118 | target_key_padding_mask=target_key_padding_mask, 119 | memory_key_padding_mask=memory_key_padding_mask, 120 | pos_encoding=pos_encoding, 121 | query_encoding=query_encoding) 122 | 123 | if self.return_intermediate: 124 | if self.norm: 125 | intermediate.append(self.norm(x)) 126 | else: 127 | intermediate.append(x) 128 | 129 | if self.return_intermediate: 130 | return tf.stack(intermediate, axis=0) 131 | 132 | if self.norm: 133 | x = self.norm(x) 134 | 135 | return x 136 | 137 | 138 | class EncoderLayer(tf.keras.layers.Layer): 139 | def __init__(self, model_dim=256, num_heads=8, dim_feedforward=2048, 140 | dropout=0.1, activation='relu', normalize_before=False, 141 | **kwargs): 142 | super().__init__(**kwargs) 143 | 144 | self.self_attn = MultiHeadAttention(model_dim, num_heads, 145 | dropout=dropout, 146 | name='self_attn') 147 | 148 | self.dropout = Dropout(dropout) 149 | self.activation = Activation(activation) 150 | 151 | self.linear1 = Linear(dim_feedforward, name='linear1') 152 | self.linear2 = Linear(model_dim, name='linear2') 153 | 154 | self.norm1 = LayerNormalization(epsilon=1e-5, name='norm1') 155 | self.norm2 = LayerNormalization(epsilon=1e-5, name='norm2') 156 | 157 | self.normalize_before = normalize_before 158 | 159 | def call(self, source, source_mask=None, source_key_padding_mask=None, 160 | pos_encoding=None, training=False): 161 | if self.normalize_before: 162 | return self.pre_norm_call(source, source_mask, 163 | source_key_padding_mask, 164 | pos_encoding, training) 165 | return self.post_norm_call(source, source_mask, 166 | source_key_padding_mask, 167 | pos_encoding, training) 168 | 169 | def pre_norm_call(self, source, source_mask=None, 170 | source_key_padding_mask=None, 171 | pos_encoding=None, training=False): 172 | raise Exception('pre_norm_call not implemented yet') 173 | 174 | def post_norm_call(self, source, source_mask=None, 175 | source_key_padding_mask=None, 176 | pos_encoding=None, training=False): 177 | if pos_encoding is None: 178 | query = key = source 179 | else: 180 | query = key = source + pos_encoding 181 | 182 | attn_source = self.self_attn((query, key, source), 183 | attn_mask=source_mask, 184 | key_padding_mask=source_key_padding_mask, 185 | need_weights=False) 186 | source += self.dropout(attn_source, training=training) 187 | source = self.norm1(source) 188 | 189 | x = self.linear1(source) 190 | x = self.activation(x) 191 | x = self.dropout(x, training=training) 192 | x = self.linear2(x) 193 | source += self.dropout(x, training=training) 194 | source = self.norm2(source) 195 | 196 | return source 197 | 198 | 199 | class DecoderLayer(tf.keras.layers.Layer): 200 | def __init__(self, model_dim=256, num_heads=8, dim_feedforward=2048, 201 | dropout=0.1, activation='relu', normalize_before=False, 202 | **kwargs): 203 | super().__init__(**kwargs) 204 | 205 | self.self_attn = MultiHeadAttention(model_dim, num_heads, 206 | dropout=dropout, 207 | name='self_attn') 208 | self.multihead_attn = MultiHeadAttention(model_dim, num_heads, 209 | dropout=dropout, 210 | name='multihead_attn') 211 | 212 | self.dropout = Dropout(dropout) 213 | self.activation = Activation(activation) 214 | 215 | self.linear1 = Linear(dim_feedforward, name='linear1') 216 | self.linear2 = Linear(model_dim, name='linear2') 217 | 218 | self.norm1 = LayerNormalization(epsilon=1e-5, name='norm1') 219 | self.norm2 = LayerNormalization(epsilon=1e-5, name='norm2') 220 | self.norm3 = LayerNormalization(epsilon=1e-5, name='norm3') 221 | 222 | self.normalize_before = normalize_before 223 | 224 | def call(self, target, memory, target_mask=None, memory_mask=None, 225 | target_key_padding_mask=None, memory_key_padding_mask=None, 226 | pos_encoding=None, query_encoding=None, training=False): 227 | if self.normalize_before: 228 | return self.pre_norm_call(target, memory, target_mask, memory_mask, 229 | target_key_padding_mask, 230 | memory_key_padding_mask, 231 | pos_encoding, query_encoding, 232 | training=training) 233 | return self.post_norm_call(target, memory, target_mask, memory_mask, 234 | target_key_padding_mask, 235 | memory_key_padding_mask, 236 | pos_encoding, query_encoding, 237 | training=training) 238 | 239 | def pre_norm_call(self, target, memory, target_mask=None, 240 | memory_mask=None, target_key_padding_mask=None, 241 | memory_key_padding_mask=None, pos_encoding=None, 242 | query_encoding=None, training=False): 243 | raise Exception('pre_norm_call not implemented yet') 244 | 245 | def post_norm_call(self, target, memory, target_mask=None, 246 | memory_mask=None, target_key_padding_mask=None, 247 | memory_key_padding_mask=None, pos_encoding=None, 248 | query_encoding=None, training=False): 249 | 250 | query_tgt = key_tgt = target + query_encoding 251 | attn_target = self.self_attn((query_tgt, key_tgt, target), 252 | attn_mask=target_mask, 253 | key_padding_mask=target_key_padding_mask, 254 | need_weights=False) 255 | target += self.dropout(attn_target, training=training) 256 | target = self.norm1(target) 257 | 258 | query_tgt = target + query_encoding 259 | key_mem = memory + pos_encoding 260 | 261 | attn_target2 = self.multihead_attn( 262 | (query_tgt, key_mem, memory), attn_mask=memory_mask, 263 | key_padding_mask=memory_key_padding_mask, need_weights=False) 264 | target += self.dropout(attn_target2, training=training) 265 | target = self.norm2(target) 266 | 267 | x = self.linear1(target) 268 | x = self.activation(x) 269 | x = self.dropout(x, training=training) 270 | x = self.linear2(x) 271 | target += self.dropout(x, training=training) 272 | target = self.norm3(target) 273 | 274 | return target 275 | 276 | 277 | class MultiHeadAttention(tf.keras.layers.Layer): 278 | def __init__(self, model_dim, num_heads, dropout=0.0, **kwargs): 279 | super().__init__(**kwargs) 280 | 281 | self.model_dim = model_dim 282 | self.num_heads = num_heads 283 | 284 | assert model_dim % num_heads == 0 285 | self.head_dim = model_dim // num_heads 286 | 287 | self.dropout = Dropout(rate=dropout) 288 | 289 | def build(self, input_shapes): 290 | in_dim = sum([shape[-1] for shape in input_shapes[:3]]) 291 | self.in_proj_weight = tf.Variable( 292 | tf.zeros((in_dim, self.model_dim), dtype=tf.float32), 293 | name='in_proj_kernel') 294 | self.in_proj_bias = tf.Variable(tf.zeros((in_dim,), dtype=tf.float32), 295 | name='in_proj_bias') 296 | 297 | self.out_proj_weight = tf.Variable( 298 | tf.zeros((self.model_dim, self.model_dim), dtype=tf.float32), 299 | name='out_proj_kernel') 300 | self.out_proj_bias = tf.Variable( 301 | tf.zeros((self.model_dim,), dtype=tf.float32), 302 | name='out_proj_bias') 303 | 304 | def call(self, inputs, attn_mask=None, key_padding_mask=None, 305 | need_weights=True, training=False): 306 | query, key, value = inputs 307 | batch_size = tf.shape(query)[1] 308 | target_len = tf.shape(query)[0] 309 | source_len = tf.shape(key)[0] 310 | 311 | W = self.in_proj_weight[:self.model_dim, :] 312 | b = self.in_proj_bias[:self.model_dim] 313 | WQ = tf.matmul(query, W, transpose_b=True) + b 314 | 315 | W = self.in_proj_weight[self.model_dim:2*self.model_dim, :] 316 | b = self.in_proj_bias[self.model_dim:2*self.model_dim] 317 | WK = tf.matmul(key, W, transpose_b=True) + b 318 | 319 | W = self.in_proj_weight[2*self.model_dim:, :] 320 | b = self.in_proj_bias[2*self.model_dim:] 321 | WV = tf.matmul(value, W, transpose_b=True) + b 322 | 323 | WQ *= float(self.head_dim) ** -0.5 324 | WQ = tf.reshape( 325 | WQ, [target_len, batch_size * self.num_heads, self.head_dim]) 326 | WQ = tf.transpose(WQ, [1, 0, 2]) 327 | 328 | WK = tf.reshape( 329 | WK, [source_len, batch_size * self.num_heads, self.head_dim]) 330 | WK = tf.transpose(WK, [1, 0, 2]) 331 | 332 | WV = tf.reshape( 333 | WV, [source_len, batch_size * self.num_heads, self.head_dim]) 334 | WV = tf.transpose(WV, [1, 0, 2]) 335 | 336 | attn_output_weights = tf.matmul(WQ, WK, transpose_b=True) 337 | 338 | if attn_mask is not None: 339 | attn_output_weights += attn_mask 340 | 341 | if key_padding_mask is not None: 342 | attn_output_weights = tf.reshape( 343 | attn_output_weights, 344 | [batch_size, self.num_heads, target_len, source_len]) 345 | 346 | key_padding_mask = tf.expand_dims(key_padding_mask, 1) 347 | key_padding_mask = tf.expand_dims(key_padding_mask, 2) 348 | key_padding_mask = tf.tile( 349 | key_padding_mask, [1, self.num_heads, target_len, 1]) 350 | 351 | attn_output_weights = tf.where( 352 | key_padding_mask, 353 | tf.zeros_like(attn_output_weights) + float('-inf'), 354 | attn_output_weights) 355 | attn_output_weights = tf.reshape( 356 | attn_output_weights, 357 | [batch_size * self.num_heads, target_len, source_len]) 358 | 359 | attn_output_weights = tf.nn.softmax(attn_output_weights, axis=-1) 360 | attn_output_weights = self.dropout(attn_output_weights, 361 | training=training) 362 | 363 | attn_output = tf.matmul(attn_output_weights, WV) 364 | attn_output = tf.transpose(attn_output, [1, 0, 2]) 365 | attn_output = tf.reshape(attn_output, 366 | [target_len, batch_size, self.model_dim]) 367 | attn_output = tf.matmul(attn_output, self.out_proj_weight, 368 | transpose_b=True) + self.out_proj_bias 369 | 370 | if need_weights: 371 | attn_output_weights = tf.reshape( 372 | attn_output_weights, 373 | [batch_size, self.num_heads, target_len, source_len]) 374 | # Retrun the average weight over the heads 375 | avg_weights = tf.reduce_mean(attn_output_weights, axis=1) 376 | return attn_output, avg_weights 377 | 378 | return attn_output 379 | -------------------------------------------------------------------------------- /detr_tensorflow/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def read_jpeg_image(img_path): 5 | image = tf.io.read_file(img_path) 6 | image = tf.image.decode_jpeg(image, channels=3) 7 | return image 8 | 9 | 10 | def resize(image, min_side=800.0, max_side=1333.0): 11 | h = tf.cast(tf.shape(image)[0], tf.float32) 12 | w = tf.cast(tf.shape(image)[1], tf.float32) 13 | cur_min_side = tf.minimum(w, h) 14 | cur_max_side = tf.maximum(w, h) 15 | 16 | scale = tf.minimum(max_side / cur_max_side, 17 | min_side / cur_min_side) 18 | nh = tf.cast(scale * h, tf.int32) 19 | nw = tf.cast(scale * w, tf.int32) 20 | 21 | image = tf.image.resize(image, (nh, nw)) 22 | return image 23 | 24 | 25 | def build_mask(image): 26 | return tf.zeros(tf.shape(image)[:2], dtype=tf.bool) 27 | 28 | 29 | def cxcywh2xyxy(boxes): 30 | cx, cy, w, h = [boxes[..., i] for i in range(4)] 31 | 32 | xmin, ymin = cx - w*0.5, cy - h*0.5 33 | xmax, ymax = cx + w*0.5, cy + h*0.5 34 | 35 | boxes = tf.stack([xmin, ymin, xmax, ymax], axis=-1) 36 | return boxes 37 | 38 | 39 | def absolute2relative(boxes, img_size): 40 | width, height = img_size 41 | scale = tf.constant([width, height, width, height], dtype=tf.float32) 42 | boxes *= scale 43 | return boxes 44 | 45 | 46 | def xyxy2xywh(boxes): 47 | xmin, ymin, xmax, ymax = [boxes[..., i] for i in range(4)] 48 | return tf.stack([xmin, ymin, xmax - xmin, ymax - ymin], axis=-1) 49 | 50 | 51 | def preprocess_image(image): 52 | image = resize(image, min_side=800.0, max_side=1333.0) 53 | 54 | channel_avg = tf.constant([0.485, 0.456, 0.406]) 55 | channel_std = tf.constant([0.229, 0.224, 0.225]) 56 | image = (image / 255.0 - channel_avg) / channel_std 57 | 58 | return image, build_mask(image) 59 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: detr_tensorflow 2 | channels: 3 | - conda-forge 4 | - anaconda 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _tflow_select=2.1.0=gpu 9 | - absl-py=0.12.0=py38h06a4308_0 10 | - aiohttp=3.6.3=py38h7b6447c_0 11 | - astunparse=1.6.3=py_0 12 | - async-timeout=3.0.1=py38_0 13 | - attrs=20.2.0=py_0 14 | - blas=1.0=mkl 15 | - blinker=1.4=py38_0 16 | - brotlipy=0.7.0=py38h7b6447c_1000 17 | - c-ares=1.17.1=h27cfd23_0 18 | - ca-certificates=2021.4.13=h06a4308_1 19 | - cachetools=4.1.1=py_0 20 | - certifi=2020.12.5=py38h06a4308_0 21 | - cffi=1.14.3=py38he30daa8_0 22 | - chardet=3.0.4=py38_1003 23 | - click=7.1.2=py_0 24 | - coverage=5.3=py38h7b6447c_0 25 | - cryptography=3.1.1=py38h1ba5d50_0 26 | - cudatoolkit=10.1.243=h6bb024c_0 27 | - cudnn=7.6.5=cuda10.1_0 28 | - cupti=10.1.168=0 29 | - cycler=0.10.0=py_2 30 | - cython=0.29.21=py38he6710b0_0 31 | - freetype=2.10.4=h7ca028e_0 32 | - gast=0.4.0=py_0 33 | - google-auth=1.22.1=py_0 34 | - google-auth-oauthlib=0.4.1=py_2 35 | - google-pasta=0.2.0=py_0 36 | - grpcio=1.36.1=py38h2157cd5_1 37 | - h5py=2.10.0=py38hd6299e0_1 38 | - hdf5=1.10.6=hb1b8bf9_0 39 | - icu=67.1=he1b5a44_0 40 | - idna=2.10=py_0 41 | - importlib-metadata=2.0.0=py_1 42 | - intel-openmp=2021.2.0=h06a4308_610 43 | - jpeg=9b=h024ee3a_2 44 | - keras-preprocessing=1.1.2=pyhd3eb1b0_0 45 | - kiwisolver=1.3.1=py38h82cb98a_0 46 | - lcms2=2.12=h3be6417_0 47 | - ld_impl_linux-64=2.33.1=h53a641e_7 48 | - libffi=3.3=he6710b0_2 49 | - libgcc-ng=9.1.0=hdf63c60_0 50 | - libgfortran-ng=7.3.0=hdf63c60_0 51 | - libpng=1.6.37=h21135ba_2 52 | - libprotobuf=3.14.0=h8c45485_0 53 | - libstdcxx-ng=9.1.0=hdf63c60_0 54 | - libtiff=4.1.0=h2733197_1 55 | - lz4-c=1.9.3=h2531618_0 56 | - markdown=3.3.2=py38_0 57 | - matplotlib-base=3.2.2=py38h5d868c9_1 58 | - mkl=2021.2.0=h06a4308_296 59 | - mkl-service=2.3.0=py38h27cfd23_1 60 | - mkl_fft=1.3.0=py38h42c9631_2 61 | - mkl_random=1.2.1=py38ha9443f7_2 62 | - multidict=4.7.6=py38h7b6447c_1 63 | - ncurses=6.2=he6710b0_1 64 | - numpy=1.20.2=py38h2d18471_0 65 | - numpy-base=1.20.2=py38hfae3a4d_0 66 | - oauthlib=3.1.0=py_0 67 | - olefile=0.46=py_0 68 | - openssl=1.1.1k=h27cfd23_0 69 | - opt_einsum=3.1.0=py_0 70 | - pillow=8.2.0=py38he98fc37_0 71 | - pip=21.1.1=py38h06a4308_0 72 | - protobuf=3.14.0=py38h2531618_1 73 | - pyasn1=0.4.8=py_0 74 | - pyasn1-modules=0.2.8=py_0 75 | - pycocotools=2.0.2=py38h1e0a361_1 76 | - pycparser=2.20=py_2 77 | - pyjwt=1.7.1=py38_0 78 | - pyopenssl=19.1.0=py_1 79 | - pyparsing=2.4.7=pyh9f0ad1d_0 80 | - pysocks=1.7.1=py38_0 81 | - python=3.8.10=hdb3f193_7 82 | - python-dateutil=2.8.1=py_0 83 | - python-flatbuffers=1.12=pyhd3eb1b0_0 84 | - python_abi=3.8=1_cp38 85 | - readline=8.1=h27cfd23_0 86 | - requests=2.24.0=py_0 87 | - requests-oauthlib=1.3.0=py_0 88 | - rsa=4.6=py_0 89 | - scipy=1.6.2=py38had2a1c9_1 90 | - setuptools=52.0.0=py38h06a4308_0 91 | - six=1.15.0=py_0 92 | - sqlite=3.35.4=hdfb4753_0 93 | - tensorboard=2.4.0=pyhc547734_0 94 | - tensorboard-plugin-wit=1.6.0=py_0 95 | - tensorflow=2.4.1=gpu_py38h8a7d6ce_0 96 | - tensorflow-base=2.4.1=gpu_py38h29c2da4_0 97 | - tensorflow-estimator=2.4.1=pyheb71bc4_0 98 | - tensorflow-gpu=2.4.1=h30adc30_0 99 | - termcolor=1.1.0=py38_1 100 | - tk=8.6.10=hbc83047_0 101 | - tornado=6.1=py38h25fe258_0 102 | - tqdm=4.59.0=pyhd3eb1b0_1 103 | - urllib3=1.25.11=py_0 104 | - werkzeug=1.0.1=py_0 105 | - wheel=0.36.2=pyhd3eb1b0_0 106 | - wrapt=1.12.1=py38h7b6447c_1 107 | - xz=5.2.5=h7b6447c_0 108 | - yarl=1.6.2=py38h7b6447c_0 109 | - zipp=3.3.1=py_0 110 | - zlib=1.2.11=h7b6447c_3 111 | - zstd=1.4.9=haebb681_0 112 | prefix: ./detr_tensorflow/env 113 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pycocotools.coco import COCO 4 | from pycocotools.cocoeval import COCOeval 5 | import tensorflow as tf 6 | from tqdm import tqdm 7 | 8 | from detr_tensorflow.datasets import COCODatasetBBoxes 9 | from detr_tensorflow import models 10 | from detr_tensorflow.utils import (preprocess_image, read_jpeg_image, 11 | absolute2relative, xyxy2xywh) 12 | 13 | 14 | parser = argparse.ArgumentParser( 15 | 'DETR evalutaion script for the COCO dataset.') 16 | 17 | parser.add_argument('--coco_path', type=str, 18 | help='Path to the COCO dataset root directory. ' 19 | 'For evaluation, only the ' 20 | 'validation data needs to be downloaded.') 21 | parser.add_argument('--backbone', type=str, default=None, 22 | choices=('resnet50', 'resnet50-dc5', 23 | 'resnet101', 'resnet101-dc5'), 24 | help='Choice of backbone CNN for the model.') 25 | parser.add_argument('--frozen_weights', type=str, default=None, 26 | help='Path to the pretrained weights file. ' 27 | 'Please check the repository for links to download ' 28 | 'tensorflow ports of the official ones.') 29 | parser.add_argument('--batch_size', type=int, default=2) 30 | parser.add_argument('--results_file', type=str, default='results.json', 31 | help='.json file to save the results in the COCO format.') 32 | parser.add_argument('--from_file', action='store_true', 33 | help='If specified, will compute the results using ' 34 | 'the predictions in the --results_file, instead of ' 35 | 'performing inference on the whole validation set again.') 36 | 37 | args = parser.parse_args() 38 | 39 | 40 | coco_data = COCODatasetBBoxes( 41 | args.coco_path, partition='val2017', return_boxes=False) 42 | 43 | 44 | def evaluate(results): 45 | coco_dt = COCO.loadRes(coco_data.coco, args.results_file) 46 | cocoEval = COCOeval(coco_data.coco, coco_dt, iouType='bbox') 47 | cocoEval.evaluate() 48 | cocoEval.accumulate() 49 | cocoEval.summarize() 50 | 51 | 52 | if args.from_file: 53 | evaluate(args.results_file) 54 | exit() 55 | 56 | if args.backbone is None or args.frozen_weights is None: 57 | raise Exception('If --from_file is not provided, ' 58 | 'both --backbone and --frozen_weights ' 59 | 'must be provided.') 60 | 61 | model_fns = { 62 | 'resnet50': models.default.build_detr_resnet50, 63 | 'resnet50-dc5': models.default.build_detr_resnet50_dc5, 64 | 'resnet101': models.default.build_detr_resnet101, 65 | 'resnet101-dc5': models.default.build_detr_resnet101_dc5 66 | } 67 | 68 | detr = model_fns[args.backbone](num_classes=91) 69 | detr.build() 70 | detr.load_weights(args.frozen_weights) 71 | 72 | 73 | dataset = tf.data.Dataset.from_generator( 74 | lambda: coco_data, (tf.int32, tf.string)) 75 | dataset = dataset.map( 76 | lambda img_id, img_path: (img_id, read_jpeg_image(img_path))) 77 | dataset = dataset.map( 78 | lambda img_id, image: (img_id, *preprocess_image(image))) 79 | 80 | dataset = dataset.padded_batch( 81 | batch_size=args.batch_size, 82 | padded_shapes=((), (None, None, 3), (None, None)), 83 | padding_values=(None, tf.constant(0.0), tf.constant(True))) 84 | 85 | results = [] 86 | 87 | with tqdm(total=len(coco_data)) as pbar: 88 | for img_ids, images, masks in dataset: 89 | outputs = detr((images, masks), post_process=True) 90 | 91 | for img_id, scores, labels, boxes in zip( 92 | img_ids, outputs['scores'], 93 | outputs['labels'], outputs['boxes']): 94 | img_id = img_id.numpy() 95 | 96 | img_info = coco_data.coco.loadImgs([img_id])[0] 97 | img_height = img_info['height'] 98 | img_width = img_info['width'] 99 | 100 | for score, label, box in zip(scores, labels, boxes): 101 | score = score.numpy() 102 | label = label.numpy() 103 | box = absolute2relative(box, (img_width, img_height)) 104 | box = xyxy2xywh(box).numpy() 105 | 106 | results.append({ 107 | "image_id": int(img_id), 108 | "category_id": int(label), 109 | "bbox": box.tolist(), 110 | "score": float(score) 111 | }) 112 | 113 | pbar.update(int(len(images))) 114 | 115 | json_object = json.dumps(results, indent=2) 116 | with open(args.results_file, 'w') as f: 117 | f.write(json_object) 118 | 119 | evaluate(args.results_file) 120 | -------------------------------------------------------------------------------- /samples/sample_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Leonardo-Blanger/detr_tensorflow/38fc3c586b6767deed09bd7ec6c2a2fd7002346e/samples/sample_1.jpg -------------------------------------------------------------------------------- /samples/sample_1_boxes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Leonardo-Blanger/detr_tensorflow/38fc3c586b6767deed09bd7ec6c2a2fd7002346e/samples/sample_1_boxes.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name='detr_tensorflow', 5 | packages=find_packages(include=['detr_tensorflow']), 6 | version='0.1.0', 7 | description='DETR Object Detection architecture in Tensorflow', 8 | author='Leonardo Blanger', 9 | ) 10 | --------------------------------------------------------------------------------