├── .gitignore ├── README.md ├── datasets.py ├── model.py ├── modules ├── nms.py └── parse_polys.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | runs/ 2 | .idea/ 3 | .vscode/ 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [FOTS: Fast Oriented Text Spotting with a Unified Network](https://arxiv.org/abs/1801.01671) text detection branch reimplementation ([PyTorch](https://pytorch.org/)) 2 | 3 | ## Train 4 | 1. Train with SynthText for 9 epochs 5 | ```sh 6 | time python3 train.py --train-folder SynthText/ --batch-size 21 --batches-before-train 2 7 | ``` 8 | At this point the result was `Epoch 8: 100%|█████████████| 390/390 [08:28<00:00, 1.00it/s, Mean loss=0.98050]`. 9 | 2. Train with ICDAR15 10 | 11 | Replace a data set in `data_set = datasets.SynthText(args.train_folder, datasets.transform)` with `datasets.ICDAR2015` in [`train.py`](./train.py) and run 12 | ```sh 13 | time python3 train.py --train-folder icdar15/ --continue-training --batch-size 21 --batches-before-train 2 14 | ``` 15 | It is expected that the provided `--train-folder` contains unzipped `ch4_training_images` and `ch4_training_localization_transcription_gt`. To avoid saving model at each epoch, the line `if True:` in [`train.py`](./train.py) can be replaced with `if epoch > 60 and epoch % 6 == 0:` 16 | 17 | The result was `Epoch 582: 100%|█████████████| 48/48 [01:05<00:00, 1.04s/it, Mean loss=0.11290]`. 18 | 19 | ### Learning rate schedule: 20 | Epoch 175: reducing learning rate of group 0 to 5.0000e-04. 21 | 22 | Epoch 264: reducing learning rate of group 0 to 2.5000e-04. 23 | 24 | Epoch 347: reducing learning rate of group 0 to 1.2500e-04. 25 | 26 | Epoch 412: reducing learning rate of group 0 to 6.2500e-05. 27 | 28 | Epoch 469: reducing learning rate of group 0 to 3.1250e-05. 29 | 30 | Epoch 525: reducing learning rate of group 0 to 1.5625e-05. 31 | 32 | Epoch 581: reducing learning rate of group 0 to 7.8125e-06. 33 | 34 | ## Test 35 | ```sh 36 | python3 test.py --images-folder ch4_test_images/ --output-folder res/ --checkpoint epoch_582_checkpoint.pt && zip -jmq runs/u.zip res/* && python2 script.py -g=gt.zip -s=runs/u.zip 37 | ``` 38 | `ch4_training_images` and `ch4_training_localization_transcription_gt` are available in [Task 4.4: End to End (2015 edition)](http://rrc.cvc.uab.es/?ch=4&com=downloads). `script.py` and `ch4_test_images` can be found in [My Methods](https://rrc.cvc.uab.es/?ch=4&com=mymethods&task=1) (`Script: IoU` and `test set samples`). 39 | 40 | It gives `Calculated!{"precision": 0.8694968553459119, "recall": 0.7987481945113144, "hmean": 0.8326223337515684, "AP": 0}`. 41 | 42 | The pretrained models are here: https://drive.google.com/open?id=1xaVshLRrMEkb9LA46IJAZhlapQr3vyY2 43 | 44 | [`test.py`](./test.py) has a commented code to visualize results. 45 | 46 | ## Difference with the paper 47 | 1. The model is different compared to what the paper describes. An explanation is in [`model.py`](./model.py). 48 | 2. The authors of FOTS could not train on clipped words because they also have a recognition branch. The whole word is required to be present on an image to be able to be recognized correctly. This reimplementation has only detection branch and that allows to train on crops of the words. 49 | 3. The paper suggest using some other data sets in addition. Training on SynthText is simplified in this reimplementation. 50 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import re 5 | 6 | import cv2 7 | import numpy as np 8 | import scipy.io 9 | import torch 10 | import torch.utils.data 11 | import torchvision 12 | from shapely.geometry import Polygon, box 13 | import shapely 14 | 15 | 16 | def point_dist_to_line(p1, p2, p3): 17 | """Compute the distance from p3 to p2-p1.""" 18 | if not np.array_equal(p1, p2): 19 | return np.abs(np.cross(p2 - p1, p1 - p3)) / np.linalg.norm(p2 - p1) 20 | else: 21 | return np.linalg.norm(p3 - p1) 22 | 23 | 24 | IN_OUT_RATIO = 4 25 | IN_SIDE = 640 26 | OUT_SIDE = IN_SIDE // IN_OUT_RATIO 27 | 28 | 29 | def transform(im, quads, texts, normalizer, data_set): 30 | # upscale 31 | scale = 2560 / np.maximum(im.shape[0], im.shape[1]) 32 | upscaled = cv2.resize(im, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC) 33 | quads = quads * scale 34 | # rotate 35 | # grab the dimensions of the image and then determine the 36 | # center 37 | (h, w) = upscaled.shape[:2] 38 | (cX, cY) = (w / 2, h / 2) 39 | 40 | # grab the rotation matrix (applying the negative of the 41 | # angle to rotate clockwise), then grab the sine and cosine 42 | # (i.e., the rotation components of the matrix) 43 | angle = torch.empty(1).uniform_(-10, 10).item() 44 | M = cv2.getRotationMatrix2D((cX, cY), angle=angle, scale=1.0) 45 | cos = np.abs(M[0, 0]) 46 | sin = np.abs(M[0, 1]) 47 | 48 | # compute the new bounding dimensions of the image 49 | nW = int((h * sin) + (w * cos)) # TODO replace with round and do it later 50 | nH = int((h * cos) + (w * sin)) 51 | 52 | # adjust the rotation matrix to take into account translation 53 | M[0, 2] += (nW / 2) - cX 54 | M[1, 2] += (nH / 2) - cY 55 | 56 | # perform the actual rotation and return the image 57 | rotated = cv2.warpAffine(upscaled, M, (nW, nH)) 58 | quads = cv2.transform(quads, M) 59 | # stretch 60 | strechK = torch.empty(1).uniform_(0.8, 1.2).item() 61 | stretched = cv2.resize(rotated, None, fx=1, fy=strechK, interpolation=cv2.INTER_CUBIC) 62 | quads[:, :, 1] = quads[:, :, 1] * strechK 63 | 64 | quads /= IN_OUT_RATIO 65 | 66 | training_mask = np.ones((OUT_SIDE, OUT_SIDE), dtype=float) 67 | classification = np.zeros((OUT_SIDE, OUT_SIDE), dtype=float) 68 | regression = np.zeros((4,) + classification.shape, dtype=float) 69 | tmp_cls = np.empty(classification.shape, dtype=float) 70 | thetas = np.zeros(classification.shape, dtype=float) 71 | 72 | # crop 73 | crop_max_y = stretched.shape[0] // IN_OUT_RATIO - OUT_SIDE # since Synth has some low images, there is a chance that y coord of crop can be zero only 74 | if 0 != crop_max_y: 75 | crop_point = (torch.randint(low=0, high=stretched.shape[1] // IN_OUT_RATIO - OUT_SIDE, size=(1,), dtype=torch.int16).item(), 76 | torch.randint(low=0, high=stretched.shape[0] // IN_OUT_RATIO - OUT_SIDE, size=(1,), dtype=torch.int16).item()) 77 | else: 78 | crop_point = (torch.randint(low=0, high=stretched.shape[1] // IN_OUT_RATIO - OUT_SIDE, size=(1,), dtype=torch.int16).item(), 79 | 0) 80 | crop_box = box(crop_point[0], crop_point[1], crop_point[0] + OUT_SIDE, crop_point[1] + OUT_SIDE) 81 | 82 | for quad_id, quad in enumerate(quads): 83 | polygon = Polygon(quad) 84 | try: 85 | intersected_polygon = polygon.intersection(crop_box) 86 | except shapely.errors.TopologicalError: # some points of quads in Synth can be in wrong order 87 | quad[1], quad[2] = quad[2], quad[1] 88 | polygon = Polygon(quad) 89 | intersected_polygon = polygon.intersection(crop_box) 90 | if type(intersected_polygon) is Polygon: 91 | intersected_quad = np.array(intersected_polygon.exterior.coords[:-1]) 92 | intersected_quad -= crop_point 93 | intersected_minAreaRect = cv2.minAreaRect(intersected_quad.astype(np.float32)) 94 | intersected_minAreaRect_boxPoints = cv2.boxPoints(intersected_minAreaRect) 95 | cv2.fillConvexPoly(training_mask, intersected_minAreaRect_boxPoints.round().astype(int), 0) 96 | minAreaRect = cv2.minAreaRect(quad.astype(np.float32)) 97 | shrinkage = min(minAreaRect[1][0], minAreaRect[1][1]) * 0.6 98 | shrunk_width_and_height = (intersected_minAreaRect[1][0] - shrinkage, intersected_minAreaRect[1][1] - shrinkage) 99 | if shrunk_width_and_height[0] >= 0 and shrunk_width_and_height[1] >= 0 and texts[quad_id]: 100 | shrunk_minAreaRect = intersected_minAreaRect[0], shrunk_width_and_height, intersected_minAreaRect[2] 101 | 102 | poly = intersected_minAreaRect_boxPoints 103 | if intersected_minAreaRect[2] >= -45: 104 | poly = np.array([poly[1], poly[2], poly[3], poly[0]]) 105 | else: 106 | poly = np.array([poly[2], poly[3], poly[0], poly[1]]) 107 | angle_cos = (poly[2, 0] - poly[3, 0]) / np.sqrt( 108 | (poly[2, 0] - poly[3, 0]) ** 2 + (poly[2, 1] - poly[3, 1]) ** 2 + 1e-5) # TODO tg or ctg 109 | angle = np.arccos(angle_cos) 110 | if poly[2, 1] > poly[3, 1]: 111 | angle *= -1 112 | angle += 45 * np.pi / 180 # [0, pi/2] for learning, actually [-pi/4, pi/4] 113 | 114 | tmp_cls.fill(0) 115 | round_shrink_minAreaRect_boxPoints = cv2.boxPoints(shrunk_minAreaRect) 116 | cv2.fillConvexPoly(tmp_cls, round_shrink_minAreaRect_boxPoints.round(out=round_shrink_minAreaRect_boxPoints).astype(int), 1) 117 | cv2.rectangle(tmp_cls, (0, 0), (tmp_cls.shape[1] - 1, tmp_cls.shape[0] - 1), 0, thickness=int(round(shrinkage * 2))) 118 | 119 | classification += tmp_cls 120 | training_mask += tmp_cls 121 | thetas += tmp_cls * angle 122 | 123 | points = np.nonzero(tmp_cls) 124 | pointsT = np.transpose(points) 125 | for point in pointsT: 126 | for plane in range(3): # TODO widht - dist, height - other dist and more percise dist 127 | regression[(plane,) + tuple(point)] = point_dist_to_line(poly[plane], poly[plane + 1], np.array((point[1], point[0]))) * IN_OUT_RATIO 128 | regression[(3,) + tuple(point)] = point_dist_to_line(poly[3], poly[0], np.array((point[1], point[0]))) * IN_OUT_RATIO 129 | if 0 == np.count_nonzero(classification) and 0.1 < torch.rand(1).item(): 130 | return data_set[torch.randint(low=0, high=len(data_set), size=(1,), dtype=torch.int16).item()] 131 | # avoiding training on black corners decreases hmean, see d9c727a8defbd1c8022478ae798c907ccd2fa0b2. This may happen 132 | # because of OHEM: it already guides the training and it won't select back corner pixels if the net is good at 133 | # classifying them. It can be easily verified by removing OHEM, but I didn't test it 134 | cropped = stretched[crop_point[1] * IN_OUT_RATIO:crop_point[1] * IN_OUT_RATIO + IN_SIDE, crop_point[0] * IN_OUT_RATIO:crop_point[0] * IN_OUT_RATIO + IN_SIDE] 135 | cropped = cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB).astype(np.float64) / 255 136 | permuted = np.transpose(cropped, (2, 0, 1)) 137 | permuted = torch.from_numpy(permuted).float() 138 | permuted = normalizer(permuted) 139 | return permuted, torch.from_numpy(classification).float(), torch.from_numpy(regression).float(), torch.from_numpy( 140 | thetas).float(), torch.from_numpy(training_mask).float() 141 | 142 | 143 | class ICDAR2015(torch.utils.data.Dataset): 144 | def __init__(self, root, transform): 145 | self.transform = transform 146 | self.root = root 147 | self.img_dir = 'ch4_training_images' 148 | self.labels_dir = 'ch4_training_localization_transcription_gt' 149 | self.image_prefix = [] 150 | self.pattern = re.compile('^' + '(\\d+),' * 8 + '(.+)$') 151 | self.normalizer = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], 152 | std=[0.229, 0.224, 0.225]) 153 | for dirEntry in os.scandir(os.path.join(root, 'ch4_training_images')): 154 | self.image_prefix.append(dirEntry.name[:-4]) 155 | 156 | def __len__(self): 157 | return len(self.image_prefix) 158 | 159 | def __getitem__(self, idx): 160 | img = cv2.imread(os.path.join(os.path.join(self.root, self.img_dir), self.image_prefix[idx] + '.jpg'), cv2.IMREAD_COLOR).astype(np.float32) 161 | quads = [] 162 | texts = [] 163 | lines = [line.rstrip('\n') for line in open(os.path.join(os.path.join(self.root, self.labels_dir), 'gt_' + self.image_prefix[idx] + '.txt'), 164 | encoding='utf-8-sig')] 165 | for line in lines: 166 | matches = self.pattern.findall(line)[0] 167 | numbers = np.array(matches[:8], dtype=float) 168 | quads.append(numbers.reshape((4, 2))) 169 | texts.append('###' != matches[8]) 170 | return transform(img, np.stack(quads), texts, self.normalizer, self) 171 | 172 | 173 | class SynthText(torch.utils.data.Dataset): 174 | def __init__(self, root, transform): 175 | self.transform = transform 176 | self.root = root 177 | self.labels = scipy.io.loadmat(os.path.join(root, 'gt.mat')) 178 | self.broken_image_ids = set() 179 | #sample_path = labels['imnames'][0, 1][0] 180 | #sample_boxes = np.transpose(labels['wordBB'][0, 1], (2, 1, 0)) 181 | self.pattern = re.compile('^' + '(\\d+),' * 8 + '(.+)$') 182 | self.normalizer = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], 183 | std=[0.229, 0.224, 0.225]) 184 | 185 | def __len__(self): 186 | return self.labels['imnames'].shape[1] // 105 # there are more than 105 text images for each source image 187 | 188 | def __getitem__(self, idx): 189 | idx = (idx * 105) + random.randint(0, 104) # compensate dataset size, while maintain diversity 190 | if idx in self.broken_image_ids: 191 | return self[torch.randint(low=0, high=len(self), size=(1,), dtype=torch.int16).item()] 192 | img = cv2.imread(os.path.join(self.root, self.labels['imnames'][0, idx][0]), cv2.IMREAD_COLOR).astype(np.float32) 193 | if 190 >= img.shape[0]: # image is too low, it will not be possible to crop 640x640 after transformations 194 | self.broken_image_ids.add(idx) 195 | return self[torch.randint(low=0, high=len(self), size=(1,), dtype=torch.int16).item()] 196 | coordinates = self.labels['wordBB'][0, idx] 197 | if len(coordinates.shape) == 2: 198 | coordinates = np.expand_dims(coordinates, axis=2) 199 | transposed = np.transpose(coordinates, (2, 1, 0)) 200 | if (transposed > 0).all() and (transposed[:, :, 1] < img.shape[1]).all() and (transposed[:, :, 1] < img.shape[0]).all(): 201 | if ((transposed[:, 0] != transposed[:, 1]).all() and 202 | (transposed[:, 0] != transposed[:, 2]).all() and 203 | (transposed[:, 0] != transposed[:, 3]).all() and 204 | (transposed[:, 1] != transposed[:, 2]).all() and 205 | (transposed[:, 1] != transposed[:, 3]).all() and 206 | (transposed[:, 2] != transposed[:, 3]).all()): # boxes can be in a form [p1, p1, p2, p2], while we need [p1, p2, p3, p4] 207 | return transform(img, transposed, (True, ) * len(transposed), self.normalizer, self) 208 | self.broken_image_ids.add(idx) 209 | return self[torch.randint(low=0, high=len(self), size=(1,), dtype=torch.int16).item()] 210 | 211 | 212 | if '__main__' == __name__: 213 | icdar = ICDAR2015('C:\\Users\\vzlobin\\Documents\\repo\\FOTS.PyTorch\\data\\icdar\\icdar2015\\4.4\\training', transform) 214 | # dl = torch.utils.data.DataLoader(icdar, batch_size=4, shuffle=False, sampler=None, batch_sampler=None, num_workers=4, pin_memory = False, drop_last = False, timeout = 0, worker_init_fn = None) 215 | for image_i in range(len(icdar)): 216 | normalized, classification, regression, thetas, training_mask = icdar[image_i] 217 | permuted = normalized * torch.tensor([0.229, 0.224, 0.225])[:, None, None] + torch.tensor([0.485, 0.456, 0.406])[:, None, None] 218 | cropped = permuted.permute(1, 2, 0).numpy() 219 | cv2.imshow('orig', cv2.resize(cropped[:, :, ::-1], (640, 640))) 220 | cropped = cv2.resize(cropped, (160, 160)) 221 | cv2.imshow('img', cv2.resize(cropped[:, :, ::-1] * training_mask.numpy()[:, :, None], (640, 640))) 222 | cv2.imshow('training_mask', cv2.resize(training_mask.numpy() * 255, (640, 640))) 223 | cv2.imshow('classification', cv2.resize(classification.numpy() * 255, (640, 640))) 224 | regression = regression.numpy() 225 | for i in range(4): 226 | m = np.amax(regression[i]) 227 | if 0 != m: 228 | cv2.imshow(str(i), cv2.resize(regression[i, :, :] / m, (640, 640))) 229 | else: 230 | cv2.imshow(str(i), cv2.resize(regression[i, :, :], (640, 640))) 231 | thetas = thetas.numpy() 232 | minim = np.amin(thetas) 233 | m = np.amax(thetas) 234 | print(m * 180 / np.pi) 235 | cv2.imshow('angle', cv2.resize(np.array(np.around(thetas * 255 / m * 180 / np.pi), dtype=np.uint8), (640, 640))) 236 | cv2.waitKey(0) 237 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | 7 | 8 | def conv(in_channels, out_channels, kernel_size=3, padding=1, bn=True, dilation=1, stride=1, relu=True, bias=True): 9 | modules = [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)] 10 | if bn: 11 | modules.append(nn.BatchNorm2d(out_channels)) 12 | if relu: 13 | modules.append(nn.ReLU(inplace=True)) 14 | return nn.Sequential(*modules) 15 | 16 | 17 | class Decoder(nn.Module): 18 | def __init__(self, in_channels, squeeze_channels): 19 | super().__init__() 20 | self.squeeze = conv(in_channels, squeeze_channels) 21 | 22 | def forward(self, x, encoder_features): 23 | x = self.squeeze(x) 24 | x = F.interpolate(x, size=(encoder_features.shape[2], encoder_features.shape[3]), 25 | mode='bilinear', align_corners=True) 26 | up = torch.cat([encoder_features, x], 1) 27 | return up 28 | 29 | 30 | class FOTSModel(nn.Module): 31 | def __init__(self, crop_height=640): 32 | super().__init__() 33 | self.crop_height = crop_height 34 | self.resnet = torchvision.models.resnet34(pretrained=True) 35 | self.conv1 = nn.Sequential( 36 | self.resnet.conv1, 37 | self.resnet.bn1, 38 | self.resnet.relu, 39 | ) # 64 40 | self.encoder1 = self.resnet.layer1 # 64 41 | self.encoder2 = self.resnet.layer2 # 128 42 | self.encoder3 = self.resnet.layer3 # 256 43 | self.encoder4 = self.resnet.layer4 # 512 44 | 45 | self.center = nn.Sequential( 46 | conv(512, 512, stride=2), 47 | conv(512, 1024) 48 | ) 49 | 50 | self.decoder4 = Decoder(1024, 512) 51 | self.decoder3 = Decoder(1024, 256) 52 | self.decoder2 = Decoder(512, 128) 53 | self.decoder1 = Decoder(256, 64) 54 | self.remove_artifacts = conv(128, 64) 55 | 56 | self.confidence = conv(64, 1, kernel_size=1, padding=0, bn=False, relu=False) 57 | self.distances = conv(64, 4, kernel_size=1, padding=0, bn=False, relu=False) 58 | self.angle = conv(64, 1, kernel_size=1, padding=0, bn=False, relu=False) 59 | 60 | def forward(self, x): 61 | x = self.conv1(x) 62 | x = F.max_pool2d(x, kernel_size=2, stride=2) 63 | 64 | e1 = self.encoder1(x) 65 | e2 = self.encoder2(e1) 66 | e3 = self.encoder3(e2) 67 | e4 = self.encoder4(e3) 68 | 69 | f = self.center(e4) 70 | 71 | d4 = self.decoder4(f, e4) 72 | d3 = self.decoder3(d4, e3) 73 | d2 = self.decoder2(d3, e2) 74 | d1 = self.decoder1(d2, e1) 75 | 76 | final = self.remove_artifacts(d1) 77 | 78 | confidence = self.confidence(final) 79 | distances = self.distances(final) 80 | distances = torch.sigmoid(distances) * self.crop_height 81 | angle = self.angle(final) 82 | angle = torch.sigmoid(angle) * np.pi / 2 83 | 84 | return confidence, distances, angle 85 | 86 | 87 | # class FOTSModel(nn.Module): 88 | # """This model is described in the paper, but it trains slower and gives slightly worse results""" 89 | # def __init__(self, crop_height=640): 90 | # super().__init__() 91 | # self.crop_height = crop_height 92 | # self.resnet = torchvision.models.resnet50(pretrained=True) 93 | # self.conv1 = nn.Sequential( 94 | # self.resnet.conv1, 95 | # self.resnet.bn1, 96 | # self.resnet.relu, 97 | # ) # 64 * 4 98 | # self.encoder1 = self.resnet.layer1 # 64 * 4 99 | # self.encoder2 = self.resnet.layer2 # 128 * 4 100 | # self.encoder3 = self.resnet.layer3 # 256 * 4 101 | # self.encoder4 = self.resnet.layer4 # 512 * 4 102 | 103 | # self.decoder3 = Decoder(512 * 4, 256 * 4) 104 | # self.decoder2 = Decoder(256 * 4 * 2, 128 * 4) 105 | # self.decoder1 = Decoder(128 * 4 * 2, 64 * 4) 106 | 107 | # self.confidence = conv(64 * 4 * 2, 1, kernel_size=1, padding=0, bn=False, relu=False) 108 | # self.distances = conv(64 * 4 * 2, 4, kernel_size=1, padding=0, bn=False, relu=False) 109 | # self.angle = conv(64 * 4 * 2, 1, kernel_size=1, padding=0, bn=False, relu=False) 110 | 111 | # def forward(self, x): 112 | # x = self.conv1(x) 113 | # x = F.max_pool2d(x, kernel_size=2, stride=2) 114 | 115 | # e1 = self.encoder1(x) 116 | # e2 = self.encoder2(e1) 117 | # e3 = self.encoder3(e2) 118 | # e4 = self.encoder4(e3) 119 | 120 | # d3 = self.decoder3(e4, e3) 121 | # d2 = self.decoder2(d3, e2) 122 | # d1 = self.decoder1(d2, e1) 123 | 124 | # confidence = self.confidence(d1) 125 | # distances = self.distances(d1) 126 | # distances = torch.sigmoid(distances) * self.crop_height 127 | # angle = self.angle(d1) 128 | # angle = torch.sigmoid(angle) * np.pi / 2 129 | 130 | # return confidence, distances, angle 131 | -------------------------------------------------------------------------------- /modules/nms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from shapely.geometry import Polygon 3 | 4 | 5 | def intersection(g, p): 6 | g = Polygon(g[:8].reshape((4, 2))) 7 | p = Polygon(p[:8].reshape((4, 2))) 8 | if not g.is_valid or not p.is_valid: 9 | return 0 10 | inter = Polygon(g).intersection(Polygon(p)).area 11 | union = g.area + p.area - inter 12 | if union == 0: 13 | return 0 14 | else: 15 | return inter/union 16 | 17 | 18 | def weighted_merge(g, p): 19 | g[:8] = (g[8] * g[:8] + p[8] * p[:8])/(g[8] + p[8]) 20 | g[8] = (g[8] + p[8]) 21 | return g 22 | 23 | 24 | def standard_nms(S, thres=0.3): 25 | if 0 == len(S): 26 | return np.array([]) 27 | order = np.argsort(S[:, 8])[::-1] 28 | keep = [] 29 | while order.size > 0: 30 | i = order[0] 31 | keep.append(i) 32 | ovr = np.array([intersection(S[i], S[t]) for t in order[1:]]) 33 | 34 | inds = np.where(ovr <= thres)[0] 35 | order = order[inds+1] 36 | 37 | return S[keep] 38 | 39 | 40 | def nms_locality(polys, thres=0.3): 41 | ''' 42 | locality aware nms of EAST 43 | :param polys: a N*9 numpy array. first 8 coordinates, then prob 44 | :return: boxes after nms 45 | ''' 46 | S = [] 47 | p = None 48 | for g in polys: 49 | if p is not None and intersection(g, p) > thres: 50 | p = weighted_merge(g, p) 51 | else: 52 | if p is not None: 53 | S.append(p) 54 | p = g 55 | if p is not None: 56 | S.append(p) 57 | 58 | if len(S) == 0: 59 | return np.array([]) 60 | return standard_nms(np.array(S), thres) 61 | 62 | 63 | if __name__ == '__main__': 64 | # 343,350,448,135,474,143,369,359 65 | print(Polygon(np.array([[343, 350], [448, 135], 66 | [474, 143], [369, 359]])).area) 67 | 68 | -------------------------------------------------------------------------------- /modules/parse_polys.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from modules.nms import nms_locality, standard_nms 5 | 6 | 7 | def parse_polys(cls, distances, angle, confidence_threshold=0.5, intersection_threshold=0.3, img=None): 8 | polys = [] 9 | height, width = cls.shape 10 | 11 | #bin_cls = cls > confidence_threshold 12 | #t_dist = distances[0].copy() 13 | #t_dist[bin_cls == False] = 0 14 | #b_dist = distances[2].copy() 15 | #b_dist[bin_cls == False] = 0 16 | #cv2.imshow('cls', cls) 17 | #cv2.imshow('t_dist', t_dist.astype(np.uint8)) 18 | #cv2.imshow('b_dist', b_dist.astype(np.uint8)) 19 | # 20 | #thr_cls = cls.copy() 21 | #thr_cls[bin_cls == False] = 0 22 | #thr_cls = (thr_cls * 255).astype(np.uint8) 23 | #thr_cls = cv2.cvtColor(thr_cls, cv2.COLOR_GRAY2BGR) 24 | 25 | IN_OUT_RATIO = 4 26 | for y in range(height): 27 | for x in range(width): 28 | if cls[y, x] < confidence_threshold: 29 | continue 30 | #thr_cls_copy = thr_cls.copy() 31 | #thr_cls_copy[y, x] = [0, 0, 255] 32 | 33 | #a,b,c,d = distances[0, y, x], distances[1, y, x], distances[2, y, x], distances[3, y, x] 34 | poly_height = distances[0, y, x] + distances[2, y, x] 35 | poly_width = distances[1, y, x] + distances[3, y, x] 36 | #cv2.line(thr_cls_copy, (10, 10), (10 + int(poly_width), 10), (0, 255, 0)) 37 | #cv2.line(thr_cls_copy, (10, 10), (10, int(10 + poly_height)), (0, 255, 0)) 38 | 39 | poly_angle = angle[y, x] - np.pi / 4 40 | x_rot = x * np.cos(-poly_angle) + y * np.sin(-poly_angle) 41 | y_rot = -x * np.sin(-poly_angle) + y * np.cos(-poly_angle) 42 | poly_y_center = y_rot * IN_OUT_RATIO + (poly_height / 2 - distances[0, y, x]) 43 | poly_x_center = x_rot * IN_OUT_RATIO - (poly_width / 2 - distances[1, y, x]) 44 | poly = [ 45 | int(((poly_x_center - poly_width / 2) * np.cos(poly_angle) + (poly_y_center - poly_height / 2) * np.sin(poly_angle))), 46 | int((-(poly_x_center - poly_width / 2) * np.sin(poly_angle) + (poly_y_center - poly_height / 2) * np.cos(poly_angle))), 47 | int(((poly_x_center + poly_width / 2) * np.cos(poly_angle) + (poly_y_center - poly_height / 2) * np.sin(poly_angle))), 48 | int((-(poly_x_center + poly_width / 2) * np.sin(poly_angle) + (poly_y_center - poly_height / 2) * np.cos(poly_angle))), 49 | int(((poly_x_center + poly_width / 2) * np.cos(poly_angle) + (poly_y_center + poly_height / 2) * np.sin(poly_angle))), 50 | int((-(poly_x_center + poly_width / 2) * np.sin(poly_angle) + (poly_y_center + poly_height / 2) * np.cos(poly_angle))), 51 | int(((poly_x_center - poly_width / 2) * np.cos(poly_angle) + (poly_y_center + poly_height / 2) * np.sin(poly_angle))), 52 | int((-(poly_x_center - poly_width / 2) * np.sin(poly_angle) + (poly_y_center + poly_height / 2) * np.cos(poly_angle))), 53 | cls[y, x] 54 | ] 55 | #pts = np.array(poly[:8]).reshape((4, 2)).astype(np.int32) 56 | #cv2.line(thr_cls_copy, (pts[0, 0], pts[0, 1]), (pts[1, 0], pts[1, 1]), color=(0, 255, 0)) 57 | #cv2.line(thr_cls_copy, (pts[1, 0], pts[1, 1]), (pts[2, 0], pts[2, 1]), color=(0, 255, 0)) 58 | #cv2.line(thr_cls_copy, (pts[2, 0], pts[2, 1]), (pts[3, 0], pts[3, 1]), color=(0, 255, 0)) 59 | #cv2.line(thr_cls_copy, (pts[3, 0], pts[3, 1]), (pts[0, 0], pts[0, 1]), color=(0, 255, 0)) 60 | #cv2.imshow('tmp', thr_cls_copy) 61 | #cv2.waitKey() 62 | 63 | polys.append(poly) 64 | 65 | polys = nms_locality(np.array(polys), intersection_threshold) 66 | if img is not None: 67 | for poly in polys: 68 | pts = np.array(poly[:8]).reshape((4, 2)).astype(np.int32) 69 | cv2.line(img, (pts[0, 0], pts[0, 1]), (pts[1, 0], pts[1, 1]), color=(0, 255, 0)) 70 | cv2.line(img, (pts[1, 0], pts[1, 1]), (pts[2, 0], pts[2, 1]), color=(0, 255, 0)) 71 | cv2.line(img, (pts[2, 0], pts[2, 1]), (pts[3, 0], pts[3, 1]), color=(0, 255, 0)) 72 | cv2.line(img, (pts[3, 0], pts[3, 1]), (pts[0, 0], pts[0, 1]), color=(0, 255, 0)) 73 | cv2.imshow('polys', img) 74 | cv2.waitKey() 75 | return polys 76 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | 8 | from model import FOTSModel 9 | from modules.parse_polys import parse_polys 10 | import re 11 | import tqdm 12 | 13 | 14 | def test(net, images_folder, output_folder, scaled_height): 15 | pbar = tqdm.tqdm(os.listdir(images_folder), desc='Test', ncols=80) 16 | for image_name in pbar: 17 | prefix = image_name[:image_name.rfind('.')] 18 | image = cv2.imread(os.path.join(images_folder, image_name), cv2.IMREAD_COLOR) 19 | # due to bad net arch sizes have to be mult of 32, so hardcode it 20 | scale_x = 2240 / image.shape[1] # 2240 # 1280 21 | scale_y = 1248 / image.shape[0] # 1248 # 704 22 | scaled_image = cv2.resize(image, dsize=(0, 0), fx=scale_x, fy=scale_y, interpolation=cv2.INTER_CUBIC) 23 | orig_scaled_image = scaled_image.copy() 24 | 25 | scaled_image = scaled_image[:, :, ::-1].astype(np.float32) 26 | scaled_image = (scaled_image / 255 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225]) 27 | image_tensor = torch.from_numpy(np.expand_dims(np.transpose(scaled_image, axes=(2, 0, 1)), axis=0)).float() 28 | 29 | confidence, distances, angle = net(image_tensor.cuda()) 30 | confidence = torch.sigmoid(confidence).squeeze().data.cpu().numpy() 31 | distances = distances.squeeze().data.cpu().numpy() 32 | angle = angle.squeeze().data.cpu().numpy() 33 | polys = parse_polys(confidence, distances, angle, 0.95, 0.3)#, img=orig_scaled_image) 34 | with open('{}'.format(os.path.join(output_folder, 'res_{}.txt'.format(prefix))), 'w') as f: 35 | for id in range(polys.shape[0]): 36 | f.write('{}, {}, {}, {}, {}, {}, {}, {}\n'.format( 37 | int(polys[id, 0] / scale_x), int(polys[id, 1] / scale_y), int(polys[id, 2] / scale_x), int(polys[id, 3] / scale_y), 38 | int(polys[id, 4] / scale_x), int(polys[id, 5] / scale_y), int(polys[id, 6] / scale_x), int(polys[id, 7] / scale_y) 39 | )) 40 | pbar.set_postfix_str(image_name, refresh=False) 41 | # visualize 42 | # reshaped_pred_polys = [] 43 | # for id in range(polys.shape[0]): 44 | # reshaped_pred_polys.append(np.array([int(polys[id, 0] / scale_x), int(polys[id, 1] / scale_y), int(polys[id, 2] / scale_x), int(polys[id, 3] / scale_y), 45 | # int(polys[id, 4] / scale_x), int(polys[id, 5] / scale_y), int(polys[id, 6] / scale_x), int(polys[id, 7] / scale_y)]).reshape((4, 2))) 46 | # numpy_reshaped_pred_polys = np.stack(reshaped_pred_polys) 47 | # strong_gt_quads = [] 48 | # weak_gt_quads = [] 49 | # lines = [line.rstrip('\n') for line in open(os.path.join(os.path.join(images_folder, '../Challenge4_Test_Task4_GT'), 'gt_' + image_name[:-4] + '.txt'), 50 | # encoding='utf-8-sig')] 51 | # pattern = re.compile('^' + '(\\d+),' * 8 + '(.+)$') 52 | # for line in lines: 53 | # matches = pattern.findall(line)[0] 54 | # numbers = np.array(matches[:8], dtype=float) 55 | # if '###' == matches[8]: 56 | # weak_gt_quads.append(numbers.reshape((4, 2))) 57 | # else: 58 | # strong_gt_quads.append(numbers.reshape((4, 2))) 59 | # if len(strong_gt_quads): 60 | # numpy_strong_gt_quads = np.stack(strong_gt_quads) 61 | # cv2.polylines(image, numpy_strong_gt_quads.round().astype(int), True, (0, 0, 255)) 62 | # if len(weak_gt_quads): 63 | # numpy_weak_gt_quads = np.stack(weak_gt_quads) 64 | # cv2.polylines(image, numpy_weak_gt_quads.round().astype(int), True, (0, 255, 255)) 65 | # cv2.polylines(image, numpy_reshaped_pred_polys.round().astype(int), True, (255, 0, 0)) 66 | # cv2.imshow('img', image) 67 | # print(image_name) 68 | # cv2.waitKey(0) 69 | 70 | 71 | if __name__ == '__main__': 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument('--images-folder', type=str, required=True, help='path to the folder with test images') 74 | parser.add_argument('--output-folder', type=str, default='fots_test_results', 75 | help='path to the output folder with result labels') 76 | parser.add_argument('--checkpoint', type=str, required=True, help='path to the checkpoint to test') 77 | parser.add_argument('--height-size', type=int, default=1260, help='height size to resize input image') 78 | args = parser.parse_args() 79 | 80 | if not os.path.exists(args.output_folder): 81 | os.makedirs(args.output_folder) 82 | 83 | net = FOTSModel() 84 | checkpoint = torch.load(args.checkpoint) 85 | print('Epoch ', checkpoint['epoch']) 86 | net.load_state_dict(checkpoint['model_state_dict']) 87 | net = net.eval().cuda() 88 | with torch.no_grad(): 89 | test(net, args.images_folder, args.output_folder, args.height_size) 90 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | 4 | import cv2 5 | import numpy as np 6 | import numpy.random as nprnd 7 | import os 8 | import torch 9 | import torch.utils.data 10 | import tqdm 11 | 12 | import datasets 13 | from model import FOTSModel 14 | from modules.parse_polys import parse_polys 15 | 16 | 17 | def restore_checkpoint(folder, contunue): 18 | model = FOTSModel().to(torch.device("cuda")) 19 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5) 20 | lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=32, verbose=True, threshold=0.05, threshold_mode='rel') 21 | 22 | checkppoint_name = os.path.join(folder, 'epoch_8_checkpoint.pt') 23 | if os.path.isfile(checkppoint_name) and contunue: 24 | checkpoint = torch.load(checkppoint_name) 25 | model.load_state_dict(checkpoint['model_state_dict']) 26 | # return 0, model, optimizer, lr_scheduler, +math.inf 27 | epoch = checkpoint['epoch'] + 1 28 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 29 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict']) 30 | best_score = checkpoint['best_score'] 31 | return epoch, model, optimizer, lr_scheduler, best_score 32 | else: 33 | return 0, model, optimizer, lr_scheduler, +math.inf 34 | 35 | 36 | def save_checkpoint(epoch, model, optimizer, lr_scheduler, best_score, folder, save_as_best): 37 | if not os.path.exists(folder): 38 | os.makedirs(folder) 39 | # if epoch > 60 and epoch % 6 == 0: 40 | if True: 41 | torch.save({ 42 | 'epoch': epoch, 43 | 'model_state_dict': model.module.state_dict(), 44 | 'optimizer_state_dict': optimizer.state_dict(), 45 | 'lr_scheduler_state_dict': lr_scheduler.state_dict(), 46 | 'best_score': best_score # not current score 47 | }, os.path.join(folder, 'epoch_{}_checkpoint.pt'.format(epoch))) 48 | 49 | if save_as_best: 50 | torch.save({ 51 | 'epoch': epoch, 52 | 'model_state_dict': model.module.state_dict(), 53 | 'optimizer_state_dict': optimizer.state_dict(), 54 | 'lr_scheduler_state_dict': lr_scheduler.state_dict(), 55 | 'best_score': best_score # not current score 56 | }, os.path.join(folder, 'best_checkpoint.pt')) 57 | print('Updated best_model') 58 | torch.save({ 59 | 'epoch': epoch, 60 | 'model_state_dict': model.module.state_dict(), 61 | 'optimizer_state_dict': optimizer.state_dict(), 62 | 'lr_scheduler_state_dict': lr_scheduler.state_dict(), 63 | 'best_score': best_score # not current score 64 | }, os.path.join(folder, 'last_checkpoint.pt')) 65 | 66 | 67 | def fill_ohem_mask(raw_loss, ohem_mask, num_samples_total, max_hard_samples, max_rnd_samples): 68 | h, w = raw_loss.shape 69 | if num_samples_total != 0: 70 | top_val, top_idx = torch.topk(raw_loss.view(-1), num_samples_total) 71 | num_hard_samples = int(min(max_hard_samples, num_samples_total)) 72 | 73 | num_rnd_samples = max_hard_samples + max_rnd_samples - num_hard_samples 74 | num_rnd_samples = min(num_rnd_samples, num_samples_total - num_hard_samples) 75 | weight = num_hard_samples + num_rnd_samples 76 | 77 | for id in range(min(len(top_idx), num_hard_samples)): 78 | val = top_idx[id] 79 | y = val // w 80 | x = val - y * w 81 | ohem_mask[y, x] = 1 #/ weight 82 | 83 | if num_rnd_samples != 0: 84 | for id in nprnd.randint(num_hard_samples, num_hard_samples + num_rnd_samples, num_rnd_samples): 85 | val = top_idx[id] 86 | y = val // w 87 | x = val - y * w 88 | ohem_mask[y, x] = 1 #/ weight 89 | 90 | 91 | def detection_loss(pred, gt): 92 | y_pred_cls, y_pred_geo, theta_pred = pred 93 | y_true_cls, y_true_geo, theta_gt, training_mask = gt 94 | y_true_cls, theta_gt = y_true_cls.unsqueeze(1), theta_gt.unsqueeze(1) 95 | y_true_cls, y_true_geo, theta_gt = y_true_cls.to('cuda'), y_true_geo.to('cuda'), theta_gt.to('cuda') 96 | 97 | raw_cls_loss = torch.nn.functional.binary_cross_entropy_with_logits(input=y_pred_cls, target=y_true_cls, weight=None, reduction='none') 98 | 99 | d1_gt, d2_gt, d3_gt, d4_gt = torch.split(y_true_geo, 1, 1) 100 | d1_pred, d2_pred, d3_pred, d4_pred = torch.split(y_pred_geo, 1, 1) 101 | area_gt = (d1_gt + d3_gt) * (d2_gt + d4_gt) 102 | area_pred = (d1_pred + d3_pred) * (d2_pred + d4_pred) 103 | w_intersect = torch.min(d2_gt, d2_pred) + torch.min(d4_gt, d4_pred) 104 | h_intersect = torch.min(d1_gt, d1_pred) + torch.min(d3_gt, d3_pred) 105 | area_intersect = w_intersect * h_intersect 106 | area_union = area_gt + area_pred - area_intersect 107 | raw_tensor_loss = -torch.log((area_intersect+1) / (area_union+1)) + 10 * (1 - torch.cos(theta_pred - theta_gt)) 108 | 109 | ohem_cls_mask = np.zeros(raw_cls_loss.shape, dtype=np.float32) 110 | ohem_reg_mask = np.zeros(raw_cls_loss.shape, dtype=np.float32) 111 | for batch_id in range(y_true_cls.shape[0]): 112 | y_true = y_true_cls[batch_id].squeeze().data.cpu().numpy().astype(np.uint8) 113 | mask = training_mask[batch_id].squeeze().data.cpu().numpy().astype(np.uint8) 114 | shrunk_mask = y_true & mask 115 | neg_mask = y_true.copy() 116 | neg_mask[y_true == 1] = 0 117 | neg_mask[y_true == 0] = 1 118 | neg_mask[mask == 0] = 0 119 | 120 | shrunk_sum = int(shrunk_mask.sum()) 121 | if shrunk_sum != 0: 122 | ohem_cls_mask[batch_id, 0, shrunk_mask == 1] = 1 #/ shrunk_sum 123 | raw_loss = raw_cls_loss[batch_id].squeeze().data.cpu().numpy() 124 | raw_loss[neg_mask == 0] = 0 125 | raw_loss = torch.from_numpy(raw_loss) 126 | num_neg = int(neg_mask.sum()) 127 | fill_ohem_mask(raw_loss, ohem_cls_mask[batch_id, 0], num_neg, 512, 512) 128 | 129 | raw_loss = raw_tensor_loss[batch_id].squeeze().data.cpu().numpy() 130 | raw_loss[shrunk_mask == 0] = 0 131 | raw_loss = torch.from_numpy(raw_loss) 132 | num_pos = int(shrunk_mask.sum()) 133 | fill_ohem_mask(raw_loss, ohem_reg_mask[batch_id, 0], num_pos, 128, 128) 134 | 135 | if 0: 136 | for batch_id in range(y_true_cls.shape[0]): 137 | y_true = y_true_cls[batch_id].squeeze().data.cpu().numpy().astype(np.uint8) 138 | cv2.imshow('y_true', y_true*255) 139 | mask = training_mask[batch_id].squeeze().data.cpu().numpy().astype(np.uint8) 140 | cv2.imshow('mask', mask*255) 141 | 142 | shrunk_mask = y_true & mask 143 | cv2.imshow('shrunk pos', shrunk_mask*255) 144 | neg_mask = y_true.copy() 145 | neg_mask[y_true == 1] = 0 146 | neg_mask[y_true == 0] = 1 147 | neg_mask[mask == 0] = 0 148 | cv2.imshow('neg', neg_mask*255) 149 | 150 | cv2.imshow('ohem_cls', ohem_cls_mask[batch_id, 0]) 151 | cv2.imshow('ohem_reg', ohem_reg_mask[batch_id, 0]) 152 | 153 | cv2.waitKey() 154 | ohem_cls_mask_sum = int(ohem_cls_mask.sum()) 155 | ohem_reg_mask_sum = int(ohem_reg_mask.sum()) 156 | if 0 != ohem_cls_mask_sum: 157 | raw_cls_loss = raw_cls_loss * torch.from_numpy(ohem_cls_mask).cuda() 158 | raw_cls_loss = raw_cls_loss.sum() / ohem_cls_mask_sum 159 | else: 160 | raw_cls_loss = 0 161 | 162 | if 0 != ohem_reg_mask_sum: 163 | raw_tensor_loss = raw_tensor_loss * torch.from_numpy(ohem_reg_mask).cuda() 164 | reg_loss = raw_tensor_loss.sum() / ohem_reg_mask_sum 165 | else: 166 | reg_loss = 0 167 | return reg_loss + raw_cls_loss 168 | 169 | 170 | def show_tensors(cropped, classification, regression, thetas, training_mask, file_names): 171 | print(file_names[0]) 172 | cropped = cropped[0].to('cpu').numpy() 173 | cropped = np.transpose(cropped, (1, 2, 0)) 174 | cropped = cv2.resize(cropped, None, fx=0.25, fy=0.25) / 255 175 | 176 | d1, d2, d3, d4 = torch.split(regression.to('cpu'), 1, 1) 177 | d1, d2, d3, d4 = d1[0].view(160, 160).detach().numpy(), d2[0].view(160, 160).detach().numpy(), d3[0].view(160, 160).detach().numpy(), d4[0].view(160, 160).detach().numpy() 178 | 179 | thetas = thetas[0].view(160, 160).to('cpu').detach().numpy() 180 | 181 | cv2.imshow('', cropped) 182 | cv2.waitKey(0) 183 | cv2.imshow('', classification[0].view(160, 160).to('cpu').detach().numpy()) 184 | cv2.waitKey(0) 185 | cv2.imshow('', d1 / np.amax(d1)) 186 | cv2.waitKey(0) 187 | cv2.imshow('', d2 / np.amax(d2)) 188 | cv2.waitKey(0) 189 | cv2.imshow('', d3 / np.amax(d3)) 190 | cv2.waitKey(0) 191 | cv2.imshow('', d4 / np.amax(d4)) 192 | cv2.waitKey(0) 193 | cv2.imshow('', thetas / np.amin(thetas)) 194 | cv2.waitKey(0) 195 | cv2.imshow('', training_mask[0].to('cpu').detach().numpy()) 196 | cv2.waitKey(0) 197 | 198 | 199 | def fit(start_epoch, model, loss_func, opt, lr_scheduler, best_score, max_batches_per_iter_cnt, checkpoint_dir, train_dl, valid_dl): 200 | batch_per_iter_cnt = 0 201 | for epoch in range(start_epoch, 999): 202 | model.train() 203 | train_loss_stats = 0.0 204 | loss_count_stats = 0 205 | pbar = tqdm.tqdm(train_dl, 'Epoch ' + str(epoch), ncols=80) 206 | for cropped, classification, regression, thetas, training_mask in pbar: 207 | if batch_per_iter_cnt == 0: 208 | optimizer.zero_grad() 209 | prediction = model(cropped.to('cuda')) 210 | 211 | if 0: 212 | for batch_id in range(cropped.shape[0]): 213 | img = cropped[batch_id].data.cpu().numpy().transpose((1, 2, 0)) * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406] 214 | cv2.imshow('img', img[:, :, ::-1]) 215 | 216 | cls = np.squeeze(prediction[0][batch_id].data.cpu().numpy()) 217 | #cls = cv2.resize(cls, (0, 0), fx=4, fy=4, interpolation=cv2.INTER_AREA) 218 | 219 | mask = training_mask[batch_id].data.cpu().numpy() 220 | #mask = cv2.resize(mask, (0, 0), fx=4, fy=4, interpolation=cv2.INTER_AREA) 221 | cv2.imshow('mask', mask) 222 | #cv2.imshow('cls', cls*mask) 223 | cls_bin = cls > 0.5 224 | 225 | cls2 = cls.copy() 226 | cls2[cls_bin != True] = 0 227 | 228 | cv2.imshow('cls', cls) 229 | 230 | #res = parse_polys(cls2, 231 | # prediction[1][batch_id].data.cpu().numpy(), 232 | # np.squeeze(prediction[2][batch_id].data.cpu().numpy()), img=img.copy()) 233 | 234 | #top_dist = regression[batch_id, 0].data.cpu().numpy() 235 | #top_dist /= top_dist.max() 236 | ##top_dist = cv2.resize(top_dist, (0, 0), fx=4, fy=4, interpolation=cv2.INTER_AREA) 237 | #cv2.imshow('top_dist', top_dist) 238 | # 239 | #right_dist = regression[batch_id, 1].data.cpu().numpy() 240 | #right_dist /= right_dist.max() 241 | ##right_dist = cv2.resize(right_dist, (0, 0), fx=4, fy=4, interpolation=cv2.INTER_AREA) 242 | #cv2.imshow('right_dist', right_dist) 243 | # 244 | #bottom_dist = regression[batch_id, 2].data.cpu().numpy() 245 | #bottom_dist /= bottom_dist.max() 246 | ##bottom_dist = cv2.resize(bottom_dist, (0, 0), fx=4, fy=4, interpolation=cv2.INTER_AREA) 247 | #cv2.imshow('bottom_dist', bottom_dist) 248 | # 249 | #left_dist = regression[batch_id, 3].data.cpu().numpy() 250 | #left_dist /= left_dist.max() 251 | ##left_dist = cv2.resize(left_dist, (0, 0), fx=4, fy=4, interpolation=cv2.INTER_AREA) 252 | #cv2.imshow('left_dist', left_dist) 253 | # 254 | # 255 | ##angle = thetas[batch_id].data.cpu().numpy() 256 | #angle = prediction[2][batch_id].squeeze().data.cpu().numpy() 257 | ##angle /= angle.max() 258 | ##left_dist = cv2.resize(left_dist, (0, 0), fx=4, fy=4, interpolation=cv2.INTER_AREA) 259 | #cv2.imshow('angle', (angle * cls_bin / np.pi * 180).astype(np.uint8)) 260 | cv2.waitKey() 261 | 262 | # show_tensors(cropped, classification, regression, thetas, training_mask, file_names) 263 | 264 | # show_tensors(cropped, *prediction, training_mask, file_names) 265 | 266 | loss = loss_func(prediction, (classification, regression, thetas, training_mask)) / max_batches_per_iter_cnt 267 | train_loss_stats += loss.item() 268 | loss.backward() 269 | batch_per_iter_cnt += 1 270 | if batch_per_iter_cnt == max_batches_per_iter_cnt: 271 | opt.step() 272 | batch_per_iter_cnt = 0 273 | loss_count_stats += 1 274 | mean_loss = train_loss_stats / loss_count_stats 275 | pbar.set_postfix({'Mean loss': f'{mean_loss:.5f}'}, refresh=False) 276 | lr_scheduler.step(mean_loss, epoch) 277 | 278 | if valid_dl is None: 279 | val_loss = train_loss_stats / loss_count_stats 280 | else: 281 | model.eval() 282 | with torch.no_grad(): 283 | val_loss = 0.0 284 | val_loss_count = 0 285 | for cropped, classification, regression, thetas, training_mask, file_names in valid_dl: 286 | prediction = model(cropped.to('cuda')) 287 | loss = loss_func(prediction, (classification, regression, thetas, training_mask, file_names)) 288 | val_loss += loss.item() 289 | val_loss_count += len(cropped) 290 | val_loss /= val_loss_count 291 | # print('Val loss: ', val_loss) 292 | 293 | if best_score > val_loss: 294 | best_score = val_loss 295 | save_as_best = True 296 | else: 297 | save_as_best = False 298 | save_checkpoint(epoch, model, opt, lr_scheduler, best_score, checkpoint_dir, save_as_best) 299 | 300 | 301 | if __name__ == '__main__': 302 | parser = argparse.ArgumentParser() 303 | parser.add_argument('--train-folder', type=str, required=True, help='Path to folder with train images and labels') 304 | parser.add_argument('--batch-size', type=int, default=21, help='Number of batches to process before train step') 305 | parser.add_argument('--batches-before-train', type=int, default=2, help='Number of batches to process before train step') 306 | parser.add_argument('--num-workers', type=int, default=8, help='Path to folder with train images and labels') 307 | parser.add_argument('--continue-training', action='store_true', help='continue training') 308 | args = parser.parse_args() 309 | 310 | data_set = datasets.SynthText(args.train_folder, datasets.transform) 311 | # data_set = datasets.ICDAR2015(args.train_folder, datasets.transform) 312 | 313 | # SynthText and ICDAR2015 have different layouts. One will probably need to provide two different paths to train 314 | # on concatination of these two data sets. But the paper doesn't concat them so me neither 315 | # datai_set = torch.utils.data.ConcatDataset((synth, icdar)) 316 | 317 | dl = torch.utils.data.DataLoader(data_set, batch_size=args.batch_size, shuffle=True, 318 | sampler=None, batch_sampler=None, num_workers=args.num_workers) 319 | checkoint_dir = 'runs' 320 | epoch, model, optimizer, lr_scheduler, best_score = restore_checkpoint(checkoint_dir, args.continue_training) 321 | model = torch.nn.DataParallel(model) 322 | fit(epoch, model, detection_loss, optimizer, lr_scheduler, best_score, args.batches_before_train, checkoint_dir, dl, None) 323 | --------------------------------------------------------------------------------