├── arial.ttf
├── README.md
├── processor
├── microsoft
│ ├── trocr-base-stage1
│ │ ├── preprocessor_config.json
│ │ ├── special_tokens_map.json
│ │ └── tokenizer_config.json
│ └── trocr-base-handwritten
│ │ ├── preprocessor_config.json
│ │ ├── special_tokens_map.json
│ │ └── tokenizer_config.json
└── facebook
│ └── detr-resnet-50
│ └── preprocessor_config.json
├── 1.下载文件.ipynb
└── 2.文字定位Loss函数.ipynb
/arial.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lansinuote/Simple_OCR/HEAD/arial.ttf
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 一个简单的OCR模型的训练过程
2 |
3 | 环境信息:
4 |
5 | PyTorch==1.12.1+cuda
6 |
7 | transformers==4.36.2
8 |
9 | datasets==2.16.1
10 |
11 | 视频课程:https://www.bilibili.com/video/BV1nJ4m1t7KU
12 |
--------------------------------------------------------------------------------
/processor/microsoft/trocr-base-stage1/preprocessor_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "do_normalize": true,
3 | "do_rescale": true,
4 | "do_resize": true,
5 | "image_mean": [
6 | 0.5,
7 | 0.5,
8 | 0.5
9 | ],
10 | "image_processor_type": "ViTImageProcessor",
11 | "image_std": [
12 | 0.5,
13 | 0.5,
14 | 0.5
15 | ],
16 | "processor_class": "TrOCRProcessor",
17 | "resample": 2,
18 | "rescale_factor": 0.00392156862745098,
19 | "size": {
20 | "height": 384,
21 | "width": 384
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/processor/microsoft/trocr-base-handwritten/preprocessor_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "do_normalize": true,
3 | "do_rescale": true,
4 | "do_resize": true,
5 | "image_mean": [
6 | 0.5,
7 | 0.5,
8 | 0.5
9 | ],
10 | "image_processor_type": "ViTImageProcessor",
11 | "image_std": [
12 | 0.5,
13 | 0.5,
14 | 0.5
15 | ],
16 | "processor_class": "TrOCRProcessor",
17 | "resample": 2,
18 | "rescale_factor": 0.00392156862745098,
19 | "size": {
20 | "height": 384,
21 | "width": 384
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/processor/facebook/detr-resnet-50/preprocessor_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "do_normalize": true,
3 | "do_pad": true,
4 | "do_rescale": true,
5 | "do_resize": true,
6 | "format": "coco_detection",
7 | "image_mean": [
8 | 0.485,
9 | 0.456,
10 | 0.406
11 | ],
12 | "image_processor_type": "DetrImageProcessor",
13 | "image_std": [
14 | 0.229,
15 | 0.224,
16 | 0.225
17 | ],
18 | "resample": 2,
19 | "rescale_factor": 0.00392156862745098,
20 | "size": {
21 | "longest_edge": 1333,
22 | "shortest_edge": 800
23 | }
24 | }
25 |
--------------------------------------------------------------------------------
/processor/microsoft/trocr-base-stage1/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {
2 | "bos_token": {
3 | "content": "",
4 | "lstrip": false,
5 | "normalized": true,
6 | "rstrip": false,
7 | "single_word": false
8 | },
9 | "cls_token": {
10 | "content": "",
11 | "lstrip": false,
12 | "normalized": true,
13 | "rstrip": false,
14 | "single_word": false
15 | },
16 | "eos_token": {
17 | "content": "",
18 | "lstrip": false,
19 | "normalized": true,
20 | "rstrip": false,
21 | "single_word": false
22 | },
23 | "mask_token": {
24 | "content": "",
25 | "lstrip": true,
26 | "normalized": true,
27 | "rstrip": false,
28 | "single_word": false
29 | },
30 | "pad_token": {
31 | "content": "",
32 | "lstrip": false,
33 | "normalized": true,
34 | "rstrip": false,
35 | "single_word": false
36 | },
37 | "sep_token": {
38 | "content": "",
39 | "lstrip": false,
40 | "normalized": true,
41 | "rstrip": false,
42 | "single_word": false
43 | },
44 | "unk_token": {
45 | "content": "",
46 | "lstrip": false,
47 | "normalized": true,
48 | "rstrip": false,
49 | "single_word": false
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/processor/microsoft/trocr-base-handwritten/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {
2 | "bos_token": {
3 | "content": "",
4 | "lstrip": false,
5 | "normalized": true,
6 | "rstrip": false,
7 | "single_word": false
8 | },
9 | "cls_token": {
10 | "content": "",
11 | "lstrip": false,
12 | "normalized": true,
13 | "rstrip": false,
14 | "single_word": false
15 | },
16 | "eos_token": {
17 | "content": "",
18 | "lstrip": false,
19 | "normalized": true,
20 | "rstrip": false,
21 | "single_word": false
22 | },
23 | "mask_token": {
24 | "content": "",
25 | "lstrip": true,
26 | "normalized": true,
27 | "rstrip": false,
28 | "single_word": false
29 | },
30 | "pad_token": {
31 | "content": "",
32 | "lstrip": false,
33 | "normalized": true,
34 | "rstrip": false,
35 | "single_word": false
36 | },
37 | "sep_token": {
38 | "content": "",
39 | "lstrip": false,
40 | "normalized": true,
41 | "rstrip": false,
42 | "single_word": false
43 | },
44 | "unk_token": {
45 | "content": "",
46 | "lstrip": false,
47 | "normalized": true,
48 | "rstrip": false,
49 | "single_word": false
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/processor/microsoft/trocr-base-stage1/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "add_prefix_space": false,
3 | "added_tokens_decoder": {
4 | "0": {
5 | "content": "",
6 | "lstrip": false,
7 | "normalized": true,
8 | "rstrip": false,
9 | "single_word": false,
10 | "special": true
11 | },
12 | "1": {
13 | "content": "",
14 | "lstrip": false,
15 | "normalized": true,
16 | "rstrip": false,
17 | "single_word": false,
18 | "special": true
19 | },
20 | "2": {
21 | "content": "",
22 | "lstrip": false,
23 | "normalized": true,
24 | "rstrip": false,
25 | "single_word": false,
26 | "special": true
27 | },
28 | "3": {
29 | "content": "",
30 | "lstrip": false,
31 | "normalized": true,
32 | "rstrip": false,
33 | "single_word": false,
34 | "special": true
35 | },
36 | "50264": {
37 | "content": "",
38 | "lstrip": true,
39 | "normalized": true,
40 | "rstrip": false,
41 | "single_word": false,
42 | "special": true
43 | }
44 | },
45 | "bos_token": "",
46 | "clean_up_tokenization_spaces": true,
47 | "cls_token": "",
48 | "eos_token": "",
49 | "errors": "replace",
50 | "mask_token": "",
51 | "model_max_length": 512,
52 | "pad_token": "",
53 | "processor_class": "TrOCRProcessor",
54 | "sep_token": "",
55 | "tokenizer_class": "RobertaTokenizer",
56 | "trim_offsets": true,
57 | "unk_token": ""
58 | }
59 |
--------------------------------------------------------------------------------
/processor/microsoft/trocr-base-handwritten/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "add_prefix_space": false,
3 | "added_tokens_decoder": {
4 | "0": {
5 | "content": "",
6 | "lstrip": false,
7 | "normalized": true,
8 | "rstrip": false,
9 | "single_word": false,
10 | "special": true
11 | },
12 | "1": {
13 | "content": "",
14 | "lstrip": false,
15 | "normalized": true,
16 | "rstrip": false,
17 | "single_word": false,
18 | "special": true
19 | },
20 | "2": {
21 | "content": "",
22 | "lstrip": false,
23 | "normalized": true,
24 | "rstrip": false,
25 | "single_word": false,
26 | "special": true
27 | },
28 | "3": {
29 | "content": "",
30 | "lstrip": false,
31 | "normalized": true,
32 | "rstrip": false,
33 | "single_word": false,
34 | "special": true
35 | },
36 | "50264": {
37 | "content": "",
38 | "lstrip": true,
39 | "normalized": true,
40 | "rstrip": false,
41 | "single_word": false,
42 | "special": true
43 | }
44 | },
45 | "bos_token": "",
46 | "clean_up_tokenization_spaces": true,
47 | "cls_token": "",
48 | "eos_token": "",
49 | "errors": "replace",
50 | "mask_token": "",
51 | "model_max_length": 512,
52 | "pad_token": "",
53 | "processor_class": "TrOCRProcessor",
54 | "sep_token": "",
55 | "tokenizer_class": "RobertaTokenizer",
56 | "trim_offsets": true,
57 | "unk_token": ""
58 | }
59 |
--------------------------------------------------------------------------------
/1.下载文件.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "c5e0ea0f",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "# from transformers import AutoImageProcessor\n",
11 | "\n",
12 | "# name = 'facebook/detr-resnet-50'\n",
13 | "# AutoImageProcessor.from_pretrained(name).save_pretrained('processor/' + name)"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": null,
19 | "id": "fe8dc7a2-b442-416d-ae17-17c1d0a8dcf0",
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "# from transformers import TrOCRProcessor\n",
24 | "\n",
25 | "# name = 'microsoft/trocr-base-handwritten'\n",
26 | "# TrOCRProcessor.from_pretrained(name).save_pretrained('processor/' + name)"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": null,
32 | "id": "37bfb12a",
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "from datasets import load_dataset\n",
37 | "\n",
38 | "name = 'lansinuote/ocr_id_card_small'\n",
39 | "load_dataset(name).save_to_disk('dataset/' + name)"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "id": "c3ac989a",
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "from transformers import AutoModelForObjectDetection\n",
50 | "\n",
51 | "name = 'facebook/detr-resnet-50'\n",
52 | "AutoModelForObjectDetection.from_pretrained(name).save_pretrained(\n",
53 | " 'model/' + name, ignore_mismatched_sizes=True)"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "id": "e8e71ac5-4c9d-463a-adea-5d6f94744da2",
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "from transformers import VisionEncoderDecoderModel\n",
64 | "\n",
65 | "name = 'microsoft/trocr-base-stage1'\n",
66 | "VisionEncoderDecoderModel.from_pretrained(name).save_pretrained('model/' +\n",
67 | " name)"
68 | ]
69 | }
70 | ],
71 | "metadata": {
72 | "kernelspec": {
73 | "display_name": "Python [conda env:pt]",
74 | "language": "python",
75 | "name": "conda-env-pt-py"
76 | },
77 | "language_info": {
78 | "codemirror_mode": {
79 | "name": "ipython",
80 | "version": 3
81 | },
82 | "file_extension": ".py",
83 | "mimetype": "text/x-python",
84 | "name": "python",
85 | "nbconvert_exporter": "python",
86 | "pygments_lexer": "ipython3",
87 | "version": "3.10.13"
88 | }
89 | },
90 | "nbformat": 4,
91 | "nbformat_minor": 5
92 | }
93 |
--------------------------------------------------------------------------------
/2.文字定位Loss函数.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "700aec2d",
7 | "metadata": {},
8 | "outputs": [
9 | {
10 | "name": "stderr",
11 | "output_type": "stream",
12 | "text": [
13 | "/root/miniconda3/envs/pt/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
14 | " from .autonotebook import tqdm as notebook_tqdm\n"
15 | ]
16 | }
17 | ],
18 | "source": [
19 | "import torch\n",
20 | "\n",
21 | "#测试数据\n",
22 | "targets = [{\n",
23 | " 'class_labels': torch.randint(low=0, high=4, size=[5]),\n",
24 | " 'boxes': torch.rand(5, 4),\n",
25 | "} for _ in range(8)]\n",
26 | "\n",
27 | "outputs = {\n",
28 | " 'logits': torch.randn(8, 100, 92),\n",
29 | " 'pred_boxes': torch.rand(8, 100, 4)\n",
30 | "}"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": 2,
36 | "id": "309a460f",
37 | "metadata": {},
38 | "outputs": [
39 | {
40 | "data": {
41 | "text/plain": [
42 | "[(tensor([ 7, 8, 35, 92, 97]), tensor([0, 1, 4, 2, 3])),\n",
43 | " (tensor([35, 36, 44, 59, 61]), tensor([2, 1, 0, 4, 3])),\n",
44 | " (tensor([10, 31, 37, 54, 63]), tensor([4, 0, 2, 3, 1])),\n",
45 | " (tensor([ 2, 21, 39, 55, 84]), tensor([3, 2, 1, 4, 0])),\n",
46 | " (tensor([12, 15, 19, 20, 23]), tensor([2, 0, 3, 1, 4])),\n",
47 | " (tensor([ 7, 48, 58, 63, 87]), tensor([0, 4, 2, 1, 3])),\n",
48 | " (tensor([13, 40, 60, 85, 90]), tensor([0, 4, 1, 2, 3])),\n",
49 | " (tensor([15, 23, 35, 68, 83]), tensor([3, 1, 2, 0, 4]))]"
50 | ]
51 | },
52 | "execution_count": 2,
53 | "metadata": {},
54 | "output_type": "execute_result"
55 | }
56 | ],
57 | "source": [
58 | "from scipy.optimize._lsap import linear_sum_assignment\n",
59 | "\n",
60 | "\n",
61 | "class DetrHungarianMatcher:\n",
62 | "\n",
63 | " def box_iou(self, boxes1, boxes2):\n",
64 | " inter = []\n",
65 | " union = []\n",
66 | " area = []\n",
67 | " for box1 in boxes1:\n",
68 | " for box2 in boxes2:\n",
69 | " #求交集面积\n",
70 | " x1 = torch.max(box1[0], box2[0])\n",
71 | " y1 = torch.max(box1[1], box2[1])\n",
72 | " x2 = torch.min(box1[2], box2[2])\n",
73 | " y2 = torch.min(box1[3], box2[3])\n",
74 | "\n",
75 | " w = (x2 - x1).clamp(min=0)\n",
76 | " h = (y2 - y1).clamp(min=0)\n",
77 | "\n",
78 | " inter.append(w * h)\n",
79 | "\n",
80 | " #求并集面积\n",
81 | " w1 = box1[2] - box1[0]\n",
82 | " h1 = box1[3] - box1[1]\n",
83 | " w2 = box2[2] - box2[0]\n",
84 | " h2 = box2[3] - box2[1]\n",
85 | "\n",
86 | " s1 = w1 * h1\n",
87 | " s2 = w2 * h2\n",
88 | "\n",
89 | " union.append(s1 + s2 - (w * h))\n",
90 | "\n",
91 | " #求扩展面积\n",
92 | " x1 = torch.min(box1[0], box2[0])\n",
93 | " y1 = torch.min(box1[1], box2[1])\n",
94 | " x2 = torch.max(box1[2], box2[2])\n",
95 | " y2 = torch.max(box1[3], box2[3])\n",
96 | "\n",
97 | " w = (x2 - x1).clamp(min=0)\n",
98 | " h = (y2 - y1).clamp(min=0)\n",
99 | "\n",
100 | " area.append(w * h)\n",
101 | "\n",
102 | " inter = torch.stack(inter).reshape(len(boxes1), len(boxes2))\n",
103 | " union = torch.stack(union).reshape(len(boxes1), len(boxes2))\n",
104 | " area = torch.stack(area).reshape(len(boxes1), len(boxes2))\n",
105 | "\n",
106 | " #前面的数是iou,值域0-1,衡量了两个框重合的程度,这个数越大越好\n",
107 | "\n",
108 | " #后面的数是个分数,分别来看分子和分母\n",
109 | " #分子是扩展面积-并集面积,这个数显然是越小越好.\n",
110 | " #分母是扩展面积,显然是起归一化作用,所以这个分数的值域是0-1\n",
111 | "\n",
112 | " #综合以上,这个数总的来说还是iou,只是额外考虑了扩展面积的情况\n",
113 | " return (inter / union) - (area - union) / area\n",
114 | "\n",
115 | " #等价写法,上面的写法效率低\n",
116 | " def box_iou(self, boxes1, boxes2):\n",
117 | " area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])\n",
118 | " area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])\n",
119 | "\n",
120 | " p1 = torch.max(boxes1[:, :2].unsqueeze(1), boxes2[:, :2])\n",
121 | " p2 = torch.min(boxes1[:, 2:].unsqueeze(1), boxes2[:, 2:])\n",
122 | " wh = (p2 - p1).clamp(min=0)\n",
123 | " inter = wh[:, :, 0] * wh[:, :, 1]\n",
124 | "\n",
125 | " union = area1.unsqueeze(1) + area2 - inter\n",
126 | "\n",
127 | " p1 = torch.min(boxes1[:, :2].unsqueeze(1), boxes2[:, :2])\n",
128 | " p2 = torch.max(boxes1[:, 2:].unsqueeze(1), boxes2[:, 2:])\n",
129 | " wh = (p2 - p1).clamp(min=0)\n",
130 | " area = wh[:, :, 0] * wh[:, :, 1]\n",
131 | "\n",
132 | " return (inter / union) - (area - union) / area\n",
133 | "\n",
134 | " def xywh_to_x1y1x2y2(self, boxes):\n",
135 | " x = boxes[:, 0]\n",
136 | " y = boxes[:, 1]\n",
137 | " w = boxes[:, 2]\n",
138 | " h = boxes[:, 3]\n",
139 | "\n",
140 | " x1 = x - 0.5 * w\n",
141 | " y1 = y - 0.5 * h\n",
142 | " x2 = x + 0.5 * w\n",
143 | " y2 = y + 0.5 * h\n",
144 | "\n",
145 | " return torch.stack([x1, y1, x2, y2], dim=-1)\n",
146 | "\n",
147 | " @torch.no_grad()\n",
148 | " def __call__(self, outputs, targets):\n",
149 | " #取所有框和预测结果\n",
150 | " #[8, 100, 92] -> [800, 92]\n",
151 | " logits = outputs['logits'].flatten(0, 1).softmax(1)\n",
152 | "\n",
153 | " #[8, 100, 4] -> [800, 4]\n",
154 | " pred_boxes = outputs['pred_boxes'].flatten(0, 1)\n",
155 | "\n",
156 | " #取目标\n",
157 | " #[52]\n",
158 | " class_labels = torch.cat([i['class_labels'] for i in targets])\n",
159 | "\n",
160 | " #[52, 4]\n",
161 | " target_boxes = torch.cat([i['boxes'] for i in targets])\n",
162 | "\n",
163 | " #label的loss,简单的预测概率取反\n",
164 | " #[800, 92] -> [800, 52]\n",
165 | " class_cost = -logits[:, class_labels]\n",
166 | "\n",
167 | " #框的loss,4个点距离的和作为loss\n",
168 | " #[800, 4],[52, 4] -> [800, 52]\n",
169 | " bbox_cost = []\n",
170 | " for box1 in pred_boxes:\n",
171 | " cost = [(box1 - box2).abs().sum() for box2 in target_boxes]\n",
172 | " bbox_cost.append(torch.stack(cost))\n",
173 | " bbox_cost = torch.stack(bbox_cost)\n",
174 | "\n",
175 | " #等价写法\n",
176 | " #bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)\n",
177 | "\n",
178 | " #[800, 52]\n",
179 | " giou_cost = -self.box_iou(self.xywh_to_x1y1x2y2(pred_boxes),\n",
180 | " self.xywh_to_x1y1x2y2(target_boxes))\n",
181 | "\n",
182 | " #[800, 52] -> [8, 100, 52]\n",
183 | " cost = 5 * bbox_cost + 1 * class_cost + 2 * giou_cost\n",
184 | " cost = cost.view(8, 100, -1).cpu()\n",
185 | "\n",
186 | " indices = []\n",
187 | " sum_s = 0\n",
188 | " for c, t in zip(cost, targets):\n",
189 | " #取目标框的数量\n",
190 | " s = len(t['boxes'])\n",
191 | "\n",
192 | " #取这些框的计算结果\n",
193 | " #[100, lens]\n",
194 | " c = c[:, sum_s:sum_s + s]\n",
195 | "\n",
196 | " #累计索引\n",
197 | " sum_s = sum_s + s\n",
198 | "\n",
199 | " #c这个矩阵记录的是loss的情况\n",
200 | " #求最小loss的分配方式\n",
201 | " index_row, index_col = linear_sum_assignment(c)\n",
202 | " index_row = torch.LongTensor(index_row)\n",
203 | " index_col = torch.LongTensor(index_col)\n",
204 | " indices.append((index_row, index_col))\n",
205 | "\n",
206 | " return indices\n",
207 | "\n",
208 | "\n",
209 | "matcher = DetrHungarianMatcher()\n",
210 | "\n",
211 | "matcher(outputs, targets)"
212 | ]
213 | },
214 | {
215 | "cell_type": "code",
216 | "execution_count": 3,
217 | "id": "5eceb918",
218 | "metadata": {},
219 | "outputs": [
220 | {
221 | "data": {
222 | "text/plain": [
223 | "{'loss_ce': tensor(4.9638),\n",
224 | " 'loss_bbox': tensor(0.4050),\n",
225 | " 'loss_giou': tensor(0.7186)}"
226 | ]
227 | },
228 | "execution_count": 3,
229 | "metadata": {},
230 | "output_type": "execute_result"
231 | }
232 | ],
233 | "source": [
234 | "class DetrLoss(torch.nn.Module):\n",
235 | "\n",
236 | " def __init__(self):\n",
237 | " super().__init__()\n",
238 | " empty_weight = torch.ones(92)\n",
239 | " empty_weight[-1] = 0.1\n",
240 | " self.register_buffer('empty_weight', empty_weight)\n",
241 | "\n",
242 | " def loss_labels(self, outputs, targets, indices):\n",
243 | " # 默认都是背景\n",
244 | " #[8, 100]\n",
245 | " target_classes = torch.full([8, 100],\n",
246 | " 91,\n",
247 | " dtype=torch.int64,\n",
248 | " device=outputs['logits'].device)\n",
249 | "\n",
250 | " #遍历8条数据\n",
251 | " for i in range(8):\n",
252 | " #遍历每一个分配结果(最小cost方式分配)\n",
253 | " for io, it in zip(*indices[i]):\n",
254 | " #按照最小cost的方式分配每个预测结果的目标\n",
255 | " target_classes[i, io.item()] = targets[i]['class_labels'][\n",
256 | " it.item()]\n",
257 | "\n",
258 | " #[8, 100, 92] -> [8, 92, 100]\n",
259 | " logits = outputs['logits'].transpose(1, 2)\n",
260 | "\n",
261 | " return torch.nn.functional.cross_entropy(logits, target_classes,\n",
262 | " self.empty_weight)\n",
263 | "\n",
264 | " def loss_boxes(self, outputs, targets, indices):\n",
265 | " boxes_output = []\n",
266 | " boxes_target = []\n",
267 | " # 遍历8条数据\n",
268 | " for i in range(8):\n",
269 | " # 遍历每一个分配结果(最小cost方式分配)\n",
270 | " for io, it in zip(*indices[i]):\n",
271 | " # 按照最小cost的方式取每一对框\n",
272 | " boxes_output.append(outputs['pred_boxes'][i, io.item()])\n",
273 | " boxes_target.append(targets[i]['boxes'][it.item()])\n",
274 | "\n",
275 | " boxes_output = torch.stack(boxes_output)\n",
276 | " boxes_target = torch.stack(boxes_target)\n",
277 | "\n",
278 | " num_boxes = sum(len(i['class_labels']) for i in targets)\n",
279 | " if num_boxes < 1:\n",
280 | " num_boxes = 1\n",
281 | "\n",
282 | " #没对框之间求L1距离作为loss\n",
283 | " loss_bbox = torch.nn.functional.l1_loss(boxes_output,\n",
284 | " boxes_target,\n",
285 | " reduction='none')\n",
286 | " loss_bbox = loss_bbox.sum() / num_boxes\n",
287 | "\n",
288 | " #取iou作为第二部分的loss\n",
289 | " #只需要考虑成对的框的iou,所以取对角线元素计算即可\n",
290 | " #iou是越大越好,优化方向取反,所以loss=1-iou\n",
291 | " giou = matcher.box_iou(matcher.xywh_to_x1y1x2y2(boxes_output),\n",
292 | " matcher.xywh_to_x1y1x2y2(boxes_target))\n",
293 | " loss_giou = 1 - giou.diag()\n",
294 | " loss_giou = loss_giou.sum() / num_boxes\n",
295 | "\n",
296 | " return loss_bbox, loss_giou\n",
297 | "\n",
298 | " def forward(self, outputs, targets):\n",
299 | " indices = matcher(outputs, targets)\n",
300 | "\n",
301 | " losses = {}\n",
302 | " losses['loss_ce'] = self.loss_labels(outputs, targets, indices)\n",
303 | "\n",
304 | " loss_bbox, loss_giou = self.loss_boxes(outputs, targets, indices)\n",
305 | " losses['loss_bbox'] = loss_bbox\n",
306 | " losses['loss_giou'] = loss_giou\n",
307 | "\n",
308 | " return losses\n",
309 | "\n",
310 | "\n",
311 | "criterion = DetrLoss()\n",
312 | "\n",
313 | "criterion(outputs, targets)"
314 | ]
315 | },
316 | {
317 | "cell_type": "code",
318 | "execution_count": 4,
319 | "id": "7820caf3",
320 | "metadata": {},
321 | "outputs": [
322 | {
323 | "data": {
324 | "text/plain": [
325 | "{'loss_ce': tensor(4.9638),\n",
326 | " 'loss_bbox': tensor(0.4050),\n",
327 | " 'loss_giou': tensor(0.7186),\n",
328 | " 'cardinality_error': tensor(94.1250)}"
329 | ]
330 | },
331 | "execution_count": 4,
332 | "metadata": {},
333 | "output_type": "execute_result"
334 | }
335 | ],
336 | "source": [
337 | "def test():\n",
338 | " from transformers.models.detr.modeling_detr import DetrLoss, DetrHungarianMatcher\n",
339 | "\n",
340 | " matcher = DetrHungarianMatcher(class_cost=1, bbox_cost=5, giou_cost=2)\n",
341 | "\n",
342 | " criterion = DetrLoss(matcher=matcher,\n",
343 | " num_classes=91,\n",
344 | " eos_coef=0.1,\n",
345 | " losses=['labels', 'boxes', 'cardinality'])\n",
346 | "\n",
347 | " return criterion(outputs, targets)\n",
348 | "\n",
349 | "\n",
350 | "test()"
351 | ]
352 | }
353 | ],
354 | "metadata": {
355 | "kernelspec": {
356 | "display_name": "Python [conda env:pt]",
357 | "language": "python",
358 | "name": "conda-env-pt-py"
359 | },
360 | "language_info": {
361 | "codemirror_mode": {
362 | "name": "ipython",
363 | "version": 3
364 | },
365 | "file_extension": ".py",
366 | "mimetype": "text/x-python",
367 | "name": "python",
368 | "nbconvert_exporter": "python",
369 | "pygments_lexer": "ipython3",
370 | "version": "3.10.13"
371 | }
372 | },
373 | "nbformat": 4,
374 | "nbformat_minor": 5
375 | }
376 |
--------------------------------------------------------------------------------