├── arial.ttf ├── README.md ├── processor ├── microsoft │ ├── trocr-base-stage1 │ │ ├── preprocessor_config.json │ │ ├── special_tokens_map.json │ │ └── tokenizer_config.json │ └── trocr-base-handwritten │ │ ├── preprocessor_config.json │ │ ├── special_tokens_map.json │ │ └── tokenizer_config.json └── facebook │ └── detr-resnet-50 │ └── preprocessor_config.json ├── 1.下载文件.ipynb └── 2.文字定位Loss函数.ipynb /arial.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lansinuote/Simple_OCR/HEAD/arial.ttf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 一个简单的OCR模型的训练过程 2 | 3 | 环境信息: 4 | 5 | PyTorch==1.12.1+cuda 6 | 7 | transformers==4.36.2 8 | 9 | datasets==2.16.1 10 | 11 | 视频课程:https://www.bilibili.com/video/BV1nJ4m1t7KU 12 | -------------------------------------------------------------------------------- /processor/microsoft/trocr-base-stage1/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_normalize": true, 3 | "do_rescale": true, 4 | "do_resize": true, 5 | "image_mean": [ 6 | 0.5, 7 | 0.5, 8 | 0.5 9 | ], 10 | "image_processor_type": "ViTImageProcessor", 11 | "image_std": [ 12 | 0.5, 13 | 0.5, 14 | 0.5 15 | ], 16 | "processor_class": "TrOCRProcessor", 17 | "resample": 2, 18 | "rescale_factor": 0.00392156862745098, 19 | "size": { 20 | "height": 384, 21 | "width": 384 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /processor/microsoft/trocr-base-handwritten/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_normalize": true, 3 | "do_rescale": true, 4 | "do_resize": true, 5 | "image_mean": [ 6 | 0.5, 7 | 0.5, 8 | 0.5 9 | ], 10 | "image_processor_type": "ViTImageProcessor", 11 | "image_std": [ 12 | 0.5, 13 | 0.5, 14 | 0.5 15 | ], 16 | "processor_class": "TrOCRProcessor", 17 | "resample": 2, 18 | "rescale_factor": 0.00392156862745098, 19 | "size": { 20 | "height": 384, 21 | "width": 384 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /processor/facebook/detr-resnet-50/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_normalize": true, 3 | "do_pad": true, 4 | "do_rescale": true, 5 | "do_resize": true, 6 | "format": "coco_detection", 7 | "image_mean": [ 8 | 0.485, 9 | 0.456, 10 | 0.406 11 | ], 12 | "image_processor_type": "DetrImageProcessor", 13 | "image_std": [ 14 | 0.229, 15 | 0.224, 16 | 0.225 17 | ], 18 | "resample": 2, 19 | "rescale_factor": 0.00392156862745098, 20 | "size": { 21 | "longest_edge": 1333, 22 | "shortest_edge": 800 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /processor/microsoft/trocr-base-stage1/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": { 3 | "content": "", 4 | "lstrip": false, 5 | "normalized": true, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "cls_token": { 10 | "content": "", 11 | "lstrip": false, 12 | "normalized": true, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "eos_token": { 17 | "content": "", 18 | "lstrip": false, 19 | "normalized": true, 20 | "rstrip": false, 21 | "single_word": false 22 | }, 23 | "mask_token": { 24 | "content": "", 25 | "lstrip": true, 26 | "normalized": true, 27 | "rstrip": false, 28 | "single_word": false 29 | }, 30 | "pad_token": { 31 | "content": "", 32 | "lstrip": false, 33 | "normalized": true, 34 | "rstrip": false, 35 | "single_word": false 36 | }, 37 | "sep_token": { 38 | "content": "", 39 | "lstrip": false, 40 | "normalized": true, 41 | "rstrip": false, 42 | "single_word": false 43 | }, 44 | "unk_token": { 45 | "content": "", 46 | "lstrip": false, 47 | "normalized": true, 48 | "rstrip": false, 49 | "single_word": false 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /processor/microsoft/trocr-base-handwritten/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": { 3 | "content": "", 4 | "lstrip": false, 5 | "normalized": true, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "cls_token": { 10 | "content": "", 11 | "lstrip": false, 12 | "normalized": true, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "eos_token": { 17 | "content": "", 18 | "lstrip": false, 19 | "normalized": true, 20 | "rstrip": false, 21 | "single_word": false 22 | }, 23 | "mask_token": { 24 | "content": "", 25 | "lstrip": true, 26 | "normalized": true, 27 | "rstrip": false, 28 | "single_word": false 29 | }, 30 | "pad_token": { 31 | "content": "", 32 | "lstrip": false, 33 | "normalized": true, 34 | "rstrip": false, 35 | "single_word": false 36 | }, 37 | "sep_token": { 38 | "content": "", 39 | "lstrip": false, 40 | "normalized": true, 41 | "rstrip": false, 42 | "single_word": false 43 | }, 44 | "unk_token": { 45 | "content": "", 46 | "lstrip": false, 47 | "normalized": true, 48 | "rstrip": false, 49 | "single_word": false 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /processor/microsoft/trocr-base-stage1/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "added_tokens_decoder": { 4 | "0": { 5 | "content": "", 6 | "lstrip": false, 7 | "normalized": true, 8 | "rstrip": false, 9 | "single_word": false, 10 | "special": true 11 | }, 12 | "1": { 13 | "content": "", 14 | "lstrip": false, 15 | "normalized": true, 16 | "rstrip": false, 17 | "single_word": false, 18 | "special": true 19 | }, 20 | "2": { 21 | "content": "", 22 | "lstrip": false, 23 | "normalized": true, 24 | "rstrip": false, 25 | "single_word": false, 26 | "special": true 27 | }, 28 | "3": { 29 | "content": "", 30 | "lstrip": false, 31 | "normalized": true, 32 | "rstrip": false, 33 | "single_word": false, 34 | "special": true 35 | }, 36 | "50264": { 37 | "content": "", 38 | "lstrip": true, 39 | "normalized": true, 40 | "rstrip": false, 41 | "single_word": false, 42 | "special": true 43 | } 44 | }, 45 | "bos_token": "", 46 | "clean_up_tokenization_spaces": true, 47 | "cls_token": "", 48 | "eos_token": "", 49 | "errors": "replace", 50 | "mask_token": "", 51 | "model_max_length": 512, 52 | "pad_token": "", 53 | "processor_class": "TrOCRProcessor", 54 | "sep_token": "", 55 | "tokenizer_class": "RobertaTokenizer", 56 | "trim_offsets": true, 57 | "unk_token": "" 58 | } 59 | -------------------------------------------------------------------------------- /processor/microsoft/trocr-base-handwritten/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "added_tokens_decoder": { 4 | "0": { 5 | "content": "", 6 | "lstrip": false, 7 | "normalized": true, 8 | "rstrip": false, 9 | "single_word": false, 10 | "special": true 11 | }, 12 | "1": { 13 | "content": "", 14 | "lstrip": false, 15 | "normalized": true, 16 | "rstrip": false, 17 | "single_word": false, 18 | "special": true 19 | }, 20 | "2": { 21 | "content": "", 22 | "lstrip": false, 23 | "normalized": true, 24 | "rstrip": false, 25 | "single_word": false, 26 | "special": true 27 | }, 28 | "3": { 29 | "content": "", 30 | "lstrip": false, 31 | "normalized": true, 32 | "rstrip": false, 33 | "single_word": false, 34 | "special": true 35 | }, 36 | "50264": { 37 | "content": "", 38 | "lstrip": true, 39 | "normalized": true, 40 | "rstrip": false, 41 | "single_word": false, 42 | "special": true 43 | } 44 | }, 45 | "bos_token": "", 46 | "clean_up_tokenization_spaces": true, 47 | "cls_token": "", 48 | "eos_token": "", 49 | "errors": "replace", 50 | "mask_token": "", 51 | "model_max_length": 512, 52 | "pad_token": "", 53 | "processor_class": "TrOCRProcessor", 54 | "sep_token": "", 55 | "tokenizer_class": "RobertaTokenizer", 56 | "trim_offsets": true, 57 | "unk_token": "" 58 | } 59 | -------------------------------------------------------------------------------- /1.下载文件.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "c5e0ea0f", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# from transformers import AutoImageProcessor\n", 11 | "\n", 12 | "# name = 'facebook/detr-resnet-50'\n", 13 | "# AutoImageProcessor.from_pretrained(name).save_pretrained('processor/' + name)" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "id": "fe8dc7a2-b442-416d-ae17-17c1d0a8dcf0", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "# from transformers import TrOCRProcessor\n", 24 | "\n", 25 | "# name = 'microsoft/trocr-base-handwritten'\n", 26 | "# TrOCRProcessor.from_pretrained(name).save_pretrained('processor/' + name)" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "id": "37bfb12a", 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "from datasets import load_dataset\n", 37 | "\n", 38 | "name = 'lansinuote/ocr_id_card_small'\n", 39 | "load_dataset(name).save_to_disk('dataset/' + name)" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "id": "c3ac989a", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "from transformers import AutoModelForObjectDetection\n", 50 | "\n", 51 | "name = 'facebook/detr-resnet-50'\n", 52 | "AutoModelForObjectDetection.from_pretrained(name).save_pretrained(\n", 53 | " 'model/' + name, ignore_mismatched_sizes=True)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "id": "e8e71ac5-4c9d-463a-adea-5d6f94744da2", 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "from transformers import VisionEncoderDecoderModel\n", 64 | "\n", 65 | "name = 'microsoft/trocr-base-stage1'\n", 66 | "VisionEncoderDecoderModel.from_pretrained(name).save_pretrained('model/' +\n", 67 | " name)" 68 | ] 69 | } 70 | ], 71 | "metadata": { 72 | "kernelspec": { 73 | "display_name": "Python [conda env:pt]", 74 | "language": "python", 75 | "name": "conda-env-pt-py" 76 | }, 77 | "language_info": { 78 | "codemirror_mode": { 79 | "name": "ipython", 80 | "version": 3 81 | }, 82 | "file_extension": ".py", 83 | "mimetype": "text/x-python", 84 | "name": "python", 85 | "nbconvert_exporter": "python", 86 | "pygments_lexer": "ipython3", 87 | "version": "3.10.13" 88 | } 89 | }, 90 | "nbformat": 4, 91 | "nbformat_minor": 5 92 | } 93 | -------------------------------------------------------------------------------- /2.文字定位Loss函数.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "700aec2d", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stderr", 11 | "output_type": "stream", 12 | "text": [ 13 | "/root/miniconda3/envs/pt/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 14 | " from .autonotebook import tqdm as notebook_tqdm\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "import torch\n", 20 | "\n", 21 | "#测试数据\n", 22 | "targets = [{\n", 23 | " 'class_labels': torch.randint(low=0, high=4, size=[5]),\n", 24 | " 'boxes': torch.rand(5, 4),\n", 25 | "} for _ in range(8)]\n", 26 | "\n", 27 | "outputs = {\n", 28 | " 'logits': torch.randn(8, 100, 92),\n", 29 | " 'pred_boxes': torch.rand(8, 100, 4)\n", 30 | "}" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 2, 36 | "id": "309a460f", 37 | "metadata": {}, 38 | "outputs": [ 39 | { 40 | "data": { 41 | "text/plain": [ 42 | "[(tensor([ 7, 8, 35, 92, 97]), tensor([0, 1, 4, 2, 3])),\n", 43 | " (tensor([35, 36, 44, 59, 61]), tensor([2, 1, 0, 4, 3])),\n", 44 | " (tensor([10, 31, 37, 54, 63]), tensor([4, 0, 2, 3, 1])),\n", 45 | " (tensor([ 2, 21, 39, 55, 84]), tensor([3, 2, 1, 4, 0])),\n", 46 | " (tensor([12, 15, 19, 20, 23]), tensor([2, 0, 3, 1, 4])),\n", 47 | " (tensor([ 7, 48, 58, 63, 87]), tensor([0, 4, 2, 1, 3])),\n", 48 | " (tensor([13, 40, 60, 85, 90]), tensor([0, 4, 1, 2, 3])),\n", 49 | " (tensor([15, 23, 35, 68, 83]), tensor([3, 1, 2, 0, 4]))]" 50 | ] 51 | }, 52 | "execution_count": 2, 53 | "metadata": {}, 54 | "output_type": "execute_result" 55 | } 56 | ], 57 | "source": [ 58 | "from scipy.optimize._lsap import linear_sum_assignment\n", 59 | "\n", 60 | "\n", 61 | "class DetrHungarianMatcher:\n", 62 | "\n", 63 | " def box_iou(self, boxes1, boxes2):\n", 64 | " inter = []\n", 65 | " union = []\n", 66 | " area = []\n", 67 | " for box1 in boxes1:\n", 68 | " for box2 in boxes2:\n", 69 | " #求交集面积\n", 70 | " x1 = torch.max(box1[0], box2[0])\n", 71 | " y1 = torch.max(box1[1], box2[1])\n", 72 | " x2 = torch.min(box1[2], box2[2])\n", 73 | " y2 = torch.min(box1[3], box2[3])\n", 74 | "\n", 75 | " w = (x2 - x1).clamp(min=0)\n", 76 | " h = (y2 - y1).clamp(min=0)\n", 77 | "\n", 78 | " inter.append(w * h)\n", 79 | "\n", 80 | " #求并集面积\n", 81 | " w1 = box1[2] - box1[0]\n", 82 | " h1 = box1[3] - box1[1]\n", 83 | " w2 = box2[2] - box2[0]\n", 84 | " h2 = box2[3] - box2[1]\n", 85 | "\n", 86 | " s1 = w1 * h1\n", 87 | " s2 = w2 * h2\n", 88 | "\n", 89 | " union.append(s1 + s2 - (w * h))\n", 90 | "\n", 91 | " #求扩展面积\n", 92 | " x1 = torch.min(box1[0], box2[0])\n", 93 | " y1 = torch.min(box1[1], box2[1])\n", 94 | " x2 = torch.max(box1[2], box2[2])\n", 95 | " y2 = torch.max(box1[3], box2[3])\n", 96 | "\n", 97 | " w = (x2 - x1).clamp(min=0)\n", 98 | " h = (y2 - y1).clamp(min=0)\n", 99 | "\n", 100 | " area.append(w * h)\n", 101 | "\n", 102 | " inter = torch.stack(inter).reshape(len(boxes1), len(boxes2))\n", 103 | " union = torch.stack(union).reshape(len(boxes1), len(boxes2))\n", 104 | " area = torch.stack(area).reshape(len(boxes1), len(boxes2))\n", 105 | "\n", 106 | " #前面的数是iou,值域0-1,衡量了两个框重合的程度,这个数越大越好\n", 107 | "\n", 108 | " #后面的数是个分数,分别来看分子和分母\n", 109 | " #分子是扩展面积-并集面积,这个数显然是越小越好.\n", 110 | " #分母是扩展面积,显然是起归一化作用,所以这个分数的值域是0-1\n", 111 | "\n", 112 | " #综合以上,这个数总的来说还是iou,只是额外考虑了扩展面积的情况\n", 113 | " return (inter / union) - (area - union) / area\n", 114 | "\n", 115 | " #等价写法,上面的写法效率低\n", 116 | " def box_iou(self, boxes1, boxes2):\n", 117 | " area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])\n", 118 | " area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])\n", 119 | "\n", 120 | " p1 = torch.max(boxes1[:, :2].unsqueeze(1), boxes2[:, :2])\n", 121 | " p2 = torch.min(boxes1[:, 2:].unsqueeze(1), boxes2[:, 2:])\n", 122 | " wh = (p2 - p1).clamp(min=0)\n", 123 | " inter = wh[:, :, 0] * wh[:, :, 1]\n", 124 | "\n", 125 | " union = area1.unsqueeze(1) + area2 - inter\n", 126 | "\n", 127 | " p1 = torch.min(boxes1[:, :2].unsqueeze(1), boxes2[:, :2])\n", 128 | " p2 = torch.max(boxes1[:, 2:].unsqueeze(1), boxes2[:, 2:])\n", 129 | " wh = (p2 - p1).clamp(min=0)\n", 130 | " area = wh[:, :, 0] * wh[:, :, 1]\n", 131 | "\n", 132 | " return (inter / union) - (area - union) / area\n", 133 | "\n", 134 | " def xywh_to_x1y1x2y2(self, boxes):\n", 135 | " x = boxes[:, 0]\n", 136 | " y = boxes[:, 1]\n", 137 | " w = boxes[:, 2]\n", 138 | " h = boxes[:, 3]\n", 139 | "\n", 140 | " x1 = x - 0.5 * w\n", 141 | " y1 = y - 0.5 * h\n", 142 | " x2 = x + 0.5 * w\n", 143 | " y2 = y + 0.5 * h\n", 144 | "\n", 145 | " return torch.stack([x1, y1, x2, y2], dim=-1)\n", 146 | "\n", 147 | " @torch.no_grad()\n", 148 | " def __call__(self, outputs, targets):\n", 149 | " #取所有框和预测结果\n", 150 | " #[8, 100, 92] -> [800, 92]\n", 151 | " logits = outputs['logits'].flatten(0, 1).softmax(1)\n", 152 | "\n", 153 | " #[8, 100, 4] -> [800, 4]\n", 154 | " pred_boxes = outputs['pred_boxes'].flatten(0, 1)\n", 155 | "\n", 156 | " #取目标\n", 157 | " #[52]\n", 158 | " class_labels = torch.cat([i['class_labels'] for i in targets])\n", 159 | "\n", 160 | " #[52, 4]\n", 161 | " target_boxes = torch.cat([i['boxes'] for i in targets])\n", 162 | "\n", 163 | " #label的loss,简单的预测概率取反\n", 164 | " #[800, 92] -> [800, 52]\n", 165 | " class_cost = -logits[:, class_labels]\n", 166 | "\n", 167 | " #框的loss,4个点距离的和作为loss\n", 168 | " #[800, 4],[52, 4] -> [800, 52]\n", 169 | " bbox_cost = []\n", 170 | " for box1 in pred_boxes:\n", 171 | " cost = [(box1 - box2).abs().sum() for box2 in target_boxes]\n", 172 | " bbox_cost.append(torch.stack(cost))\n", 173 | " bbox_cost = torch.stack(bbox_cost)\n", 174 | "\n", 175 | " #等价写法\n", 176 | " #bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)\n", 177 | "\n", 178 | " #[800, 52]\n", 179 | " giou_cost = -self.box_iou(self.xywh_to_x1y1x2y2(pred_boxes),\n", 180 | " self.xywh_to_x1y1x2y2(target_boxes))\n", 181 | "\n", 182 | " #[800, 52] -> [8, 100, 52]\n", 183 | " cost = 5 * bbox_cost + 1 * class_cost + 2 * giou_cost\n", 184 | " cost = cost.view(8, 100, -1).cpu()\n", 185 | "\n", 186 | " indices = []\n", 187 | " sum_s = 0\n", 188 | " for c, t in zip(cost, targets):\n", 189 | " #取目标框的数量\n", 190 | " s = len(t['boxes'])\n", 191 | "\n", 192 | " #取这些框的计算结果\n", 193 | " #[100, lens]\n", 194 | " c = c[:, sum_s:sum_s + s]\n", 195 | "\n", 196 | " #累计索引\n", 197 | " sum_s = sum_s + s\n", 198 | "\n", 199 | " #c这个矩阵记录的是loss的情况\n", 200 | " #求最小loss的分配方式\n", 201 | " index_row, index_col = linear_sum_assignment(c)\n", 202 | " index_row = torch.LongTensor(index_row)\n", 203 | " index_col = torch.LongTensor(index_col)\n", 204 | " indices.append((index_row, index_col))\n", 205 | "\n", 206 | " return indices\n", 207 | "\n", 208 | "\n", 209 | "matcher = DetrHungarianMatcher()\n", 210 | "\n", 211 | "matcher(outputs, targets)" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 3, 217 | "id": "5eceb918", 218 | "metadata": {}, 219 | "outputs": [ 220 | { 221 | "data": { 222 | "text/plain": [ 223 | "{'loss_ce': tensor(4.9638),\n", 224 | " 'loss_bbox': tensor(0.4050),\n", 225 | " 'loss_giou': tensor(0.7186)}" 226 | ] 227 | }, 228 | "execution_count": 3, 229 | "metadata": {}, 230 | "output_type": "execute_result" 231 | } 232 | ], 233 | "source": [ 234 | "class DetrLoss(torch.nn.Module):\n", 235 | "\n", 236 | " def __init__(self):\n", 237 | " super().__init__()\n", 238 | " empty_weight = torch.ones(92)\n", 239 | " empty_weight[-1] = 0.1\n", 240 | " self.register_buffer('empty_weight', empty_weight)\n", 241 | "\n", 242 | " def loss_labels(self, outputs, targets, indices):\n", 243 | " # 默认都是背景\n", 244 | " #[8, 100]\n", 245 | " target_classes = torch.full([8, 100],\n", 246 | " 91,\n", 247 | " dtype=torch.int64,\n", 248 | " device=outputs['logits'].device)\n", 249 | "\n", 250 | " #遍历8条数据\n", 251 | " for i in range(8):\n", 252 | " #遍历每一个分配结果(最小cost方式分配)\n", 253 | " for io, it in zip(*indices[i]):\n", 254 | " #按照最小cost的方式分配每个预测结果的目标\n", 255 | " target_classes[i, io.item()] = targets[i]['class_labels'][\n", 256 | " it.item()]\n", 257 | "\n", 258 | " #[8, 100, 92] -> [8, 92, 100]\n", 259 | " logits = outputs['logits'].transpose(1, 2)\n", 260 | "\n", 261 | " return torch.nn.functional.cross_entropy(logits, target_classes,\n", 262 | " self.empty_weight)\n", 263 | "\n", 264 | " def loss_boxes(self, outputs, targets, indices):\n", 265 | " boxes_output = []\n", 266 | " boxes_target = []\n", 267 | " # 遍历8条数据\n", 268 | " for i in range(8):\n", 269 | " # 遍历每一个分配结果(最小cost方式分配)\n", 270 | " for io, it in zip(*indices[i]):\n", 271 | " # 按照最小cost的方式取每一对框\n", 272 | " boxes_output.append(outputs['pred_boxes'][i, io.item()])\n", 273 | " boxes_target.append(targets[i]['boxes'][it.item()])\n", 274 | "\n", 275 | " boxes_output = torch.stack(boxes_output)\n", 276 | " boxes_target = torch.stack(boxes_target)\n", 277 | "\n", 278 | " num_boxes = sum(len(i['class_labels']) for i in targets)\n", 279 | " if num_boxes < 1:\n", 280 | " num_boxes = 1\n", 281 | "\n", 282 | " #没对框之间求L1距离作为loss\n", 283 | " loss_bbox = torch.nn.functional.l1_loss(boxes_output,\n", 284 | " boxes_target,\n", 285 | " reduction='none')\n", 286 | " loss_bbox = loss_bbox.sum() / num_boxes\n", 287 | "\n", 288 | " #取iou作为第二部分的loss\n", 289 | " #只需要考虑成对的框的iou,所以取对角线元素计算即可\n", 290 | " #iou是越大越好,优化方向取反,所以loss=1-iou\n", 291 | " giou = matcher.box_iou(matcher.xywh_to_x1y1x2y2(boxes_output),\n", 292 | " matcher.xywh_to_x1y1x2y2(boxes_target))\n", 293 | " loss_giou = 1 - giou.diag()\n", 294 | " loss_giou = loss_giou.sum() / num_boxes\n", 295 | "\n", 296 | " return loss_bbox, loss_giou\n", 297 | "\n", 298 | " def forward(self, outputs, targets):\n", 299 | " indices = matcher(outputs, targets)\n", 300 | "\n", 301 | " losses = {}\n", 302 | " losses['loss_ce'] = self.loss_labels(outputs, targets, indices)\n", 303 | "\n", 304 | " loss_bbox, loss_giou = self.loss_boxes(outputs, targets, indices)\n", 305 | " losses['loss_bbox'] = loss_bbox\n", 306 | " losses['loss_giou'] = loss_giou\n", 307 | "\n", 308 | " return losses\n", 309 | "\n", 310 | "\n", 311 | "criterion = DetrLoss()\n", 312 | "\n", 313 | "criterion(outputs, targets)" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": 4, 319 | "id": "7820caf3", 320 | "metadata": {}, 321 | "outputs": [ 322 | { 323 | "data": { 324 | "text/plain": [ 325 | "{'loss_ce': tensor(4.9638),\n", 326 | " 'loss_bbox': tensor(0.4050),\n", 327 | " 'loss_giou': tensor(0.7186),\n", 328 | " 'cardinality_error': tensor(94.1250)}" 329 | ] 330 | }, 331 | "execution_count": 4, 332 | "metadata": {}, 333 | "output_type": "execute_result" 334 | } 335 | ], 336 | "source": [ 337 | "def test():\n", 338 | " from transformers.models.detr.modeling_detr import DetrLoss, DetrHungarianMatcher\n", 339 | "\n", 340 | " matcher = DetrHungarianMatcher(class_cost=1, bbox_cost=5, giou_cost=2)\n", 341 | "\n", 342 | " criterion = DetrLoss(matcher=matcher,\n", 343 | " num_classes=91,\n", 344 | " eos_coef=0.1,\n", 345 | " losses=['labels', 'boxes', 'cardinality'])\n", 346 | "\n", 347 | " return criterion(outputs, targets)\n", 348 | "\n", 349 | "\n", 350 | "test()" 351 | ] 352 | } 353 | ], 354 | "metadata": { 355 | "kernelspec": { 356 | "display_name": "Python [conda env:pt]", 357 | "language": "python", 358 | "name": "conda-env-pt-py" 359 | }, 360 | "language_info": { 361 | "codemirror_mode": { 362 | "name": "ipython", 363 | "version": 3 364 | }, 365 | "file_extension": ".py", 366 | "mimetype": "text/x-python", 367 | "name": "python", 368 | "nbconvert_exporter": "python", 369 | "pygments_lexer": "ipython3", 370 | "version": "3.10.13" 371 | } 372 | }, 373 | "nbformat": 4, 374 | "nbformat_minor": 5 375 | } 376 | --------------------------------------------------------------------------------