├── .gitignore ├── BASNet ├── basnet_test.py ├── data_loader.py ├── model │ ├── BASNet.py │ ├── __init__.py │ └── resnet_model.py ├── pytorch_iou │ └── __init__.py ├── pytorch_ssim │ └── __init__.py └── utils.py ├── LICENSE ├── README.md ├── ShuffleNetV2.py ├── SmartText_demo.py ├── cal_color.py ├── figures └── example.png ├── generate_candidates.py ├── make_all.sh ├── mobilenetv2.py ├── option.py ├── rod_align ├── __init__.py ├── functions │ ├── __init__.py │ └── rod_align.py ├── make.sh ├── modules │ ├── __init__.py │ └── rod_align.py ├── setup.py └── src │ ├── rod_align.cpp │ ├── rod_align.h │ ├── rod_align_cuda.cpp │ ├── rod_align_cuda.h │ ├── rod_align_kernel.cu │ └── rod_align_kernel.h ├── roi_align ├── __init__.py ├── functions │ ├── __init__.py │ └── roi_align.py ├── make.sh ├── modules │ ├── __init__.py │ └── roi_align.py ├── setup.py └── src │ ├── roi_align.cpp │ ├── roi_align.h │ ├── roi_align_cuda.cpp │ ├── roi_align_cuda.h │ ├── roi_align_kernel.cu │ ├── roi_align_kernel.cu.o │ └── roi_align_kernel.h ├── smtDataset.py ├── smtModel.py ├── test_data ├── Fonts │ └── verdanab.ttf └── SMT │ ├── 00001.jpg │ ├── 00002.png │ ├── 00003.jpg │ ├── 00004.jpg │ ├── 00005.jpg │ └── 00006.jpg ├── test_opt.yml └── util_cal.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # user define ----------------------------- 132 | # dataset 133 | dataset/ 134 | test_data/test_modi 135 | test_result/ 136 | 137 | # logs 138 | logs*/ 139 | runs*/ 140 | tb_logger/ 141 | 142 | # model 143 | training/ 144 | experiments/ 145 | pretrained/ 146 | 147 | .vscode/ 148 | -------------------------------------------------------------------------------- /BASNet/basnet_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from skimage import io 3 | import torch 4 | from torch.autograd import Variable 5 | from torch.utils.data import DataLoader 6 | from torchvision import transforms 7 | 8 | import numpy as np 9 | from PIL import Image 10 | 11 | from .data_loader import RescaleT 12 | from .data_loader import ToTensorLab 13 | from .data_loader import SalObjDataset 14 | 15 | from . import utils 16 | import cv2 17 | 18 | 19 | def normPRED(d): 20 | ma = torch.max(d) 21 | mi = torch.min(d) 22 | dn = (d - mi) / (ma - mi) 23 | return dn 24 | 25 | 26 | def save_output(image_name, pred, d_dir, d_dir_ovl): 27 | # overlay the importance map on the input image 28 | predict = pred.squeeze() 29 | predict_np = predict.cpu().data.numpy() 30 | 31 | im = Image.fromarray(predict_np * 255).convert('RGB') 32 | img_name = image_name.split("/")[-1] 33 | image = io.imread(image_name) 34 | imo = im.resize((image.shape[1], image.shape[0]), resample=Image.BILINEAR) 35 | pb_np = np.array(imo) 36 | pb_np = pb_np[:, :, :1].squeeze() 37 | 38 | aaa = img_name.split(".") 39 | bbb = aaa[0:-1] 40 | imidx = bbb[0] 41 | for i in range(1, len(bbb)): 42 | imidx = imidx + "." + bbb[i] 43 | 44 | # save the importance map 45 | pred_fp = d_dir + imidx + '.png' 46 | imo.save(pred_fp) 47 | 48 | img_ini = Image.open(image_name).convert('RGB') 49 | img_imp = cv2.imread(d_dir + imidx + '.png') 50 | img_imp = img_imp[:, :, :1] 51 | img_imp = img_imp.squeeze() 52 | fname = os.path.join(d_dir_ovl + imidx + '_ovl' + '.png') 53 | utils.overlay_imp_on_img(img_ini, img_imp, fname, colormap='jet') 54 | 55 | return pred_fp, pb_np 56 | 57 | 58 | def get_imp(img_name, prediction_dir, prediction_dir_ovl, visimp_model): 59 | 60 | # predict the importance map 61 | os.makedirs(prediction_dir, exist_ok=True) 62 | os.makedirs(prediction_dir_ovl, exist_ok=True) 63 | img_name_list = [img_name] 64 | 65 | # --------- dataloader ----------- 66 | test_salobj_dataset = SalObjDataset(img_name_list=img_name_list, 67 | lbl_name_list=[], 68 | transform=transforms.Compose([RescaleT(256), ToTensorLab(flag=0)])) 69 | test_salobj_dataloader = DataLoader(test_salobj_dataset, batch_size=1, shuffle=False, num_workers=0) 70 | 71 | net = visimp_model 72 | 73 | # --------- inference for each image --------- 74 | for i_test, data_test in enumerate(test_salobj_dataloader): 75 | print("inferencing:", img_name_list[i_test].split("/")[-1]) 76 | 77 | inputs_test = data_test['image'] 78 | inputs_test = inputs_test.type(torch.FloatTensor) 79 | 80 | if torch.cuda.is_available(): 81 | inputs_test = Variable(inputs_test.cuda()) 82 | else: 83 | inputs_test = Variable(inputs_test) 84 | 85 | d1, _, _, _, _, _, _, _ = net(inputs_test) 86 | 87 | # normalization 88 | pred = d1[:, 0, :, :] 89 | pred = normPRED(pred) 90 | 91 | # save results to prediction_dir folder 92 | pred_name, pred_np = save_output(img_name_list[i_test], pred, prediction_dir, prediction_dir_ovl) 93 | 94 | return pred_name, pred_np 95 | -------------------------------------------------------------------------------- /BASNet/data_loader.py: -------------------------------------------------------------------------------- 1 | # data loader 2 | from __future__ import print_function, division 3 | import glob 4 | import torch 5 | from skimage import io, transform, color 6 | import numpy as np 7 | import math 8 | import matplotlib.pyplot as plt 9 | from torch.utils.data import Dataset, DataLoader 10 | from torchvision import transforms, utils 11 | from PIL import Image 12 | #==========================dataset load========================== 13 | 14 | 15 | class RescaleT(object): 16 | 17 | def __init__(self, output_size): 18 | assert isinstance(output_size, (int, tuple)) 19 | self.output_size = output_size 20 | 21 | def __call__(self, sample): 22 | image, label = sample['image'], sample['label'] 23 | 24 | h, w = image.shape[:2] 25 | 26 | if isinstance(self.output_size, int): 27 | if h > w: 28 | new_h, new_w = self.output_size * h / w, self.output_size 29 | else: 30 | new_h, new_w = self.output_size, self.output_size * w / h 31 | else: 32 | new_h, new_w = self.output_size 33 | 34 | new_h, new_w = int(new_h), int(new_w) 35 | 36 | # resize the image to new_h x new_w and convert image from range [0,255] to [0,1] 37 | # img = transform.resize(image,(new_h,new_w),mode='constant') 38 | # lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True) 39 | 40 | img = transform.resize(image, (self.output_size, self.output_size), mode='constant') 41 | lbl = transform.resize(label, (self.output_size, self.output_size), 42 | mode='constant', 43 | order=0, 44 | preserve_range=True) 45 | 46 | return {'image': img, 'label': lbl} 47 | 48 | 49 | class Rescale(object): 50 | 51 | def __init__(self, output_size): 52 | assert isinstance(output_size, (int, tuple)) 53 | self.output_size = output_size 54 | 55 | def __call__(self, sample): 56 | image, label = sample['image'], sample['label'] 57 | 58 | h, w = image.shape[:2] 59 | 60 | if isinstance(self.output_size, int): 61 | if h > w: 62 | new_h, new_w = self.output_size * h / w, self.output_size 63 | else: 64 | new_h, new_w = self.output_size, self.output_size * w / h 65 | else: 66 | new_h, new_w = self.output_size 67 | 68 | new_h, new_w = int(new_h), int(new_w) 69 | 70 | # resize the image to new_h x new_w and convert image from range [0,255] to [0,1] 71 | img = transform.resize(image, (new_h, new_w), mode='constant') 72 | lbl = transform.resize(label, (new_h, new_w), mode='constant', order=0, preserve_range=True) 73 | 74 | return {'image': img, 'label': lbl} 75 | 76 | 77 | class CenterCrop(object): 78 | 79 | def __init__(self, output_size): 80 | assert isinstance(output_size, (int, tuple)) 81 | if isinstance(output_size, int): 82 | self.output_size = (output_size, output_size) 83 | else: 84 | assert len(output_size) == 2 85 | self.output_size = output_size 86 | 87 | def __call__(self, sample): 88 | image, label = sample['image'], sample['label'] 89 | 90 | h, w = image.shape[:2] 91 | new_h, new_w = self.output_size 92 | 93 | # print("h: %d, w: %d, new_h: %d, new_w: %d"%(h, w, new_h, new_w)) 94 | assert ((h >= new_h) and (w >= new_w)) 95 | 96 | h_offset = int(math.floor((h - new_h) / 2)) 97 | w_offset = int(math.floor((w - new_w) / 2)) 98 | 99 | image = image[h_offset:h_offset + new_h, w_offset:w_offset + new_w] 100 | label = label[h_offset:h_offset + new_h, w_offset:w_offset + new_w] 101 | 102 | return {'image': image, 'label': label} 103 | 104 | 105 | class RandomCrop(object): 106 | 107 | def __init__(self, output_size): 108 | assert isinstance(output_size, (int, tuple)) 109 | if isinstance(output_size, int): 110 | self.output_size = (output_size, output_size) 111 | else: 112 | assert len(output_size) == 2 113 | self.output_size = output_size 114 | 115 | def __call__(self, sample): 116 | image, label = sample['image'], sample['label'] 117 | 118 | h, w = image.shape[:2] 119 | new_h, new_w = self.output_size 120 | 121 | top = np.random.randint(0, h - new_h) 122 | left = np.random.randint(0, w - new_w) 123 | 124 | image = image[top:top + new_h, left:left + new_w] 125 | label = label[top:top + new_h, left:left + new_w] 126 | 127 | return {'image': image, 'label': label} 128 | 129 | 130 | class ToTensor(object): 131 | """Convert ndarrays in sample to Tensors.""" 132 | 133 | def __call__(self, sample): 134 | 135 | image, label = sample['image'], sample['label'] 136 | 137 | tmpImg = np.zeros((image.shape[0], image.shape[1], 3)) 138 | tmpLbl = np.zeros(label.shape) 139 | 140 | image = image / np.max(image) 141 | if (np.max(label) < 1e-6): 142 | label = label 143 | else: 144 | label = label / np.max(label) 145 | 146 | if image.shape[2] == 1: 147 | tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 148 | tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229 149 | tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229 150 | else: 151 | tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 152 | tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224 153 | tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225 154 | 155 | tmpLbl[:, :, 0] = label[:, :, 0] 156 | 157 | # change the r,g,b to b,r,g from [0,255] to [0,1] 158 | # transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)) 159 | tmpImg = tmpImg.transpose((2, 0, 1)) 160 | tmpLbl = label.transpose((2, 0, 1)) 161 | 162 | return {'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)} 163 | 164 | 165 | class ToTensorLab(object): 166 | """Convert ndarrays in sample to Tensors.""" 167 | 168 | def __init__(self, flag=0): 169 | self.flag = flag 170 | 171 | def __call__(self, sample): 172 | 173 | image, label = sample['image'], sample['label'] 174 | 175 | tmpLbl = np.zeros(label.shape) 176 | 177 | if (np.max(label) < 1e-6): 178 | label = label 179 | else: 180 | label = label / np.max(label) 181 | 182 | # change the color space 183 | if self.flag == 2: # with rgb and Lab colors 184 | tmpImg = np.zeros((image.shape[0], image.shape[1], 6)) 185 | tmpImgt = np.zeros((image.shape[0], image.shape[1], 3)) 186 | if image.shape[2] == 1: 187 | tmpImgt[:, :, 0] = image[:, :, 0] 188 | tmpImgt[:, :, 1] = image[:, :, 0] 189 | tmpImgt[:, :, 2] = image[:, :, 0] 190 | else: 191 | tmpImgt = image 192 | tmpImgtl = color.rgb2lab(tmpImgt) 193 | 194 | # nomalize image to range [0,1] 195 | tmpImg[:, :, 0] = (tmpImgt[:, :, 0] - np.min(tmpImgt[:, :, 0])) / (np.max(tmpImgt[:, :, 0]) - 196 | np.min(tmpImgt[:, :, 0])) 197 | tmpImg[:, :, 1] = (tmpImgt[:, :, 1] - np.min(tmpImgt[:, :, 1])) / (np.max(tmpImgt[:, :, 1]) - 198 | np.min(tmpImgt[:, :, 1])) 199 | tmpImg[:, :, 2] = (tmpImgt[:, :, 2] - np.min(tmpImgt[:, :, 2])) / (np.max(tmpImgt[:, :, 2]) - 200 | np.min(tmpImgt[:, :, 2])) 201 | tmpImg[:, :, 3] = (tmpImgtl[:, :, 0] - np.min(tmpImgtl[:, :, 0])) / (np.max(tmpImgtl[:, :, 0]) - 202 | np.min(tmpImgtl[:, :, 0])) 203 | tmpImg[:, :, 4] = (tmpImgtl[:, :, 1] - np.min(tmpImgtl[:, :, 1])) / (np.max(tmpImgtl[:, :, 1]) - 204 | np.min(tmpImgtl[:, :, 1])) 205 | tmpImg[:, :, 5] = (tmpImgtl[:, :, 2] - np.min(tmpImgtl[:, :, 2])) / (np.max(tmpImgtl[:, :, 2]) - 206 | np.min(tmpImgtl[:, :, 2])) 207 | 208 | # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg)) 209 | 210 | tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(tmpImg[:, :, 0]) 211 | tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(tmpImg[:, :, 1]) 212 | tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(tmpImg[:, :, 2]) 213 | tmpImg[:, :, 3] = (tmpImg[:, :, 3] - np.mean(tmpImg[:, :, 3])) / np.std(tmpImg[:, :, 3]) 214 | tmpImg[:, :, 4] = (tmpImg[:, :, 4] - np.mean(tmpImg[:, :, 4])) / np.std(tmpImg[:, :, 4]) 215 | tmpImg[:, :, 5] = (tmpImg[:, :, 5] - np.mean(tmpImg[:, :, 5])) / np.std(tmpImg[:, :, 5]) 216 | 217 | elif self.flag == 1: # with Lab color 218 | tmpImg = np.zeros((image.shape[0], image.shape[1], 3)) 219 | 220 | if image.shape[2] == 1: 221 | tmpImg[:, :, 0] = image[:, :, 0] 222 | tmpImg[:, :, 1] = image[:, :, 0] 223 | tmpImg[:, :, 2] = image[:, :, 0] 224 | else: 225 | tmpImg = image 226 | 227 | tmpImg = color.rgb2lab(tmpImg) 228 | 229 | # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg)) 230 | 231 | tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.min(tmpImg[:, :, 0])) / (np.max(tmpImg[:, :, 0]) - 232 | np.min(tmpImg[:, :, 0])) 233 | tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.min(tmpImg[:, :, 1])) / (np.max(tmpImg[:, :, 1]) - 234 | np.min(tmpImg[:, :, 1])) 235 | tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.min(tmpImg[:, :, 2])) / (np.max(tmpImg[:, :, 2]) - 236 | np.min(tmpImg[:, :, 2])) 237 | 238 | tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(tmpImg[:, :, 0]) 239 | tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(tmpImg[:, :, 1]) 240 | tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(tmpImg[:, :, 2]) 241 | 242 | else: # with rgb color 243 | tmpImg = np.zeros((image.shape[0], image.shape[1], 3)) 244 | image = image / np.max(image) 245 | if image.shape[2] == 1: 246 | tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 247 | tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229 248 | tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229 249 | else: 250 | tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 251 | tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224 252 | tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225 253 | 254 | tmpLbl[:, :, 0] = label[:, :, 0] 255 | 256 | # change the r,g,b to b,r,g from [0,255] to [0,1] 257 | # transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)) 258 | tmpImg = tmpImg.transpose((2, 0, 1)) 259 | tmpLbl = label.transpose((2, 0, 1)) 260 | 261 | return {'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)} 262 | 263 | 264 | class SalObjDataset(Dataset): 265 | 266 | def __init__(self, img_name_list, lbl_name_list, transform=None): 267 | # self.root_dir = root_dir 268 | # self.image_name_list = glob.glob(image_dir+'*.png') 269 | # self.label_name_list = glob.glob(label_dir+'*.png') 270 | self.image_name_list = img_name_list 271 | self.label_name_list = lbl_name_list 272 | self.transform = transform 273 | 274 | def __len__(self): 275 | return len(self.image_name_list) 276 | 277 | def __getitem__(self, idx): 278 | 279 | # image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx]) 280 | # label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx]) 281 | 282 | # image = io.imread(self.image_name_list[idx]) 283 | image = Image.open(self.image_name_list[idx]).convert('RGB') 284 | image = np.array(image, dtype=np.uint8) 285 | 286 | if (0 == len(self.label_name_list)): 287 | label_3 = np.zeros(image.shape) 288 | else: 289 | # label_3 = io.imread(self.label_name_list[idx]) 290 | label_3 = Image.open(self.label_name_list[idx]).convert('RGB') 291 | label_3 = np.array(label_3, dtype=np.uint8) 292 | 293 | label = np.zeros(label_3.shape[0:2]) 294 | if (3 == len(label_3.shape)): 295 | label = label_3[:, :, 0] 296 | elif (2 == len(label_3.shape)): 297 | label = label_3 298 | 299 | if (3 == len(image.shape) and 2 == len(label.shape)): 300 | label = label[:, :, np.newaxis] 301 | elif (2 == len(image.shape) and 2 == len(label.shape)): 302 | image = image[:, :, np.newaxis] 303 | label = label[:, :, np.newaxis] 304 | 305 | # #vertical flipping 306 | # # fliph = np.random.randn(1) 307 | # flipv = np.random.randn(1) 308 | # 309 | # if flipv>0: 310 | # image = image[::-1,:,:] 311 | # label = label[::-1,:,:] 312 | # #vertical flip 313 | 314 | sample = {'image': image, 'label': label} 315 | 316 | if self.transform: 317 | sample = self.transform(sample) 318 | 319 | return sample 320 | -------------------------------------------------------------------------------- /BASNet/model/BASNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | import torch.nn.functional as F 5 | 6 | from .resnet_model import * 7 | 8 | 9 | class RefUnet(nn.Module): 10 | 11 | def __init__(self, in_ch, inc_ch): 12 | super(RefUnet, self).__init__() 13 | 14 | self.conv0 = nn.Conv2d(in_ch, inc_ch, 3, padding=1) 15 | 16 | self.conv1 = nn.Conv2d(inc_ch, 64, 3, padding=1) 17 | self.bn1 = nn.BatchNorm2d(64) 18 | self.relu1 = nn.ReLU(inplace=True) 19 | 20 | self.pool1 = nn.MaxPool2d(2, 2, ceil_mode=True) 21 | 22 | self.conv2 = nn.Conv2d(64, 64, 3, padding=1) 23 | self.bn2 = nn.BatchNorm2d(64) 24 | self.relu2 = nn.ReLU(inplace=True) 25 | 26 | self.pool2 = nn.MaxPool2d(2, 2, ceil_mode=True) 27 | 28 | self.conv3 = nn.Conv2d(64, 64, 3, padding=1) 29 | self.bn3 = nn.BatchNorm2d(64) 30 | self.relu3 = nn.ReLU(inplace=True) 31 | 32 | self.pool3 = nn.MaxPool2d(2, 2, ceil_mode=True) 33 | 34 | self.conv4 = nn.Conv2d(64, 64, 3, padding=1) 35 | self.bn4 = nn.BatchNorm2d(64) 36 | self.relu4 = nn.ReLU(inplace=True) 37 | 38 | self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True) 39 | 40 | ##### 41 | 42 | self.conv5 = nn.Conv2d(64, 64, 3, padding=1) 43 | self.bn5 = nn.BatchNorm2d(64) 44 | self.relu5 = nn.ReLU(inplace=True) 45 | 46 | ##### 47 | 48 | self.conv_d4 = nn.Conv2d(128, 64, 3, padding=1) 49 | self.bn_d4 = nn.BatchNorm2d(64) 50 | self.relu_d4 = nn.ReLU(inplace=True) 51 | 52 | self.conv_d3 = nn.Conv2d(128, 64, 3, padding=1) 53 | self.bn_d3 = nn.BatchNorm2d(64) 54 | self.relu_d3 = nn.ReLU(inplace=True) 55 | 56 | self.conv_d2 = nn.Conv2d(128, 64, 3, padding=1) 57 | self.bn_d2 = nn.BatchNorm2d(64) 58 | self.relu_d2 = nn.ReLU(inplace=True) 59 | 60 | self.conv_d1 = nn.Conv2d(128, 64, 3, padding=1) 61 | self.bn_d1 = nn.BatchNorm2d(64) 62 | self.relu_d1 = nn.ReLU(inplace=True) 63 | 64 | self.conv_d0 = nn.Conv2d(64, 1, 3, padding=1) 65 | 66 | self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 67 | # self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear') 68 | 69 | def forward(self, x): 70 | 71 | hx = x 72 | hx = self.conv0(hx) 73 | 74 | hx1 = self.relu1(self.bn1(self.conv1(hx))) 75 | hx = self.pool1(hx1) 76 | 77 | hx2 = self.relu2(self.bn2(self.conv2(hx))) 78 | hx = self.pool2(hx2) 79 | 80 | hx3 = self.relu3(self.bn3(self.conv3(hx))) 81 | hx = self.pool3(hx3) 82 | 83 | hx4 = self.relu4(self.bn4(self.conv4(hx))) 84 | hx = self.pool4(hx4) 85 | 86 | hx5 = self.relu5(self.bn5(self.conv5(hx))) 87 | 88 | hx = self.upscore2(hx5) 89 | 90 | d4 = self.relu_d4(self.bn_d4(self.conv_d4(torch.cat((hx, hx4), 1)))) 91 | hx = self.upscore2(d4) 92 | 93 | d3 = self.relu_d3(self.bn_d3(self.conv_d3(torch.cat((hx, hx3), 1)))) 94 | hx = self.upscore2(d3) 95 | 96 | d2 = self.relu_d2(self.bn_d2(self.conv_d2(torch.cat((hx, hx2), 1)))) 97 | hx = self.upscore2(d2) 98 | 99 | d1 = self.relu_d1(self.bn_d1(self.conv_d1(torch.cat((hx, hx1), 1)))) 100 | 101 | residual = self.conv_d0(d1) 102 | 103 | return x + residual 104 | 105 | 106 | class BASNet(nn.Module): 107 | 108 | def __init__(self, n_channels, n_classes): 109 | super(BASNet, self).__init__() 110 | 111 | resnet = models.resnet34(pretrained=True) 112 | 113 | ## -------------Encoder-------------- 114 | 115 | self.inconv = nn.Conv2d(n_channels, 64, 3, padding=1) 116 | self.inbn = nn.BatchNorm2d(64) 117 | self.inrelu = nn.ReLU(inplace=True) 118 | 119 | #stage 1 120 | self.encoder1 = resnet.layer1 #256 121 | #stage 2 122 | self.encoder2 = resnet.layer2 #128 123 | #stage 3 124 | self.encoder3 = resnet.layer3 #64 125 | #stage 4 126 | self.encoder4 = resnet.layer4 #32 127 | 128 | self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True) 129 | 130 | #stage 5 131 | self.resb5_1 = BasicBlock(512, 512) 132 | self.resb5_2 = BasicBlock(512, 512) 133 | self.resb5_3 = BasicBlock(512, 512) #16 134 | 135 | self.pool5 = nn.MaxPool2d(2, 2, ceil_mode=True) 136 | 137 | #stage 6 138 | self.resb6_1 = BasicBlock(512, 512) 139 | self.resb6_2 = BasicBlock(512, 512) 140 | self.resb6_3 = BasicBlock(512, 512) #8 141 | 142 | ## -------------Bridge-------------- 143 | 144 | #stage Bridge 145 | self.convbg_1 = nn.Conv2d(512, 512, 3, dilation=2, padding=2) # 8 146 | self.bnbg_1 = nn.BatchNorm2d(512) 147 | self.relubg_1 = nn.ReLU(inplace=True) 148 | self.convbg_m = nn.Conv2d(512, 512, 3, dilation=2, padding=2) 149 | self.bnbg_m = nn.BatchNorm2d(512) 150 | self.relubg_m = nn.ReLU(inplace=True) 151 | self.convbg_2 = nn.Conv2d(512, 512, 3, dilation=2, padding=2) 152 | self.bnbg_2 = nn.BatchNorm2d(512) 153 | self.relubg_2 = nn.ReLU(inplace=True) 154 | 155 | ## -------------Decoder-------------- 156 | 157 | #stage 6d 158 | self.conv6d_1 = nn.Conv2d(1024, 512, 3, padding=1) # 16 159 | self.bn6d_1 = nn.BatchNorm2d(512) 160 | self.relu6d_1 = nn.ReLU(inplace=True) 161 | 162 | self.conv6d_m = nn.Conv2d(512, 512, 3, dilation=2, padding=2) ### 163 | self.bn6d_m = nn.BatchNorm2d(512) 164 | self.relu6d_m = nn.ReLU(inplace=True) 165 | 166 | self.conv6d_2 = nn.Conv2d(512, 512, 3, dilation=2, padding=2) 167 | self.bn6d_2 = nn.BatchNorm2d(512) 168 | self.relu6d_2 = nn.ReLU(inplace=True) 169 | 170 | #stage 5d 171 | self.conv5d_1 = nn.Conv2d(1024, 512, 3, padding=1) # 16 172 | self.bn5d_1 = nn.BatchNorm2d(512) 173 | self.relu5d_1 = nn.ReLU(inplace=True) 174 | 175 | self.conv5d_m = nn.Conv2d(512, 512, 3, padding=1) ### 176 | self.bn5d_m = nn.BatchNorm2d(512) 177 | self.relu5d_m = nn.ReLU(inplace=True) 178 | 179 | self.conv5d_2 = nn.Conv2d(512, 512, 3, padding=1) 180 | self.bn5d_2 = nn.BatchNorm2d(512) 181 | self.relu5d_2 = nn.ReLU(inplace=True) 182 | 183 | #stage 4d 184 | self.conv4d_1 = nn.Conv2d(1024, 512, 3, padding=1) # 32 185 | self.bn4d_1 = nn.BatchNorm2d(512) 186 | self.relu4d_1 = nn.ReLU(inplace=True) 187 | 188 | self.conv4d_m = nn.Conv2d(512, 512, 3, padding=1) ### 189 | self.bn4d_m = nn.BatchNorm2d(512) 190 | self.relu4d_m = nn.ReLU(inplace=True) 191 | 192 | self.conv4d_2 = nn.Conv2d(512, 256, 3, padding=1) 193 | self.bn4d_2 = nn.BatchNorm2d(256) 194 | self.relu4d_2 = nn.ReLU(inplace=True) 195 | 196 | #stage 3d 197 | self.conv3d_1 = nn.Conv2d(512, 256, 3, padding=1) # 64 198 | self.bn3d_1 = nn.BatchNorm2d(256) 199 | self.relu3d_1 = nn.ReLU(inplace=True) 200 | 201 | self.conv3d_m = nn.Conv2d(256, 256, 3, padding=1) ### 202 | self.bn3d_m = nn.BatchNorm2d(256) 203 | self.relu3d_m = nn.ReLU(inplace=True) 204 | 205 | self.conv3d_2 = nn.Conv2d(256, 128, 3, padding=1) 206 | self.bn3d_2 = nn.BatchNorm2d(128) 207 | self.relu3d_2 = nn.ReLU(inplace=True) 208 | 209 | #stage 2d 210 | 211 | self.conv2d_1 = nn.Conv2d(256, 128, 3, padding=1) # 128 212 | self.bn2d_1 = nn.BatchNorm2d(128) 213 | self.relu2d_1 = nn.ReLU(inplace=True) 214 | 215 | self.conv2d_m = nn.Conv2d(128, 128, 3, padding=1) ### 216 | self.bn2d_m = nn.BatchNorm2d(128) 217 | self.relu2d_m = nn.ReLU(inplace=True) 218 | 219 | self.conv2d_2 = nn.Conv2d(128, 64, 3, padding=1) 220 | self.bn2d_2 = nn.BatchNorm2d(64) 221 | self.relu2d_2 = nn.ReLU(inplace=True) 222 | 223 | #stage 1d 224 | self.conv1d_1 = nn.Conv2d(128, 64, 3, padding=1) # 256 225 | self.bn1d_1 = nn.BatchNorm2d(64) 226 | self.relu1d_1 = nn.ReLU(inplace=True) 227 | 228 | self.conv1d_m = nn.Conv2d(64, 64, 3, padding=1) ### 229 | self.bn1d_m = nn.BatchNorm2d(64) 230 | self.relu1d_m = nn.ReLU(inplace=True) 231 | 232 | self.conv1d_2 = nn.Conv2d(64, 64, 3, padding=1) 233 | self.bn1d_2 = nn.BatchNorm2d(64) 234 | self.relu1d_2 = nn.ReLU(inplace=True) 235 | 236 | ## -------------Bilinear Upsampling-------------- 237 | self.upscore6 = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=False) ### 238 | self.upscore5 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=False) 239 | self.upscore4 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=False) 240 | self.upscore3 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False) 241 | self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 242 | 243 | # self.upscore6 = nn.Upsample(scale_factor=32, mode='bilinear') ### 244 | # self.upscore5 = nn.Upsample(scale_factor=16, mode='bilinear') 245 | # self.upscore4 = nn.Upsample(scale_factor=8, mode='bilinear') 246 | # self.upscore3 = nn.Upsample(scale_factor=4, mode='bilinear') 247 | # self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear') 248 | 249 | ## -------------Side Output-------------- 250 | self.outconvb = nn.Conv2d(512, 1, 3, padding=1) 251 | self.outconv6 = nn.Conv2d(512, 1, 3, padding=1) 252 | self.outconv5 = nn.Conv2d(512, 1, 3, padding=1) 253 | self.outconv4 = nn.Conv2d(256, 1, 3, padding=1) 254 | self.outconv3 = nn.Conv2d(128, 1, 3, padding=1) 255 | self.outconv2 = nn.Conv2d(64, 1, 3, padding=1) 256 | self.outconv1 = nn.Conv2d(64, 1, 3, padding=1) 257 | 258 | ## -------------Refine Module------------- 259 | self.refunet = RefUnet(1, 64) 260 | 261 | def forward(self, x): 262 | 263 | hx = x 264 | 265 | ## -------------Encoder------------- 266 | hx = self.inconv(hx) 267 | hx = self.inbn(hx) 268 | hx = self.inrelu(hx) 269 | 270 | h1 = self.encoder1(hx) # 256 271 | h2 = self.encoder2(h1) # 128 272 | h3 = self.encoder3(h2) # 64 273 | h4 = self.encoder4(h3) # 32 274 | 275 | hx = self.pool4(h4) # 16 276 | 277 | hx = self.resb5_1(hx) 278 | hx = self.resb5_2(hx) 279 | h5 = self.resb5_3(hx) 280 | 281 | hx = self.pool5(h5) # 8 282 | 283 | hx = self.resb6_1(hx) 284 | hx = self.resb6_2(hx) 285 | h6 = self.resb6_3(hx) 286 | 287 | ## -------------Bridge------------- 288 | hx = self.relubg_1(self.bnbg_1(self.convbg_1(h6))) # 8 289 | hx = self.relubg_m(self.bnbg_m(self.convbg_m(hx))) 290 | hbg = self.relubg_2(self.bnbg_2(self.convbg_2(hx))) 291 | 292 | ## -------------Decoder------------- 293 | 294 | hx = self.relu6d_1(self.bn6d_1(self.conv6d_1(torch.cat((hbg, h6), 1)))) 295 | hx = self.relu6d_m(self.bn6d_m(self.conv6d_m(hx))) 296 | hd6 = self.relu6d_2(self.bn5d_2(self.conv6d_2(hx))) 297 | 298 | hx = self.upscore2(hd6) # 8 -> 16 299 | 300 | hx = self.relu5d_1(self.bn5d_1(self.conv5d_1(torch.cat((hx, h5), 1)))) 301 | hx = self.relu5d_m(self.bn5d_m(self.conv5d_m(hx))) 302 | hd5 = self.relu5d_2(self.bn5d_2(self.conv5d_2(hx))) 303 | 304 | hx = self.upscore2(hd5) # 16 -> 32 305 | 306 | hx = self.relu4d_1(self.bn4d_1(self.conv4d_1(torch.cat((hx, h4), 1)))) 307 | hx = self.relu4d_m(self.bn4d_m(self.conv4d_m(hx))) 308 | hd4 = self.relu4d_2(self.bn4d_2(self.conv4d_2(hx))) 309 | 310 | hx = self.upscore2(hd4) # 32 -> 64 311 | 312 | hx = self.relu3d_1(self.bn3d_1(self.conv3d_1(torch.cat((hx, h3), 1)))) 313 | hx = self.relu3d_m(self.bn3d_m(self.conv3d_m(hx))) 314 | hd3 = self.relu3d_2(self.bn3d_2(self.conv3d_2(hx))) 315 | 316 | hx = self.upscore2(hd3) # 64 -> 128 317 | 318 | hx = self.relu2d_1(self.bn2d_1(self.conv2d_1(torch.cat((hx, h2), 1)))) 319 | hx = self.relu2d_m(self.bn2d_m(self.conv2d_m(hx))) 320 | hd2 = self.relu2d_2(self.bn2d_2(self.conv2d_2(hx))) 321 | 322 | hx = self.upscore2(hd2) # 128 -> 256 323 | 324 | hx = self.relu1d_1(self.bn1d_1(self.conv1d_1(torch.cat((hx, h1), 1)))) 325 | hx = self.relu1d_m(self.bn1d_m(self.conv1d_m(hx))) 326 | hd1 = self.relu1d_2(self.bn1d_2(self.conv1d_2(hx))) 327 | 328 | ## -------------Side Output------------- 329 | db = self.outconvb(hbg) 330 | db = self.upscore6(db) # 8->256 331 | 332 | d6 = self.outconv6(hd6) 333 | d6 = self.upscore6(d6) # 8->256 334 | 335 | d5 = self.outconv5(hd5) 336 | d5 = self.upscore5(d5) # 16->256 337 | 338 | d4 = self.outconv4(hd4) 339 | d4 = self.upscore4(d4) # 32->256 340 | 341 | d3 = self.outconv3(hd3) 342 | d3 = self.upscore3(d3) # 64->256 343 | 344 | d2 = self.outconv2(hd2) 345 | d2 = self.upscore2(d2) # 128->256 346 | 347 | d1 = self.outconv1(hd1) # 256 348 | 349 | ## -------------Refine Module------------- 350 | dout = self.refunet(d1) # 256 351 | 352 | return torch.sigmoid(dout), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid( 353 | d4), torch.sigmoid(d5), torch.sigmoid(d6), torch.sigmoid(db) 354 | -------------------------------------------------------------------------------- /BASNet/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .BASNet import BASNet 2 | -------------------------------------------------------------------------------- /BASNet/model/resnet_model.py: -------------------------------------------------------------------------------- 1 | ## code from: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | import torch 6 | import torchvision 7 | 8 | # __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 9 | # 'resnet152', 'ResNet34P','ResNet50S','ResNet50P','ResNet101P'] 10 | # 11 | # resnet18_dir = '/local/sda4/yqian3/RoadNets/resnet_model/resnet18-5c106cde.pth' 12 | # resnet34_dir = '/local/sda4/yqian3/RoadNets/resnet_model/resnet34-333f7ec4.pth' 13 | # resnet50_dir = '/local/sda4/yqian3/RoadNets/resnet_model/resnet50-19c8e357.pth' 14 | # resnet101_dir = '/local/sda4/yqian3/RoadNets/resnet_model/resnet101-5d3b4d8f.pth' 15 | # 16 | # model_urls = { 17 | # 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 18 | # 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 19 | # 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 20 | # 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 21 | # 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 22 | # } 23 | 24 | def conv3x3(in_planes, out_planes, stride=1): 25 | "3x3 convolution with padding" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=1, bias=False) 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None): 33 | super(BasicBlock, self).__init__() 34 | self.conv1 = conv3x3(inplanes, planes, stride) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | class BasicBlockDe(nn.Module): 61 | expansion = 1 62 | 63 | def __init__(self, inplanes, planes, stride=1, downsample=None): 64 | super(BasicBlockDe, self).__init__() 65 | 66 | self.convRes = conv3x3(inplanes,planes,stride) 67 | self.bnRes = nn.BatchNorm2d(planes) 68 | self.reluRes = nn.ReLU(inplace=True) 69 | 70 | self.conv1 = conv3x3(inplanes, planes, stride) 71 | self.bn1 = nn.BatchNorm2d(planes) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.conv2 = conv3x3(planes, planes) 74 | self.bn2 = nn.BatchNorm2d(planes) 75 | self.downsample = downsample 76 | self.stride = stride 77 | 78 | def forward(self, x): 79 | residual = self.convRes(x) 80 | residual = self.bnRes(residual) 81 | residual = self.reluRes(residual) 82 | 83 | out = self.conv1(x) 84 | out = self.bn1(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv2(out) 88 | out = self.bn2(out) 89 | 90 | if self.downsample is not None: 91 | residual = self.downsample(x) 92 | 93 | out += residual 94 | out = self.relu(out) 95 | 96 | return out 97 | 98 | 99 | class Bottleneck(nn.Module): 100 | expansion = 4 101 | 102 | def __init__(self, inplanes, planes, stride=1, downsample=None): 103 | super(Bottleneck, self).__init__() 104 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 105 | self.bn1 = nn.BatchNorm2d(planes) 106 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 107 | padding=1, bias=False) 108 | self.bn2 = nn.BatchNorm2d(planes) 109 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 110 | self.bn3 = nn.BatchNorm2d(planes * 4) 111 | self.relu = nn.ReLU(inplace=True) 112 | self.downsample = downsample 113 | self.stride = stride 114 | 115 | def forward(self, x): 116 | residual = x 117 | 118 | out = self.conv1(x) 119 | out = self.bn1(out) 120 | out = self.relu(out) 121 | 122 | out = self.conv2(out) 123 | out = self.bn2(out) 124 | out = self.relu(out) 125 | 126 | out = self.conv3(out) 127 | out = self.bn3(out) 128 | 129 | if self.downsample is not None: 130 | residual = self.downsample(x) 131 | 132 | out += residual 133 | out = self.relu(out) 134 | 135 | return out 136 | -------------------------------------------------------------------------------- /BASNet/pytorch_iou/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | def _iou(pred, target, size_average = True): 7 | 8 | b = pred.shape[0] 9 | IoU = 0.0 10 | for i in range(0,b): 11 | #compute the IoU of the foreground 12 | Iand1 = torch.sum(target[i,:,:,:]*pred[i,:,:,:]) 13 | Ior1 = torch.sum(target[i,:,:,:]) + torch.sum(pred[i,:,:,:])-Iand1 14 | IoU1 = Iand1/Ior1 15 | 16 | #IoU loss is (1-IoU1) 17 | IoU = IoU + (1-IoU1) 18 | 19 | return IoU/b 20 | 21 | class IOU(torch.nn.Module): 22 | def __init__(self, size_average = True): 23 | super(IOU, self).__init__() 24 | self.size_average = size_average 25 | 26 | def forward(self, pred, target): 27 | 28 | return _iou(pred, target, self.size_average) 29 | -------------------------------------------------------------------------------- /BASNet/pytorch_ssim/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | from math import exp 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 10 | return gauss/gauss.sum() 11 | 12 | def create_window(window_size, channel): 13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 15 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 16 | return window 17 | 18 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 19 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 20 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 21 | 22 | mu1_sq = mu1.pow(2) 23 | mu2_sq = mu2.pow(2) 24 | mu1_mu2 = mu1*mu2 25 | 26 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 27 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 28 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 29 | 30 | C1 = 0.01**2 31 | C2 = 0.03**2 32 | 33 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 34 | 35 | if size_average: 36 | return ssim_map.mean() 37 | else: 38 | return ssim_map.mean(1).mean(1).mean(1) 39 | 40 | class SSIM(torch.nn.Module): 41 | def __init__(self, window_size = 11, size_average = True): 42 | super(SSIM, self).__init__() 43 | self.window_size = window_size 44 | self.size_average = size_average 45 | self.channel = 1 46 | self.window = create_window(window_size, self.channel) 47 | 48 | def forward(self, img1, img2): 49 | (_, channel, _, _) = img1.size() 50 | 51 | if channel == self.channel and self.window.data.type() == img1.data.type(): 52 | window = self.window 53 | else: 54 | window = create_window(self.window_size, channel) 55 | 56 | if img1.is_cuda: 57 | window = window.cuda(img1.get_device()) 58 | window = window.type_as(img1) 59 | 60 | self.window = window 61 | self.channel = channel 62 | 63 | 64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 65 | 66 | def _logssim(img1, img2, window, window_size, channel, size_average = True): 67 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 68 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 69 | 70 | mu1_sq = mu1.pow(2) 71 | mu2_sq = mu2.pow(2) 72 | mu1_mu2 = mu1*mu2 73 | 74 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 75 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 76 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 77 | 78 | C1 = 0.01**2 79 | C2 = 0.03**2 80 | 81 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 82 | ssim_map = (ssim_map - torch.min(ssim_map))/(torch.max(ssim_map)-torch.min(ssim_map)) 83 | ssim_map = -torch.log(ssim_map + 1e-8) 84 | 85 | if size_average: 86 | return ssim_map.mean() 87 | else: 88 | return ssim_map.mean(1).mean(1).mean(1) 89 | 90 | class LOGSSIM(torch.nn.Module): 91 | def __init__(self, window_size = 11, size_average = True): 92 | super(LOGSSIM, self).__init__() 93 | self.window_size = window_size 94 | self.size_average = size_average 95 | self.channel = 1 96 | self.window = create_window(window_size, self.channel) 97 | 98 | def forward(self, img1, img2): 99 | (_, channel, _, _) = img1.size() 100 | 101 | if channel == self.channel and self.window.data.type() == img1.data.type(): 102 | window = self.window 103 | else: 104 | window = create_window(self.window_size, channel) 105 | 106 | if img1.is_cuda: 107 | window = window.cuda(img1.get_device()) 108 | window = window.type_as(img1) 109 | 110 | self.window = window 111 | self.channel = channel 112 | 113 | 114 | return _logssim(img1, img2, window, self.window_size, channel, self.size_average) 115 | 116 | 117 | def ssim(img1, img2, window_size = 11, size_average = True): 118 | (_, channel, _, _) = img1.size() 119 | window = create_window(window_size, channel) 120 | 121 | if img1.is_cuda: 122 | window = window.cuda(img1.get_device()) 123 | window = window.type_as(img1) 124 | 125 | return _ssim(img1, img2, window, window_size, channel, size_average) 126 | -------------------------------------------------------------------------------- /BASNet/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.stats.mstats 3 | import matplotlib.pyplot as plt 4 | from PIL import Image 5 | import os 6 | 7 | def create_dir(dir_name): 8 | if not os.path.exists(dir_name): 9 | os.makedirs(dir_name) 10 | 11 | def corr2_coeff(A, B): 12 | # Rowwise mean of input arrays & subtract from input arrays themeselves 13 | A_mA = A - A.mean() 14 | B_mB = B - B.mean() 15 | 16 | # Sum of squares across rows 17 | ssA = (A_mA ** 2).mean() 18 | ssB = (B_mB ** 2).mean() 19 | 20 | # Finally get corr coeff 21 | coef = (A_mA * B_mB).mean() / np.sqrt(ssA * ssB) 22 | return coef 23 | 24 | def r2coef(gt, pred): 25 | gt_mean = gt.mean() 26 | r2coef = 1 - np.sum((gt - pred) ** 2) / np.sum((gt - gt_mean) ** 2) 27 | return r2coef 28 | 29 | def get_rmse(gt, pred): 30 | return np.sqrt(np.mean((gt - pred) ** 2)) 31 | 32 | def get_kl(gt, pred, chance=1e-5): # Kullback-Leibler divergence 33 | kl = np.sum(gt * np.where(gt > chance, np.log(gt), 0) - gt * np.where(pred >= chance, np.log(pred), 0)) 34 | return kl 35 | 36 | def get_spearmanr(gt, pred): 37 | try: 38 | return scipy.stats.spearmanr(gt, pred)[0] 39 | except: 40 | return 0 41 | 42 | 43 | def label_accuracy(label_trues, label_preds): 44 | """Returns accuracy score evaluation result.""" 45 | 46 | gt = label_trues.astype(np.float64) 47 | pred = label_preds.astype(np.float64) 48 | gt_1d, pred_1d = gt.flatten() / gt.sum(), pred.flatten() / pred.sum() 49 | gt_1d_01, pred_1d_01 = gt.flatten() / 255.0, pred.flatten() / 255.0 50 | 51 | cc = corr2_coeff(gt_1d, pred_1d) 52 | kl = get_kl(gt_1d, pred_1d) 53 | kl_01 = get_kl(gt_1d_01, pred_1d_01) 54 | spearman = get_spearmanr(gt_1d, pred_1d) 55 | r2 = r2coef(gt_1d, pred_1d) 56 | rmse = get_rmse(gt_1d_01, pred_1d_01) 57 | 58 | return kl, kl_01, cc, rmse, r2, spearman 59 | 60 | def overlay_imp_on_img(img, imp, fname, colormap='jet'): 61 | 62 | cm = plt.get_cmap(colormap) # https://matplotlib.org/examples/color/colormaps_reference.html 63 | img2 = np.array(img, dtype=np.uint8) 64 | imp2 = np.array(imp, dtype=np.uint8) 65 | imp3 = (cm(imp2)[:, :, :3] * 255).astype(np.uint8) 66 | img3 = Image.fromarray(img2) 67 | imp3 = Image.fromarray(imp3) 68 | im_alpha = Image.blend(img3, imp3, 0.5) 69 | im_alpha.save(fname) 70 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Max deGroot, Ellis Brown 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Harmonious Textual Layout Generation over Natural Images via Deep Aesthetics Learning 2 | 3 | **Code for the paper [Harmonious Textual Layout Generation over Natural Images via Deep Aesthetics Learning](http://chenhui.li/documents/TextualLayout_TMM2022.pdf) (TMM 2021).** 4 | 5 | ![](./figures/example.png) 6 | 7 | ## Introduction 8 | Automatic typography is important because it helps designers avoid highly repetitive tasks and amateur users achieve high-quality **textual layout** designs. However, there are often many parameters and complicated aesthetic rules that need to be adjusted in automatic typography work. In this paper, we propose an efficient deep aesthetics learning approach to generate harmonious textual layout over natural images, which can be decomposed into two stages, saliency-aware text region proposal and aesthetics-based textual layout selection. Our method incorporates both semantic features and visual perception principles. 9 | First, we propose a semantic **visual saliency detection** network combined with a text region proposal algorithm to generate candidate text anchors with various positions and sizes. Second, a discriminative **deep aesthetics scoring** model is developed to assess the aesthetic quality of the candidate textual layouts. The results demonstrate that our method can generate harmonious textual layouts in various actual scenarios with better performance. 10 | 11 | ## Dependencies and Installation 12 | + Python 3 13 | + PyTorch >= 1.0 14 | 15 | ## Notes of compilation 16 | 17 | 1. For ```Python3``` users, before you start to build the source code and install the packages, please specify the architecture of your GPU card and CUDA_HOME path in both ```./roi_align/make.sh``` and ```./rod_align/make.sh``` 18 | 2. Build and install by running: 19 | ```bash 20 | bash make_all.sh 21 | ``` 22 | 23 | ## Usage 24 | 1. Download the source code and the pretrained models: [gdi-basnet](https://drive.google.com/file/d/1dN_lqywxefd_R4Q93lZck0kEkfKo-wkj/view?usp=sharing) and [SMT](https://drive.google.com/file/d/1zKVA9IGkPtmRkm-2_m7qriaEwVXBuaGX/view?usp=sharing). 25 | 26 | 2. Make sure your device is CUDA enabled. Build and install source code of ```roi_align_api``` and ```rod_align_api```. 27 | 28 | 3. Run SmartText_demo.py to test the pretrained model on your images. 29 | ```bash 30 | python SmartText_demo.py -opt test_opt.yml 31 | ``` 32 | 33 | ## Acknowledgement 34 | 35 | This work is the extension of our [conference version](http://chenhui.li/documents/SmartText_ICME2020.pdf) (ICME 2020). 36 | Some codes of this repository benefit from [BASNet](https://github.com/xuebinqin/BASNet) and [GAIC](https://github.com/lld533/Grid-Anchor-based-Image-Cropping-Pytorch). Thanks for their excellent work! 37 | 38 | ## Citation 39 | 40 | If you find this work useful, please cite our paper: 41 | 42 | ``` 43 | @article{li2021harmonious, 44 | title = {Harmonious Textual Layout Generation over Natural Images via Deep Aesthetics Learning}, 45 | author = {Li, Chenhui and Zhang, Peiying and Wang, Changbo}, 46 | journal = {IEEE Transactions on Multimedia}, 47 | volume = {24}, 48 | pages = {3416--3428}, 49 | year = {2021}, 50 | publisher = {IEEE}, 51 | } 52 | ``` 53 | 54 | ## Contact 55 | 56 | If you have any question, contact us through email at zhangpeiying17@gmail.com. 57 | -------------------------------------------------------------------------------- /ShuffleNetV2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from collections import OrderedDict 6 | from torch.nn import init 7 | import math 8 | 9 | def conv_bn(inp, oup, stride): 10 | return nn.Sequential( 11 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 12 | nn.BatchNorm2d(oup), 13 | nn.ReLU(inplace=True) 14 | ) 15 | 16 | 17 | def conv_1x1_bn(inp, oup): 18 | return nn.Sequential( 19 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 20 | nn.BatchNorm2d(oup), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | def channel_shuffle(x, groups): 25 | batchsize, num_channels, height, width = x.data.size() 26 | 27 | channels_per_group = num_channels // groups 28 | 29 | # reshape 30 | x = x.view(batchsize, groups, 31 | channels_per_group, height, width) 32 | 33 | x = torch.transpose(x, 1, 2).contiguous() 34 | 35 | # flatten 36 | x = x.view(batchsize, -1, height, width) 37 | 38 | return x 39 | 40 | class InvertedResidual(nn.Module): 41 | def __init__(self, inp, oup, stride, benchmodel): 42 | super(InvertedResidual, self).__init__() 43 | self.benchmodel = benchmodel 44 | self.stride = stride 45 | assert stride in [1, 2] 46 | 47 | oup_inc = oup//2 48 | 49 | if self.benchmodel == 1: 50 | #assert inp == oup_inc 51 | self.banch2 = nn.Sequential( 52 | # pw 53 | nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False), 54 | nn.BatchNorm2d(oup_inc), 55 | nn.ReLU(inplace=True), 56 | # dw 57 | nn.Conv2d(oup_inc, oup_inc, 3, stride, 1, groups=oup_inc, bias=False), 58 | nn.BatchNorm2d(oup_inc), 59 | # pw-linear 60 | nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False), 61 | nn.BatchNorm2d(oup_inc), 62 | nn.ReLU(inplace=True), 63 | ) 64 | else: 65 | self.banch1 = nn.Sequential( 66 | # dw 67 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 68 | nn.BatchNorm2d(inp), 69 | # pw-linear 70 | nn.Conv2d(inp, oup_inc, 1, 1, 0, bias=False), 71 | nn.BatchNorm2d(oup_inc), 72 | nn.ReLU(inplace=True), 73 | ) 74 | 75 | self.banch2 = nn.Sequential( 76 | # pw 77 | nn.Conv2d(inp, oup_inc, 1, 1, 0, bias=False), 78 | nn.BatchNorm2d(oup_inc), 79 | nn.ReLU(inplace=True), 80 | # dw 81 | nn.Conv2d(oup_inc, oup_inc, 3, stride, 1, groups=oup_inc, bias=False), 82 | nn.BatchNorm2d(oup_inc), 83 | # pw-linear 84 | nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False), 85 | nn.BatchNorm2d(oup_inc), 86 | nn.ReLU(inplace=True), 87 | ) 88 | 89 | @staticmethod 90 | def _concat(x, out): 91 | # concatenate along channel axis 92 | return torch.cat((x, out), 1) 93 | 94 | def forward(self, x): 95 | if 1==self.benchmodel: 96 | x1 = x[:, :(x.shape[1]//2), :, :] 97 | x2 = x[:, (x.shape[1]//2):, :, :] 98 | out = self._concat(x1, self.banch2(x2)) 99 | elif 2==self.benchmodel: 100 | out = self._concat(self.banch1(x), self.banch2(x)) 101 | 102 | return channel_shuffle(out, 2) 103 | 104 | 105 | class ShuffleNetV2(nn.Module): 106 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 107 | super(ShuffleNetV2, self).__init__() 108 | 109 | assert input_size % 32 == 0 110 | 111 | self.stage_repeats = [4, 8, 4] 112 | # index 0 is invalid and should never be called. 113 | # only used for indexing convenience. 114 | if width_mult == 0.5: 115 | self.stage_out_channels = [-1, 24, 48, 96, 192, 1024] 116 | elif width_mult == 1.0: 117 | self.stage_out_channels = [-1, 24, 116, 232, 464, 1024] 118 | elif width_mult == 1.5: 119 | self.stage_out_channels = [-1, 24, 176, 352, 704, 1024] 120 | elif width_mult == 2.0: 121 | self.stage_out_channels = [-1, 24, 224, 488, 976, 2048] 122 | else: 123 | raise ValueError( 124 | """{} groups is not supported for 125 | 1x1 Grouped Convolutions""".format(num_groups)) 126 | 127 | # building first layer 128 | input_channel = self.stage_out_channels[1] 129 | self.conv1 = conv_bn(3, input_channel, 2) 130 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 131 | 132 | self.features = [] 133 | # building inverted residual blocks 134 | for idxstage in range(len(self.stage_repeats)): 135 | numrepeat = self.stage_repeats[idxstage] 136 | output_channel = self.stage_out_channels[idxstage+2] 137 | for i in range(numrepeat): 138 | if i == 0: 139 | #inp, oup, stride, benchmodel): 140 | self.features.append(InvertedResidual(input_channel, output_channel, 2, 2)) 141 | else: 142 | self.features.append(InvertedResidual(input_channel, output_channel, 1, 1)) 143 | input_channel = output_channel 144 | 145 | 146 | # make it nn.Sequential 147 | self.features = nn.Sequential(*self.features) 148 | 149 | # building last several layers 150 | self.conv_last = conv_1x1_bn(input_channel, self.stage_out_channels[-1]) 151 | self.globalpool = nn.Sequential(nn.AvgPool2d(input_size//32)) 152 | 153 | # building classifier 154 | self.classifier = nn.Sequential(nn.Linear(self.stage_out_channels[-1], n_class)) 155 | 156 | def forward(self, x): 157 | x = self.conv1(x) 158 | x = self.maxpool(x) 159 | x = self.features(x) 160 | x = self.conv_last(x) 161 | x = self.globalpool(x) 162 | x = x.view(-1, self.stage_out_channels[-1]) 163 | x = self.classifier(x) 164 | return x 165 | 166 | def shufflenetv2(width_mult=1.): 167 | model = ShuffleNetV2(width_mult=width_mult) 168 | return model 169 | 170 | if __name__ == "__main__": 171 | """Testing 172 | """ 173 | model = ShuffleNetV2() 174 | print(model) 175 | -------------------------------------------------------------------------------- /SmartText_demo.py: -------------------------------------------------------------------------------- 1 | from smtModel import build_smt_model 2 | from smtDataset import setup_test_dataset 3 | import os 4 | import torch 5 | from torch.autograd import Variable 6 | import torch.backends.cudnn as cudnn 7 | import torch.utils.data as data 8 | import argparse 9 | import time 10 | import math 11 | 12 | from PIL import Image, ImageDraw, ImageFont 13 | import numpy as np 14 | import random 15 | import json 16 | from datetime import date 17 | 18 | from BASNet.model import BASNet 19 | from cal_color import cal_best_color, RGB_to_Hex 20 | import option 21 | from option import sv_json 22 | 23 | import warnings 24 | import torch.multiprocessing 25 | 26 | torch.multiprocessing.set_sharing_strategy('file_system') 27 | 28 | warnings.filterwarnings('ignore') 29 | 30 | SEED = 0 31 | np.random.seed(SEED) 32 | random.seed(SEED) 33 | MOS_MEAN = 2.95 34 | MOS_STD = 0.8 35 | RGB_MEAN = (0.485, 0.456, 0.406) 36 | RGB_STD = (0.229, 0.224, 0.225) 37 | 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.') 40 | opt = option.parse(parser.parse_args().opt) 41 | opt = option.dict_to_nonedict(opt) 42 | 43 | today = date.today().strftime("%Y%m%d") 44 | proc_fa_dir = opt['res_dir'] + opt['model_type'] + '_' + today + '/' 45 | output_dir = proc_fa_dir + 'res/' 46 | 47 | if not os.path.exists(output_dir): 48 | os.makedirs(output_dir) 49 | 50 | if torch.cuda.is_available(): 51 | if opt['cuda']: 52 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 53 | if not opt['cuda']: 54 | print("WARNING: It looks like you have a CUDA device, but aren't " + 55 | "using CUDA.\nRun with --cuda for optimal training speed.") 56 | torch.set_default_tensor_type('torch.FloatTensor') 57 | else: 58 | torch.set_default_tensor_type('torch.FloatTensor') 59 | 60 | # --------- model define --------- 61 | print("...load SMTNet...") 62 | smt_net = build_smt_model(scale='multi', alignsize=9, reddim=8, loadweight=False, model='shufflenetv2', downsample=4) 63 | smt_net.load_state_dict(torch.load(opt['smt_model'])) 64 | smt_net.eval() 65 | 66 | print("...load BASNet...") 67 | visimp_net = BASNet(3, 1) 68 | visimp_net.load_state_dict(torch.load(opt['visimp_model'])) 69 | visimp_net.eval() 70 | 71 | if opt['cuda']: 72 | smt_net = torch.nn.DataParallel(smt_net, device_ids=[0]) 73 | cudnn.benchmark = True 74 | smt_net = smt_net.cuda() 75 | visimp_net = visimp_net.cuda() 76 | 77 | dataset = setup_test_dataset(usr_slogan=opt['usr_slogan'], 78 | font_fp=opt['font_fp'], 79 | visimp_model=visimp_net, 80 | proc_fa_dir=proc_fa_dir, 81 | is_devi=opt['is_devi'], 82 | dataset_dir=opt['input_dir'], 83 | model_type=opt['model_type'], 84 | ratio_list=opt['ratio_list'], 85 | text_spacing=opt['text_spacing'], 86 | exp_prop=opt['exp_prop'], 87 | grid_num=opt['grid_num'], 88 | sali_coef=opt['sali_coef'], 89 | max_text_area_coef=opt['max_text_area_coef'], 90 | min_text_area_coef=opt['min_text_area_coef'], 91 | min_font_size=opt['min_font_size'], 92 | max_font_size=opt['max_font_size'], 93 | font_inc_unit=opt['font_inc_unit']) 94 | 95 | 96 | def naive_collate(batch): 97 | return batch[0] 98 | 99 | 100 | data_loader = data.DataLoader(dataset, 101 | opt['batch_size'], 102 | num_workers=opt['num_workers'], 103 | collate_fn=naive_collate, 104 | shuffle=False) 105 | 106 | 107 | def draw_text_imgpath(imgpath, fsz, fontstr, top_box, res_text_loc, text_spacing, fontcolor, font_loc): 108 | pil_im = Image.open(imgpath) 109 | draw = ImageDraw.Draw(pil_im) 110 | font = ImageFont.truetype(font_loc, fsz, encoding="utf-8") 111 | draw.text((top_box[1], top_box[0]), fontstr, fontcolor, font=font, spacing=text_spacing) 112 | pil_im.save(res_text_loc) 113 | 114 | 115 | def draw_text_cont(pil_im, draw, fsz, fontstr, top_box, res_text_loc, text_spacing, fontcolor, font_loc): 116 | font = ImageFont.truetype(font_loc, fsz, encoding="utf-8") 117 | draw.text((top_box[1], top_box[0]), fontstr, fontcolor, font=font, spacing=text_spacing) 118 | # pil_im.save(res_text_loc) 119 | 120 | 121 | def output_file_name(input_path, sc, idx, dataset_name='SMT', R_type='RoD'): 122 | name = os.path.basename(input_path) 123 | segs = name.split('.') 124 | assert len(segs) >= 2 125 | return '%s_%s_%s_%d_%s.%s' % ('.'.join(segs[:-1]), dataset_name, R_type, idx, sc, segs[-1]) 126 | 127 | 128 | def test_sep(st_id, ed_id, resized_images, bboxs): 129 | roi = [] 130 | st_flg = True 131 | i_cnt = 0 132 | for idx in range(st_id, ed_id): 133 | if (st_flg == True): 134 | in_imgs = torch.unsqueeze(torch.as_tensor(resized_images[idx]), 0) 135 | st_flg = False 136 | 137 | else: 138 | tp_img = torch.unsqueeze(torch.as_tensor(resized_images[idx]), 0) 139 | in_imgs = torch.cat((in_imgs, tp_img), 0) 140 | 141 | roi.append((i_cnt, bboxs['xmin'][idx], bboxs['ymin'][idx], bboxs['xmax'][idx], bboxs['ymax'][idx])) 142 | i_cnt += 1 143 | 144 | if opt['cuda']: 145 | in_imgs = Variable(in_imgs.cuda()) 146 | roi = Variable(torch.Tensor(roi)) 147 | else: 148 | in_imgs = Variable(in_imgs) 149 | roi = Variable(roi) 150 | 151 | out = smt_net(in_imgs, roi) 152 | return out 153 | 154 | 155 | def test(): 156 | 157 | for id, sample in enumerate(data_loader): 158 | st_time = time.time() 159 | imgpath = sample['imgpath'] 160 | bboxes = sample['sourceboxes'] 161 | resized_images = sample['resized_images'] 162 | tbboxes = sample['tbboxes'] 163 | box_list = sample['box_list'] 164 | 165 | len_tbboxes = len(tbboxes['xmin']) 166 | if (len_tbboxes == 0): 167 | continue 168 | 169 | if (opt['model_type'] == 'RoE'): 170 | bat_sz = 16 171 | te_cnt = math.ceil(len_tbboxes * 1.0 / bat_sz) 172 | st = 0 173 | for ite in range(te_cnt): 174 | ed = min(st + bat_sz, len_tbboxes) 175 | sep_out = test_sep(st, ed, resized_images, tbboxes) 176 | if (ite == 0): 177 | cat_out = torch.Tensor(sep_out) 178 | else: 179 | cat_out = torch.cat((cat_out, sep_out), 0) 180 | 181 | st = st + bat_sz 182 | 183 | out = torch.Tensor(cat_out) 184 | 185 | else: 186 | roi = [] 187 | for idx in range(0, len(tbboxes['xmin'])): 188 | roi.append((0, tbboxes['xmin'][idx], tbboxes['ymin'][idx], tbboxes['xmax'][idx], tbboxes['ymax'][idx])) 189 | 190 | resized_image = torch.unsqueeze(torch.as_tensor(resized_images), 0) 191 | if opt['cuda']: 192 | resized_image = Variable(resized_image.cuda()) 193 | roi = Variable(torch.Tensor(roi)) 194 | else: 195 | resized_image = Variable(resized_image) 196 | roi = Variable(torch.Tensor(roi)) 197 | out = smt_net(resized_image, roi) 198 | 199 | print('len_out =', len(out)) 200 | id_out = sorted(range(len(out)), key=lambda k: out[k], reverse=True) 201 | 202 | #--------------------------------- 203 | # find json file in box_dir 204 | base_dat_dir = proc_fa_dir 205 | base_box_dir = base_dat_dir + 'box_dir' + '/' 206 | img_name = imgpath.split('/')[-1] 207 | imgpre, _ = os.path.splitext(img_name) 208 | box_loc = base_box_dir + imgpre + '/' + imgpre + '.json' 209 | with open(box_loc, encoding="utf-8") as f: 210 | box_data = json.load(f) 211 | #--------------------------------- 212 | 213 | impre = imgpath.split('/')[-1].split('.')[0] 214 | len_bboxes = len(bboxes) 215 | for i in range(len_bboxes): 216 | tmp_sc = out[i].cpu().data.numpy().squeeze() 217 | tmp_sc = tmp_sc * MOS_STD + MOS_MEAN 218 | box_data[i][0]['score'] = tmp_sc 219 | 220 | sv_json(box_data, box_loc) 221 | 222 | candi_res = min(opt['candi_res'], len(id_out)) 223 | for id in range(0, candi_res): 224 | top_box = bboxes[id_out[id]] 225 | tmp_sc = str(box_data[id_out[id]][0]['score']) 226 | 227 | # draw each res in sep dir 228 | res_sep_dir = output_dir + impre + '/' 229 | os.makedirs(res_sep_dir, exist_ok=True) 230 | res_text_loc = os.path.join( 231 | res_sep_dir, 232 | output_file_name(input_path=imgpath, 233 | sc=tmp_sc, 234 | idx=id + 1, 235 | dataset_name=opt['dataset_name'], 236 | R_type=opt['model_type'])) 237 | 238 | pil_im = Image.open(imgpath) 239 | draw = ImageDraw.Draw(pil_im) 240 | tl_cnt = box_list[id_out[id]][0]['tl_cnt'] 241 | 242 | if (id == 0): 243 | # im_np = cv2.cvtColor(np.array(pil_im), cv2.COLOR_RGB2BGR) # h, w, c 244 | im_np = np.array(pil_im) # h, w, c 245 | xl = box_list[id_out[id]][0]['xl'] 246 | xr = box_list[id_out[id]][0]['xr'] 247 | yl = box_list[id_out[id]][0]['yl'] 248 | yr = box_list[id_out[id]][0]['yr'] 249 | im_crop_np = im_np[xl:xr, yl:yr] 250 | 251 | # select text color 252 | color_candi = cal_best_color(im_np, im_crop_np, opt['contrast_threshold']) 253 | fontcolor = RGB_to_Hex(color_candi[0]['color']) 254 | print("fontcolor = " + fontcolor) 255 | 256 | for tx in range(1, tl_cnt + 1): 257 | fsz = box_list[id_out[id]][tx]['fsz'] 258 | fontstr = box_list[id_out[id]][tx]['fontstr'] 259 | top_box = [box_list[id_out[id]][tx]['xl'], box_list[id_out[id]][tx]['yl']] 260 | 261 | draw_text_cont(pil_im, 262 | draw, 263 | fsz=fsz, 264 | fontstr=fontstr, 265 | top_box=top_box, 266 | res_text_loc=res_text_loc, 267 | text_spacing=opt['text_spacing'], 268 | fontcolor=fontcolor, 269 | font_loc=opt['font_fp']) 270 | 271 | pil_im.save(res_text_loc) 272 | # draw best res 273 | if (id == 0): 274 | res_text_loc = os.path.join( 275 | output_dir, 276 | output_file_name(input_path=imgpath, 277 | sc=tmp_sc, 278 | idx=id + 1, 279 | dataset_name=opt['dataset_name'], 280 | R_type=opt['model_type'])) 281 | pil_im.save(res_text_loc) 282 | 283 | ed_time = time.time() 284 | print('timer: %.4f sec.' % (ed_time - st_time)) 285 | 286 | 287 | if __name__ == '__main__': 288 | test() 289 | -------------------------------------------------------------------------------- /cal_color.py: -------------------------------------------------------------------------------- 1 | from sklearn.cluster import KMeans 2 | import numpy as np 3 | import time 4 | import cv2 5 | 6 | 7 | def cal_domcolor(img_np, k): 8 | st = time.time() 9 | img_km = img_np.reshape((img_np.shape[0] * img_np.shape[1], img_np.shape[2])) 10 | estimator = KMeans(n_clusters=k, max_iter=300, n_init=2) 11 | estimator.fit(img_km) 12 | centroids = estimator.cluster_centers_ 13 | centroids = sorted(centroids, key=lambda x: (x[0], x[1], x[2])) 14 | ed = time.time() 15 | # print("KMeans.time = " + ed - st) 16 | return centroids 17 | 18 | 19 | def draw_domcolor(centroids, n_channels, sv_fp): 20 | result = [] 21 | res_width = 200 22 | res_height_per = 80 23 | k = len(centroids) 24 | for center_index in range(k): 25 | result.append(np.full((res_width * res_height_per, n_channels), centroids[center_index], dtype=int)) 26 | result = np.array(result, dtype=np.uint8) 27 | result = result.reshape((res_height_per * k, res_width, n_channels)) 28 | 29 | result_bgr = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) 30 | cv2.imwrite(sv_fp, result_bgr) 31 | 32 | 33 | def rgb_distance(rgb): 34 | return abs(rgb[0] - rgb[1]) + abs(rgb[0] - rgb[2]) + abs(rgb[2] - rgb[1]) 35 | 36 | 37 | def RGB_to_Hex(rgb): 38 | color_str = '#' 39 | for i in rgb: 40 | num = int(i) 41 | color_str += str(hex(num))[-2:].replace('x', '0').upper() 42 | return color_str 43 | 44 | 45 | def cal_luminance(rgb): 46 | for i in range(0, len(rgb)): 47 | if rgb[i] <= 0.03928: 48 | rgb[i] = rgb[i] / 12.92 49 | else: 50 | rgb[i] = pow(((rgb[i] + 0.055) / 1.055), 2.4) 51 | l = (0.2126 * rgb[0]) + (0.7152 * rgb[1]) + (0.0722 * rgb[2]) 52 | return l 53 | 54 | 55 | def cal_contrast_rate(rgbA, rgbB): 56 | ratio = 1 57 | l1 = cal_luminance([rgbA[0] / 255, rgbA[1] / 255, rgbA[2] / 255]) 58 | l2 = cal_luminance([rgbB[0] / 255, rgbB[1] / 255, rgbB[2] / 255]) 59 | if l1 >= l2: 60 | ratio = (l1 + .05) / (l2 + .05) 61 | else: 62 | ratio = (l2 + .05) / (l1 + .05) 63 | ratio = round(ratio * 100) / 100 64 | return ratio 65 | 66 | 67 | def cal_best_color(img, img_crop, contrast_threshold=5.5): 68 | color_candidates = cal_domcolor(img, 6) 69 | crop_color = cal_domcolor(img_crop, 1)[0] 70 | color_choose = [] 71 | grey_flag = False 72 | for color in color_candidates: 73 | tmp_cr = cal_contrast_rate(color, crop_color) 74 | if tmp_cr > contrast_threshold: 75 | color_choose.append({"color": color, "contrast_rate": tmp_cr}) 76 | 77 | if len(color_choose) == 0: 78 | grey_flag = True 79 | grey_candidates = [] 80 | for i in range(0, 256, 50): 81 | grey_candidates.append([i, i, i]) 82 | 83 | for grey_color in grey_candidates: 84 | tmp_cr = cal_contrast_rate(grey_color, crop_color) 85 | if tmp_cr > contrast_threshold: 86 | color_choose.append({"color": grey_color, "contrast_rate": tmp_cr}) 87 | 88 | if len(color_choose) == 0: 89 | black_cr = cal_contrast_rate([0, 0, 0], crop_color) 90 | white_cr = cal_contrast_rate([255, 255, 255], crop_color) 91 | if (black_cr > white_cr): 92 | color_choose.append({"color": [0, 0, 0], "contrast_rate": black_cr}) 93 | else: 94 | color_choose.append({"color": [255, 255, 255], "contrast_rate": white_cr}) 95 | 96 | if grey_flag: 97 | color_choose_sorted = sorted(color_choose, key=lambda x: x["contrast_rate"], reverse=True) 98 | 99 | else: 100 | color_choose_sorted = sorted(color_choose, key=lambda x: rgb_distance(x["color"]), reverse=True) 101 | return color_choose_sorted 102 | -------------------------------------------------------------------------------- /figures/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intchous/SmartText/2a4132a09adb71a30c1f5aa5c0596e24d3ea608c/figures/example.png -------------------------------------------------------------------------------- /generate_candidates.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from PIL import Image, ImageDraw, ImageFont 6 | 7 | from option import sv_json 8 | import util_cal as uc 9 | from BASNet.basnet_test import get_imp 10 | 11 | 12 | def gen_boxes_multi(img_name, 13 | visimp_pred_dir, 14 | visimp_pred_dir_ovl, 15 | visimp_model, 16 | usr_slogan, 17 | font_fp, 18 | base_dat_dir, 19 | is_devi=False, 20 | ratio_list=[1, 1, 1, 1, 1], 21 | text_spacing=20, 22 | grid_num=120, 23 | sali_coef=2.6, 24 | max_text_area_coef=17, 25 | min_text_area_coef=7, 26 | min_font_size=10, 27 | max_font_size=500, 28 | font_inc_unit=5): 29 | base_box_dir = base_dat_dir + 'box_dir' + '/' 30 | im_wh_name = img_name.split('/')[-1] 31 | imgpre, imgext = os.path.splitext(im_wh_name) 32 | _, ini_visimp = get_imp(img_name=img_name, 33 | prediction_dir=visimp_pred_dir, 34 | prediction_dir_ovl=visimp_pred_dir_ovl, 35 | visimp_model=visimp_model) 36 | 37 | rescaled = np.array(ini_visimp) 38 | rerow = len(rescaled) 39 | recol = len(rescaled[0]) 40 | 41 | # grid artition 42 | grid_rsz = int(max(rerow, rerow) * 1.0 / grid_num) 43 | if (grid_rsz % 2 == 1): 44 | grid_rsz = grid_rsz - 1 45 | grid_csz = grid_rsz 46 | 47 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 48 | x = torch.FloatTensor(rescaled).to(device) 49 | h, w = x.shape 50 | x = F.avg_pool2d(x.view(1, 1, h, w), kernel_size=grid_rsz) 51 | x = x.cpu().numpy() 52 | crop_mat = np.squeeze(x * grid_rsz * grid_csz) 53 | 54 | crop_row_num = crop_mat.shape[0] 55 | crop_col_num = crop_mat.shape[1] 56 | matrix1D = crop_mat.flatten() 57 | matrixcal = [[0.0 for i in range(crop_col_num)] for i in range(crop_row_num)] 58 | matrix1D = np.sort(matrix1D)[::-1] 59 | 60 | # the larger sali_coef, the smaller area defined as important of the image 61 | Kth = (int)(crop_row_num * crop_col_num / sali_coef) 62 | tmpval = matrix1D[Kth] 63 | 64 | INF = float(1000000007) 65 | for i in range(crop_row_num): 66 | for j in range(crop_col_num): 67 | if (crop_mat[i][j] > tmpval): 68 | matrixcal[i][j] = INF 69 | elif (i <= 3 or j <= 3 or i >= crop_row_num - 4 or j >= crop_col_num - 4): 70 | matrixcal[i][j] = INF 71 | else: 72 | matrixcal[i][j] = crop_mat[i][j] 73 | 74 | ini_tprob_map = np.array(uc.cal_imp_conv(crop_row_num, crop_col_num, crop_mat, matrixcal, matrix1D, INF)) 75 | 76 | min_text_area = rerow * recol / max_text_area_coef 77 | max_text_area = rerow * recol / min_text_area_coef 78 | 79 | slogan_list = usr_slogan.split('\n') 80 | len_slogan_list = len(slogan_list) 81 | if (len_slogan_list > len(ratio_list)): 82 | for i in range(len_slogan_list - len(ratio_list)): 83 | ratio_list.append(1) 84 | 85 | image_name = img_name 86 | rect_im = Image.open(image_name) 87 | draw_rect = ImageDraw.Draw(rect_im) 88 | 89 | box_dir = base_box_dir + imgpre + '/' 90 | os.makedirs(box_dir, exist_ok=True) 91 | 92 | fsz = min_font_size 93 | fsz_intv = font_inc_unit 94 | scnt = 0 95 | anno_list = [] 96 | now_idx = 0 97 | while fsz <= max_font_size: 98 | pil_im = Image.open(image_name) 99 | draw = ImageDraw.Draw(pil_im) 100 | 101 | txarea_x = -text_spacing 102 | txarea_y = 0.0 103 | for tli in range(len_slogan_list): 104 | tli_fsz = int(fsz * ratio_list[tli]) 105 | font = ImageFont.truetype(font_fp, tli_fsz, encoding="utf-8") 106 | fontstr = slogan_list[tli] 107 | tli_txsz = draw.textsize(fontstr, font=font, spacing=text_spacing) 108 | txarea_x = txarea_x + tli_txsz[1] + text_spacing 109 | txarea_y = max(txarea_y, tli_txsz[0]) 110 | 111 | txarea = txarea_x * txarea_y 112 | txsz = [txarea_y, txarea_x] 113 | if ((txarea > max_text_area) or (txarea < min_text_area) or (txarea_y >= recol) or (txarea_x >= rerow)): 114 | fsz += fsz_intv 115 | continue 116 | 117 | Kth_rect = 1 118 | st = uc.get_top_k_submatrix(ini_tprob_map, ((int)(txsz[1] / grid_rsz), (int)(txsz[0] / grid_csz)), 119 | Kth_rect, 120 | desc=False) 121 | 122 | for kth in range(Kth_rect): 123 | stx = st[kth].rx * grid_rsz 124 | sty = st[kth].cy * grid_csz 125 | if ((stx >= rerow) or (stx + txsz[1] >= rerow) or (sty >= recol) or (sty + txsz[0] >= recol)): 126 | continue 127 | 128 | stcol = sty 129 | strow = stx 130 | edcol = sty + txsz[0] 131 | edrow = stx + txsz[1] 132 | scnt += 1 133 | tmp_anno_list = [] 134 | tmp_anno_list.append({ 135 | 'idx': now_idx, 136 | 'xl': strow, 137 | 'yl': stcol, 138 | 'xr': edrow, 139 | 'yr': edcol, 140 | 'tl_cnt': len_slogan_list 141 | }) 142 | now_idx += 1 143 | 144 | stcol = sty 145 | strow = stx 146 | for tli in range(len_slogan_list): 147 | tli_fsz = int(fsz * ratio_list[tli]) 148 | font = ImageFont.truetype(font_fp, tli_fsz, encoding="utf-8") 149 | fontstr = slogan_list[tli] 150 | tli_txsz = draw.textsize(fontstr, font=font, spacing=text_spacing) 151 | edcol = stcol + tli_txsz[0] 152 | edrow = strow + tli_txsz[1] 153 | tmp_anno_list.append({ 154 | 'xl': strow, 155 | 'yl': stcol, 156 | 'xr': edrow, 157 | 'yr': edcol, 158 | 'fsz': tli_fsz, 159 | 'fontstr': fontstr 160 | }) 161 | strow = strow + tli_txsz[1] + text_spacing 162 | 163 | anno_list.append(tmp_anno_list) 164 | 165 | fsz += fsz_intv 166 | 167 | new_anno_list = [] 168 | len_anno_list = len(anno_list) 169 | if (is_devi): 170 | devi_direc = [[-1, -1], [-1, 0], [-1, 1], [0, -1], [0, 1], [1, -1], [1, 0], [1, 1]] 171 | else: 172 | devi_direc = [] 173 | 174 | # deviation unit 175 | devi_unit = grid_rsz * 10 176 | each_box_num = len(devi_direc) 177 | 178 | for ai in range(len_anno_list): 179 | new_anno_list.append(anno_list[ai]) 180 | 181 | for gen_i in range(each_box_num): 182 | new_xl = anno_list[ai][0]['xl'] + devi_direc[gen_i][0] * devi_unit 183 | new_yl = anno_list[ai][0]['yl'] + devi_direc[gen_i][1] * devi_unit 184 | new_xr = new_xl + abs(anno_list[ai][0]['xr'] - anno_list[ai][0]['xl']) 185 | new_yr = new_yl + abs(anno_list[ai][0]['yr'] - anno_list[ai][0]['yl']) 186 | if (new_xl < 0 or (new_xl >= rerow) or (new_yl < 0) or (new_yl >= recol) or new_xr < 0 or (new_xr >= rerow) 187 | or (new_yr < 0) or (new_yr >= recol)): 188 | continue 189 | 190 | tmp_new_anno_list = [] 191 | tmp_new_anno_list.append({ 192 | 'idx': now_idx, 193 | 'xl': new_xl, 194 | 'yl': new_yl, 195 | 'xr': new_xr, 196 | 'yr': new_yr, 197 | 'tl_cnt': anno_list[ai][0]['tl_cnt'] 198 | }) 199 | now_idx += 1 200 | 201 | for tli in range(1, anno_list[ai][0]['tl_cnt'] + 1): 202 | tli_fsz = anno_list[ai][tli]['fsz'] 203 | fontstr = anno_list[ai][tli]['fontstr'] 204 | strow = anno_list[ai][tli]['xl'] + (new_xl - anno_list[ai][0]['xl']) 205 | stcol = anno_list[ai][tli]['yl'] + (new_yl - anno_list[ai][0]['yl']) 206 | edrow = anno_list[ai][tli]['xr'] + (new_xr - anno_list[ai][0]['xr']) 207 | edcol = anno_list[ai][tli]['yr'] + (new_yr - anno_list[ai][0]['yr']) 208 | tmp_new_anno_list.append({ 209 | 'xl': strow, 210 | 'yl': stcol, 211 | 'xr': edrow, 212 | 'yr': edcol, 213 | 'fsz': tli_fsz, 214 | 'fontstr': fontstr 215 | }) 216 | 217 | new_anno_list.append(tmp_new_anno_list) 218 | 219 | sv_json(new_anno_list, box_dir + imgpre + '.json') 220 | 221 | anno_dict = { 222 | 'img_name': img_name, 223 | 'usr_slogan': usr_slogan, 224 | 'font_loc': font_fp, 225 | 'scnt': scnt, 226 | 'now_idx': now_idx, 227 | 'new_anno_list': new_anno_list 228 | } 229 | 230 | return anno_dict, ini_visimp 231 | -------------------------------------------------------------------------------- /make_all.sh: -------------------------------------------------------------------------------- 1 | cd ./roi_align 2 | bash make.sh 3 | 4 | cd ../rod_align 5 | bash make.sh 6 | 7 | cd .. 8 | -------------------------------------------------------------------------------- /mobilenetv2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Creates a MobileNetV2 Model as defined in: 3 | Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen. (2018). 4 | MobileNetV2: Inverted Residuals and Linear Bottlenecks 5 | arXiv preprint arXiv:1801.04381. 6 | import from https://github.com/tonylins/pytorch-mobilenet-v2 7 | """ 8 | 9 | import torch.nn as nn 10 | import math 11 | 12 | __all__ = ['mobilenetv2'] 13 | 14 | 15 | def _make_divisible(v, divisor, min_value=None): 16 | """ 17 | This function is taken from the original tf repo. 18 | It ensures that all layers have a channel number that is divisible by 8 19 | It can be seen here: 20 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 21 | :param v: 22 | :param divisor: 23 | :param min_value: 24 | :return: 25 | """ 26 | if min_value is None: 27 | min_value = divisor 28 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 29 | # Make sure that round down does not go down by more than 10%. 30 | if new_v < 0.9 * v: 31 | new_v += divisor 32 | return new_v 33 | 34 | 35 | def conv_3x3_bn(inp, oup, stride): 36 | return nn.Sequential( 37 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 38 | nn.BatchNorm2d(oup), 39 | nn.ReLU6(inplace=True) 40 | ) 41 | 42 | 43 | def conv_1x1_bn(inp, oup): 44 | return nn.Sequential( 45 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 46 | nn.BatchNorm2d(oup), 47 | nn.ReLU6(inplace=True) 48 | ) 49 | 50 | 51 | class InvertedResidual(nn.Module): 52 | def __init__(self, inp, oup, stride, expand_ratio): 53 | super(InvertedResidual, self).__init__() 54 | assert stride in [1, 2] 55 | 56 | hidden_dim = int(round(inp * expand_ratio)) 57 | self.identity = stride == 1 and inp == oup 58 | 59 | if expand_ratio == 1: 60 | self.conv = nn.Sequential( 61 | # dw 62 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 63 | nn.BatchNorm2d(hidden_dim), 64 | nn.ReLU6(inplace=True), 65 | # pw-linear 66 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 67 | nn.BatchNorm2d(oup), 68 | ) 69 | else: 70 | self.conv = nn.Sequential( 71 | # pw 72 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 73 | nn.BatchNorm2d(hidden_dim), 74 | nn.ReLU6(inplace=True), 75 | # dw 76 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 77 | nn.BatchNorm2d(hidden_dim), 78 | nn.ReLU6(inplace=True), 79 | # pw-linear 80 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 81 | nn.BatchNorm2d(oup), 82 | ) 83 | 84 | def forward(self, x): 85 | if self.identity: 86 | return x + self.conv(x) 87 | else: 88 | return self.conv(x) 89 | 90 | 91 | class MobileNetV2(nn.Module): 92 | def __init__(self, num_classes=1000, input_size=224, width_mult=1.): 93 | super(MobileNetV2, self).__init__() 94 | # setting of inverted residual blocks 95 | self.cfgs = [ 96 | # t, c, n, s 97 | [1, 16, 1, 1], 98 | [6, 24, 2, 2], 99 | [6, 32, 3, 2], 100 | [6, 64, 4, 2], 101 | [6, 96, 3, 1], 102 | [6, 160, 3, 2], 103 | [6, 320, 1, 1], 104 | ] 105 | 106 | # building first layer 107 | assert input_size % 32 == 0 108 | input_channel = _make_divisible(32 * width_mult, 8) 109 | layers = [conv_3x3_bn(3, input_channel, 2)] 110 | # building inverted residual blocks 111 | block = InvertedResidual 112 | for t, c, n, s in self.cfgs: 113 | output_channel = _make_divisible(c * width_mult, 8) 114 | layers.append(block(input_channel, output_channel, s, t)) 115 | input_channel = output_channel 116 | for i in range(1, n): 117 | layers.append(block(input_channel, output_channel, 1, t)) 118 | input_channel = output_channel 119 | self.features = nn.Sequential(*layers) 120 | # building last several layers 121 | output_channel = _make_divisible(1280 * width_mult, 8) if width_mult > 1.0 else 1280 122 | self.conv = conv_1x1_bn(input_channel, output_channel) 123 | self.avgpool = nn.AvgPool2d(input_size // 32, stride=1) 124 | self.classifier = nn.Linear(output_channel, num_classes) 125 | 126 | self._initialize_weights() 127 | 128 | def forward(self, x): 129 | x = self.features(x) 130 | x = self.conv(x) 131 | x = self.avgpool(x) 132 | x = x.view(x.size(0), -1) 133 | x = self.classifier(x) 134 | return x 135 | 136 | def _initialize_weights(self): 137 | for m in self.modules(): 138 | if isinstance(m, nn.Conv2d): 139 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 140 | m.weight.data.normal_(0, math.sqrt(2. / n)) 141 | if m.bias is not None: 142 | m.bias.data.zero_() 143 | elif isinstance(m, nn.BatchNorm2d): 144 | m.weight.data.fill_(1) 145 | m.bias.data.zero_() 146 | elif isinstance(m, nn.Linear): 147 | n = m.weight.size(1) 148 | m.weight.data.normal_(0, 0.01) 149 | m.bias.data.zero_() 150 | 151 | def mobilenetv2(**kwargs): 152 | """ 153 | Constructs a MobileNet V2 model 154 | """ 155 | return MobileNetV2(**kwargs) 156 | 157 | -------------------------------------------------------------------------------- /option.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from collections import OrderedDict 4 | import numpy as np 5 | import json 6 | 7 | try: 8 | from yaml import CLoader as Loader, CDumper as Dumper 9 | except ImportError: 10 | from yaml import Loader, Dumper 11 | 12 | 13 | def OrderedYaml(): 14 | '''yaml orderedDict support''' 15 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 16 | 17 | def dict_representer(dumper, data): 18 | return dumper.represent_dict(data.items()) 19 | 20 | def dict_constructor(loader, node): 21 | return OrderedDict(loader.construct_pairs(node)) 22 | 23 | Dumper.add_representer(OrderedDict, dict_representer) 24 | Loader.add_constructor(_mapping_tag, dict_constructor) 25 | return Loader, Dumper 26 | 27 | 28 | Loader, Dumper = OrderedYaml() 29 | 30 | 31 | def parse(opt_path): 32 | with open(opt_path, mode='r') as f: 33 | opt = yaml.load(f, Loader=Loader) 34 | # export CUDA_VISIBLE_DEVICES 35 | gpu_list = ','.join(str(x) for x in opt['gpu_ids']) 36 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 37 | # print("export CUDA_VISIBLE_DEVICES = " + gpu_list) 38 | return opt 39 | 40 | 41 | def dict2str(opt, indent_l=1): 42 | '''dict to string for logger''' 43 | msg = '' 44 | for k, v in opt.items(): 45 | if isinstance(v, dict): 46 | msg += ' ' * (indent_l * 2) + k + ':[\n' 47 | msg += dict2str(v, indent_l + 1) 48 | msg += ' ' * (indent_l * 2) + ']\n' 49 | else: 50 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 51 | return msg 52 | 53 | 54 | class NoneDict(dict): 55 | 56 | def __missing__(self, key): 57 | return None 58 | 59 | 60 | # convert to NoneDict, which return None for missing key. 61 | def dict_to_nonedict(opt): 62 | if isinstance(opt, dict): 63 | new_opt = dict() 64 | for key, sub_opt in opt.items(): 65 | new_opt[key] = dict_to_nonedict(sub_opt) 66 | return NoneDict(**new_opt) 67 | elif isinstance(opt, list): 68 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 69 | else: 70 | return opt 71 | 72 | 73 | class NpEncoder(json.JSONEncoder): 74 | 75 | def default(self, obj): 76 | if isinstance(obj, np.integer): 77 | return int(obj) 78 | elif isinstance(obj, np.floating): 79 | return float(obj) 80 | elif isinstance(obj, np.ndarray): 81 | return obj.tolist() 82 | else: 83 | return super(NpEncoder, self).default(obj) 84 | 85 | 86 | def sv_json(sv_list, sv_fp): 87 | json_str = json.dumps(sv_list, ensure_ascii=False, indent=1, cls=NpEncoder) 88 | with open(sv_fp, 'w', encoding='utf-8') as f: 89 | f.write(json_str) 90 | -------------------------------------------------------------------------------- /rod_align/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intchous/SmartText/2a4132a09adb71a30c1f5aa5c0596e24d3ea608c/rod_align/__init__.py -------------------------------------------------------------------------------- /rod_align/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intchous/SmartText/2a4132a09adb71a30c1f5aa5c0596e24d3ea608c/rod_align/functions/__init__.py -------------------------------------------------------------------------------- /rod_align/functions/rod_align.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | import rod_align_api 4 | 5 | class RoDAlignFunction(Function): 6 | @staticmethod 7 | def forward(ctx, features, rois, aligned_width, aligned_height, spatial_scale): 8 | batch_size, num_channels, data_height, data_width = features.size() 9 | ctx.save_for_backward(rois, 10 | torch.IntTensor([int(batch_size), 11 | int(num_channels), 12 | int(data_height), 13 | int(data_width), 14 | int(aligned_width), 15 | int(aligned_height)]), 16 | torch.FloatTensor([float(spatial_scale)])) 17 | 18 | num_rois = rois.size(0) 19 | 20 | output = features.new(num_rois, 21 | num_channels, 22 | int(aligned_height), 23 | int(aligned_width)).zero_() 24 | 25 | rod_align_api.forward(int(aligned_height), 26 | int(aligned_width), 27 | float(spatial_scale), 28 | features, 29 | rois, output) 30 | 31 | return output 32 | 33 | @staticmethod 34 | def backward(ctx, grad_output): 35 | rois, core_size, scale = ctx.saved_tensors 36 | spatial_scale = scale[0] 37 | 38 | batch_size, num_channels, data_height, data_width, aligned_width, aligned_height = core_size 39 | 40 | grad_input = rois.new(batch_size, 41 | num_channels, 42 | data_height, 43 | data_width).zero_() 44 | 45 | rod_align_api.backward(int(aligned_height), 46 | int(aligned_width), 47 | float(spatial_scale), 48 | grad_output, 49 | rois, 50 | grad_input) 51 | 52 | return grad_input, None, None, None, None 53 | -------------------------------------------------------------------------------- /rod_align/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd src 3 | echo "Compiling rod_align kernels by nvcc..." 4 | 5 | # Specify the architecture of your NV card below. 6 | # -arch=sm_75 is compatible with the following NV GPU cards, 7 | # GeForce RTX 2080 Ti, RTX 2080, RTX 2070 Quadro RTX 8000, Quadro RTX 6000, Quadro RTX 5000 Tesla T4 8 | # See more at https://raw.githubusercontent.com/stereolabs/zed-yolo/master/libdarknet/Makefile 9 | nvcc -c -o rod_align_kernel.cu.o rod_align_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_75 10 | 11 | cd ../ 12 | # Export CUDA_HOME. And build and install the library. 13 | export CUDA_HOME=/usr/local/cuda-11.1 && python3 setup.py install 14 | 15 | -------------------------------------------------------------------------------- /rod_align/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intchous/SmartText/2a4132a09adb71a30c1f5aa5c0596e24d3ea608c/rod_align/modules/__init__.py -------------------------------------------------------------------------------- /rod_align/modules/rod_align.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.module import Module 2 | from torch.nn.functional import avg_pool2d, max_pool2d 3 | from ..functions.rod_align import RoDAlignFunction 4 | 5 | 6 | class RoDAlign(Module): 7 | def __init__(self, aligned_height, aligned_width, spatial_scale): 8 | super(RoDAlign, self).__init__() 9 | 10 | self.aligned_width = int(aligned_width) 11 | self.aligned_height = int(aligned_height) 12 | self.spatial_scale = float(spatial_scale) 13 | 14 | def forward(self, features, rois): 15 | return RoDAlignFunction.apply(features, 16 | rois, 17 | self.aligned_height, 18 | self.aligned_width, 19 | self.spatial_scale) 20 | 21 | class RoDAlignAvg(Module): 22 | def __init__(self, aligned_height, aligned_width, spatial_scale): 23 | super(RoDAlignAvg, self).__init__() 24 | 25 | self.aligned_width = int(aligned_width) 26 | self.aligned_height = int(aligned_height) 27 | self.spatial_scale = float(spatial_scale) 28 | 29 | def forward(self, features, rois): 30 | x = RoDAlignFunction.apply(features, 31 | rois, 32 | self.aligned_height+1, 33 | self.aligned_width+1, 34 | self.spatial_scale) 35 | return avg_pool2d(x, kernel_size=2, stride=1) 36 | 37 | class RoDAlignMax(Module): 38 | def __init__(self, aligned_height, aligned_width, spatial_scale): 39 | super(RoDAlignMax, self).__init__() 40 | 41 | self.aligned_width = int(aligned_width) 42 | self.aligned_height = int(aligned_height) 43 | self.spatial_scale = float(spatial_scale) 44 | 45 | def forward(self, features, rois): 46 | x = RoDAlignFunction.apply(features, 47 | rois, 48 | self.aligned_height+1, 49 | self.aligned_width+1, 50 | self.spatial_scale) 51 | return max_pool2d(x, kernel_size=2, stride=1) 52 | -------------------------------------------------------------------------------- /rod_align/setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import torch 4 | from pkg_resources import parse_version 5 | 6 | min_version = parse_version('1.0.0') 7 | current_version = parse_version(torch.__version__) 8 | 9 | 10 | if current_version < min_version: #PyTorch before 1.0 11 | from torch.utils.ffi import create_extension 12 | 13 | sources = ['src/roi_align.c'] 14 | headers = ['src/roi_align.h'] 15 | extra_objects = [] 16 | 17 | defines = [] 18 | with_cuda = False 19 | 20 | this_file = os.path.dirname(os.path.realpath(__file__)) 21 | print(this_file) 22 | 23 | if torch.cuda.is_available(): 24 | print('Including CUDA code.') 25 | sources += ['src/rod_align_cuda.c'] 26 | headers += ['src/rod_align_cuda.h'] 27 | defines += [('WITH_CUDA', None)] 28 | with_cuda = True 29 | 30 | extra_objects = ['src/rod_align_kernel.cu.o'] 31 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 32 | 33 | ffi = create_extension( 34 | '_ext.rod_align', 35 | headers=headers, 36 | sources=sources, 37 | define_macros=defines, 38 | relative_to=__file__, 39 | with_cuda=with_cuda, 40 | extra_objects=extra_objects 41 | ) 42 | 43 | if __name__ == '__main__': 44 | ffi.build() 45 | else: # PyTorch 1.0 or later 46 | from setuptools import setup 47 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 48 | 49 | print('Including CUDA code.') 50 | 51 | current_dir = os.path.dirname(os.path.realpath(__file__)) 52 | 53 | setup( 54 | name='rod_align_api', 55 | ext_modules=[ 56 | CUDAExtension( 57 | name='rod_align_api', 58 | sources=['src/rod_align_cuda.cpp', 'src/rod_align_kernel.cu'], 59 | include_dirs=[current_dir]+torch.utils.cpp_extension.include_paths(cuda=True) 60 | ) 61 | ], 62 | cmdclass={ 63 | 'build_ext': BuildExtension 64 | }) 65 | 66 | -------------------------------------------------------------------------------- /rod_align/src/rod_align.cpp: -------------------------------------------------------------------------------- 1 | #include "rod_align.h" 2 | 3 | void RODAlignForwardCpu(const float* bottom_data, const float spatial_scale, const int num_rois, 4 | const int height, const int width, const int channels, 5 | const int aligned_height, const int aligned_width, const float * bottom_rois, 6 | float* top_data); 7 | 8 | void RODAlignBackwardCpu(const float* top_diff, const float spatial_scale, const int num_rois, 9 | const int height, const int width, const int channels, 10 | const int aligned_height, const int aligned_width, const float * bottom_rois, 11 | float* top_data); 12 | 13 | int rod_align_forward(int aligned_height, int aligned_width, float spatial_scale, 14 | torch::Tensor features, torch::Tensor rois, torch::Tensor output) 15 | { 16 | //Grab the input tensor 17 | //float * data_flat = THFloatTensor_data(features); 18 | //float * rois_flat = THFloatTensor_data(rois); 19 | auto data_flat = features.data(); 20 | auto rois_flat = rois.data(); 21 | 22 | //float * output_flat = THFloatTensor_data(output); 23 | auto output_flat = output.data(); 24 | 25 | // Number of ROIs 26 | //int num_rois = THFloatTensor_size(rois, 0); 27 | //int size_rois = THFloatTensor_size(rois, 1); 28 | auto rois_sz = rois.sizes(); 29 | int num_rois = rois_sz[0]; 30 | int size_rois = rois_sz[1]; 31 | 32 | if (size_rois != 5) 33 | { 34 | return 0; 35 | } 36 | 37 | // data height 38 | //int data_height = THFloatTensor_size(features, 2); 39 | // data width 40 | //int data_width = THFloatTensor_size(features, 3); 41 | // Number of channels 42 | //int num_channels = THFloatTensor_size(features, 1); 43 | auto feat_sz = features.sizes(); 44 | int data_height = feat_sz[2]; 45 | int data_width = feat_sz[3]; 46 | int num_channels = feat_sz[1]; 47 | 48 | // do ROIAlignForward 49 | RODAlignForwardCpu(data_flat, spatial_scale, num_rois, data_height, data_width, num_channels, 50 | aligned_height, aligned_width, rois_flat, output_flat); 51 | 52 | return 1; 53 | } 54 | 55 | int rod_align_backward(int aligned_height, int aligned_width, float spatial_scale, 56 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad) 57 | { 58 | //Grab the input tensor 59 | //float * top_grad_flat = THFloatTensor_data(top_grad); 60 | //float * rois_flat = THFloatTensor_data(rois); 61 | 62 | //float * bottom_grad_flat = THFloatTensor_data(bottom_grad); 63 | 64 | auto top_grad_flat = top_grad.data(); 65 | auto rois_flat = rois.data(); 66 | auto bottom_grad_flat = bottom_grad.data(); 67 | 68 | 69 | // Number of ROIs 70 | //int num_rois = THFloatTensor_size(rois, 0); 71 | //int size_rois = THFloatTensor_size(rois, 1); 72 | 73 | auto rois_sz = rois.sizes(); 74 | int num_rois = rois_sz[0]; 75 | int size_rois = rois_sz[1]; 76 | 77 | if (size_rois != 5) 78 | { 79 | return 0; 80 | } 81 | 82 | // batch size 83 | // int batch_size = THFloatTensor_size(bottom_grad, 0); 84 | // data height 85 | //int data_height = THFloatTensor_size(bottom_grad, 2); 86 | // data width 87 | //int data_width = THFloatTensor_size(bottom_grad, 3); 88 | // Number of channels 89 | //int num_channels = THFloatTensor_size(bottom_grad, 1); 90 | auto grad_sz = bottom_grad.sizes(); 91 | int data_height = grad_sz[2]; 92 | int data_width = grad_sz[3]; 93 | int num_channels = grad_sz[1]; 94 | 95 | // do ROIAlignBackward 96 | RODAlignBackwardCpu(top_grad_flat, spatial_scale, num_rois, data_height, 97 | data_width, num_channels, aligned_height, aligned_width, rois_flat, bottom_grad_flat); 98 | 99 | return 1; 100 | } 101 | 102 | void RODAlignForwardCpu(const float* bottom_data, const float spatial_scale, const int num_rois, 103 | const int height, const int width, const int channels, 104 | const int aligned_height, const int aligned_width, const float * bottom_rois, 105 | float* top_data) 106 | { 107 | const int output_size = num_rois * aligned_height * aligned_width * channels; 108 | 109 | int idx = 0; 110 | float bin_size_h = (float)(height - 1.001) / (aligned_height - 1.); 111 | float bin_size_w = (float)(width - 1.001) / (aligned_width - 1.); 112 | for (idx = 0; idx < output_size; ++idx) 113 | { 114 | // (n, c, ph, pw) is an element in the aligned output 115 | int pw = idx % aligned_width; 116 | int ph = (idx / aligned_width) % aligned_height; 117 | int c = (idx / aligned_width / aligned_height) % channels; 118 | int n = idx / aligned_width / aligned_height / channels; 119 | 120 | float roi_batch_ind = bottom_rois[n * 5 + 0]; 121 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale; 122 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale; 123 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale; 124 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale; 125 | 126 | 127 | float h = (float)(ph) * bin_size_h; 128 | float w = (float)(pw) * bin_size_w; 129 | 130 | int hstart = fminf(floor(h), height - 2); 131 | int wstart = fminf(floor(w), width - 2); 132 | 133 | int img_start = roi_batch_ind * channels * height * width; 134 | 135 | // bilinear interpolation 136 | if (h >= roi_start_h && h <= roi_end_h && w >= roi_start_w && w <= roi_end_w){ 137 | top_data[idx] = 0.; 138 | } else { 139 | float h_ratio = h - (float)(hstart); 140 | float w_ratio = w - (float)(wstart); 141 | int upleft = img_start + (c * height + hstart) * width + wstart; 142 | int upright = upleft + 1; 143 | int downleft = upleft + width; 144 | int downright = downleft + 1; 145 | 146 | top_data[idx] = bottom_data[upleft] * (1. - h_ratio) * (1. - w_ratio) 147 | + bottom_data[upright] * (1. - h_ratio) * w_ratio 148 | + bottom_data[downleft] * h_ratio * (1. - w_ratio) 149 | + bottom_data[downright] * h_ratio * w_ratio; 150 | } 151 | } 152 | } 153 | 154 | void RODAlignBackwardCpu(const float* top_diff, const float spatial_scale, const int num_rois, 155 | const int height, const int width, const int channels, 156 | const int aligned_height, const int aligned_width, const float * bottom_rois, 157 | float* bottom_diff) 158 | { 159 | const int output_size = num_rois * aligned_height * aligned_width * channels; 160 | 161 | int idx = 0; 162 | float bin_size_h = (float)(height - 1.001) / (aligned_height - 1.); 163 | float bin_size_w = (float)(width - 1.001) / (aligned_width - 1.); 164 | for (idx = 0; idx < output_size; ++idx) 165 | { 166 | // (n, c, ph, pw) is an element in the aligned output 167 | int pw = idx % aligned_width; 168 | int ph = (idx / aligned_width) % aligned_height; 169 | int c = (idx / aligned_width / aligned_height) % channels; 170 | int n = idx / aligned_width / aligned_height / channels; 171 | 172 | float roi_batch_ind = bottom_rois[n * 5 + 0]; 173 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale; 174 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale; 175 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale; 176 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale; 177 | 178 | float h = (float)(ph) * bin_size_h; 179 | float w = (float)(pw) * bin_size_w; 180 | 181 | int hstart = fminf(floor(h), height - 2); 182 | int wstart = fminf(floor(w), width - 2); 183 | 184 | int img_start = roi_batch_ind * channels * height * width; 185 | 186 | // bilinear interpolation 187 | if (!(h >= roi_start_h && h <= roi_end_h && w >= roi_start_w && w <= roi_end_w)) { 188 | float h_ratio = h - (float)(hstart); 189 | float w_ratio = w - (float)(wstart); 190 | int upleft = img_start + (c * height + hstart) * width + wstart; 191 | int upright = upleft + 1; 192 | int downleft = upleft + width; 193 | int downright = downleft + 1; 194 | 195 | bottom_diff[upleft] += top_diff[idx] * (1. - h_ratio) * (1. - w_ratio); 196 | bottom_diff[upright] += top_diff[idx] * (1. - h_ratio) * w_ratio; 197 | bottom_diff[downleft] += top_diff[idx] * h_ratio * (1. - w_ratio); 198 | bottom_diff[downright] += top_diff[idx] * h_ratio * w_ratio; 199 | } 200 | } 201 | } 202 | 203 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 204 | m.def("forward", &rod_align_forward, "rod_align forward"); 205 | m.def("backward", &rod_align_backward, "rod_align backward"); 206 | } 207 | -------------------------------------------------------------------------------- /rod_align/src/rod_align.h: -------------------------------------------------------------------------------- 1 | #ifndef ROD_ALIGN_H 2 | #define ROD_ALIGN_H 3 | 4 | #include 5 | 6 | int rod_align_forward(int aligned_height, int aligned_width, float spatial_scale, 7 | torch::Tensor features, torch::Tensor rois, torch::Tensor output); 8 | 9 | int rod_align_backward(int aligned_height, int aligned_width, float spatial_scale, 10 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /rod_align/src/rod_align_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "rod_align_kernel.h" 4 | #include "rod_align_cuda.h" 5 | 6 | 7 | int rod_align_forward_cuda(int aligned_height, int aligned_width, float spatial_scale, 8 | torch::Tensor features, torch::Tensor rois, torch::Tensor output) 9 | { 10 | // Grab the input tensor 11 | //float * data_flat = THCudaTensor_data(state, features); 12 | //float * rois_flat = THCudaTensor_data(state, rois); 13 | 14 | //float * output_flat = THCudaTensor_data(state, output); 15 | 16 | auto data_flat = features.data(); 17 | auto rois_flat = rois.data(); 18 | auto output_flat = output.data(); 19 | 20 | // Number of ROIs 21 | //int num_rois = THCudaTensor_size(state, rois, 0); 22 | //int size_rois = THCudaTensor_size(state, rois, 1); 23 | 24 | auto rois_sz = rois.sizes(); 25 | int num_rois = rois_sz[0]; 26 | int size_rois = rois_sz[1]; 27 | 28 | if (size_rois != 5) 29 | { 30 | return 0; 31 | } 32 | 33 | // data height 34 | //int data_height = THCudaTensor_size(state, features, 2); 35 | // data width 36 | //int data_width = THCudaTensor_size(state, features, 3); 37 | // Number of channels 38 | //int num_channels = THCudaTensor_size(state, features, 1); 39 | auto feat_sz = features.sizes(); 40 | int data_height = feat_sz[2]; 41 | int data_width = feat_sz[3]; 42 | int num_channels = feat_sz[1]; 43 | 44 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 45 | 46 | RODAlignForwardLaucher( 47 | data_flat, spatial_scale, num_rois, data_height, 48 | data_width, num_channels, aligned_height, 49 | aligned_width, rois_flat, 50 | output_flat, stream); 51 | 52 | return 1; 53 | } 54 | 55 | int rod_align_backward_cuda(int aligned_height, int aligned_width, float spatial_scale, 56 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad) 57 | { 58 | // Grab the input tensor 59 | //float * top_grad_flat = THCudaTensor_data(state, top_grad); 60 | //float * rois_flat = THCudaTensor_data(state, rois); 61 | 62 | //float * bottom_grad_flat = THCudaTensor_data(state, bottom_grad); 63 | auto top_grad_flat = top_grad.data(); 64 | auto rois_flat = rois.data(); 65 | auto bottom_grad_flat = bottom_grad.data(); 66 | 67 | // Number of ROIs 68 | //int num_rois = THCudaTensor_size(state, rois, 0); 69 | //int size_rois = THCudaTensor_size(state, rois, 1); 70 | auto rois_sz = rois.sizes(); 71 | int num_rois = rois_sz[0]; 72 | int size_rois = rois_sz[1]; 73 | if (size_rois != 5) 74 | { 75 | return 0; 76 | } 77 | 78 | // batch size 79 | //int batch_size = THCudaTensor_size(state, bottom_grad, 0); 80 | // data height 81 | //int data_height = THCudaTensor_size(state, bottom_grad, 2); 82 | // data width 83 | //int data_width = THCudaTensor_size(state, bottom_grad, 3); 84 | // Number of channels 85 | //int num_channels = THCudaTensor_size(state, bottom_grad, 1); 86 | 87 | auto grad_sz = bottom_grad.sizes(); 88 | int batch_size = grad_sz[0]; 89 | int data_height = grad_sz[2]; 90 | int data_width = grad_sz[3]; 91 | int num_channels = grad_sz[1]; 92 | 93 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 94 | RODAlignBackwardLaucher( 95 | top_grad_flat, spatial_scale, batch_size, num_rois, data_height, 96 | data_width, num_channels, aligned_height, 97 | aligned_width, rois_flat, 98 | bottom_grad_flat, stream); 99 | 100 | return 1; 101 | } 102 | 103 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 104 | m.def("forward", &rod_align_forward_cuda, "rod_align forward"); 105 | m.def("backward", &rod_align_backward_cuda, "rod_align backward"); 106 | } 107 | -------------------------------------------------------------------------------- /rod_align/src/rod_align_cuda.h: -------------------------------------------------------------------------------- 1 | #ifndef ROD_ALIGN_CUDA_H 2 | #define ROD_ALIGN_CUDA_H 3 | 4 | #include 5 | 6 | int rod_align_forward_cuda(int aligned_height, int aligned_width, float spatial_scale, 7 | torch::Tensor features, torch::Tensor rois, torch::Tensor output); 8 | 9 | int rod_align_backward_cuda(int aligned_height, int aligned_width, float spatial_scale, 10 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /rod_align/src/rod_align_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "rod_align_kernel.h" 3 | 4 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 5 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ 6 | i += blockDim.x * gridDim.x) 7 | 8 | 9 | __global__ void RODAlignForward(const int nthreads, const float* bottom_data, const float spatial_scale, const int height, const int width, 10 | const int channels, const int aligned_height, const int aligned_width, const float* bottom_rois, float* top_data) { 11 | float bin_size_h = (float)(height - 1.001) / (aligned_height - 1.); 12 | float bin_size_w = (float)(width - 1.001) / (aligned_width - 1.); 13 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 14 | // (n, c, ph, pw) is an element in the aligned output 15 | // int n = index; 16 | // int pw = n % aligned_width; 17 | // n /= aligned_width; 18 | // int ph = n % aligned_height; 19 | // n /= aligned_height; 20 | // int c = n % channels; 21 | // n /= channels; 22 | 23 | int pw = index % aligned_width; 24 | int ph = (index / aligned_width) % aligned_height; 25 | int c = (index / aligned_width / aligned_height) % channels; 26 | int n = index / aligned_width / aligned_height / channels; 27 | 28 | // bottom_rois += n * 5; 29 | float roi_batch_ind = bottom_rois[n * 5 + 0]; 30 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale; 31 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale; 32 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale; 33 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale; 34 | 35 | 36 | float h = (float)(ph) * bin_size_h; 37 | float w = (float)(pw) * bin_size_w; 38 | 39 | int hstart = fminf(floor(h), height - 2); 40 | int wstart = fminf(floor(w), width - 2); 41 | 42 | int img_start = roi_batch_ind * channels * height * width; 43 | 44 | // bilinear interpolation 45 | if (h >= roi_start_h && h <= roi_end_h && w >= roi_start_w && w <= roi_end_w){ 46 | top_data[index] = 0.; 47 | } else { 48 | float h_ratio = h - (float)(hstart); 49 | float w_ratio = w - (float)(wstart); 50 | int upleft = img_start + (c * height + hstart) * width + wstart; 51 | int upright = upleft + 1; 52 | int downleft = upleft + width; 53 | int downright = downleft + 1; 54 | 55 | top_data[index] = bottom_data[upleft] * (1. - h_ratio) * (1. - w_ratio) 56 | + bottom_data[upright] * (1. - h_ratio) * w_ratio 57 | + bottom_data[downleft] * h_ratio * (1. - w_ratio) 58 | + bottom_data[downright] * h_ratio * w_ratio; 59 | } 60 | } 61 | } 62 | 63 | 64 | int RODAlignForwardLaucher(const float* bottom_data, const float spatial_scale, const int num_rois, const int height, const int width, 65 | const int channels, const int aligned_height, const int aligned_width, const float* bottom_rois, float* top_data, cudaStream_t stream) { 66 | const int kThreadsPerBlock = 1024; 67 | const int output_size = num_rois * aligned_height * aligned_width * channels; 68 | cudaError_t err; 69 | 70 | 71 | RODAlignForward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>( 72 | output_size, bottom_data, spatial_scale, height, width, channels, 73 | aligned_height, aligned_width, bottom_rois, top_data); 74 | 75 | err = cudaGetLastError(); 76 | if(cudaSuccess != err) { 77 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 78 | exit( -1 ); 79 | } 80 | 81 | return 1; 82 | } 83 | 84 | 85 | __global__ void RODAlignBackward(const int nthreads, const float* top_diff, const float spatial_scale, const int height, const int width, 86 | const int channels, const int aligned_height, const int aligned_width, float* bottom_diff, const float* bottom_rois) { 87 | float bin_size_h = (float)(height - 1.001) / (aligned_height - 1.); 88 | float bin_size_w = (float)(width - 1.001) / (aligned_width - 1.); 89 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 90 | 91 | // (n, c, ph, pw) is an element in the aligned output 92 | int pw = index % aligned_width; 93 | int ph = (index / aligned_width) % aligned_height; 94 | int c = (index / aligned_width / aligned_height) % channels; 95 | int n = index / aligned_width / aligned_height / channels; 96 | 97 | float roi_batch_ind = bottom_rois[n * 5 + 0]; 98 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale; 99 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale; 100 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale; 101 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale; 102 | 103 | 104 | float h = (float)(ph) * bin_size_h; 105 | float w = (float)(pw) * bin_size_w; 106 | 107 | int hstart = fminf(floor(h), height - 2); 108 | int wstart = fminf(floor(w), width - 2); 109 | 110 | int img_start = roi_batch_ind * channels * height * width; 111 | 112 | // bilinear interpolation 113 | if (!(h >= roi_start_h && h <= roi_end_h && w >= roi_start_w && w <= roi_end_w)) { 114 | float h_ratio = h - (float)(hstart); 115 | float w_ratio = w - (float)(wstart); 116 | int upleft = img_start + (c * height + hstart) * width + wstart; 117 | int upright = upleft + 1; 118 | int downleft = upleft + width; 119 | int downright = downleft + 1; 120 | 121 | atomicAdd(bottom_diff + upleft, top_diff[index] * (1. - h_ratio) * (1 - w_ratio)); 122 | atomicAdd(bottom_diff + upright, top_diff[index] * (1. - h_ratio) * w_ratio); 123 | atomicAdd(bottom_diff + downleft, top_diff[index] * h_ratio * (1 - w_ratio)); 124 | atomicAdd(bottom_diff + downright, top_diff[index] * h_ratio * w_ratio); 125 | } 126 | } 127 | } 128 | 129 | int RODAlignBackwardLaucher(const float* top_diff, const float spatial_scale, const int batch_size, const int num_rois, const int height, const int width, 130 | const int channels, const int aligned_height, const int aligned_width, const float* bottom_rois, float* bottom_diff, cudaStream_t stream) { 131 | const int kThreadsPerBlock = 1024; 132 | const int output_size = num_rois * aligned_height * aligned_width * channels; 133 | cudaError_t err; 134 | 135 | RODAlignBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>( 136 | output_size, top_diff, spatial_scale, height, width, channels, 137 | aligned_height, aligned_width, bottom_diff, bottom_rois); 138 | 139 | err = cudaGetLastError(); 140 | if(cudaSuccess != err) { 141 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 142 | exit( -1 ); 143 | } 144 | 145 | return 1; 146 | } 147 | -------------------------------------------------------------------------------- /rod_align/src/rod_align_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _ROD_ALIGN_KERNEL 2 | #define _ROD_ALIGN_KERNEL 3 | 4 | #include 5 | 6 | __global__ void RODAlignForward(const int nthreads, const float* bottom_data, 7 | const float spatial_scale, const int height, const int width, 8 | const int channels, const int aligned_height, const int aligned_width, 9 | const float* bottom_rois, float* top_data); 10 | 11 | int RODAlignForwardLaucher( 12 | const float* bottom_data, const float spatial_scale, const int num_rois, const int height, 13 | const int width, const int channels, const int aligned_height, 14 | const int aligned_width, const float* bottom_rois, 15 | float* top_data, cudaStream_t stream); 16 | 17 | __global__ void RODAlignBackward(const int nthreads, const float* top_diff, 18 | const float spatial_scale, const int height, const int width, 19 | const int channels, const int aligned_height, const int aligned_width, 20 | float* bottom_diff, const float* bottom_rois); 21 | 22 | int RODAlignBackwardLaucher(const float* top_diff, const float spatial_scale, const int batch_size, const int num_rois, 23 | const int height, const int width, const int channels, const int aligned_height, 24 | const int aligned_width, const float* bottom_rois, 25 | float* bottom_diff, cudaStream_t stream); 26 | 27 | #endif 28 | 29 | -------------------------------------------------------------------------------- /roi_align/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intchous/SmartText/2a4132a09adb71a30c1f5aa5c0596e24d3ea608c/roi_align/__init__.py -------------------------------------------------------------------------------- /roi_align/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intchous/SmartText/2a4132a09adb71a30c1f5aa5c0596e24d3ea608c/roi_align/functions/__init__.py -------------------------------------------------------------------------------- /roi_align/functions/roi_align.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | import roi_align_api 4 | 5 | class RoIAlignFunction(Function): 6 | @staticmethod 7 | def forward(ctx, features, rois, aligned_height, aligned_width, spatial_scale): 8 | batch_size, num_channels, data_height, data_width = features.size() 9 | ctx.save_for_backward(rois, 10 | torch.IntTensor([int(batch_size), 11 | int(num_channels), 12 | int(data_height), 13 | int(data_width), 14 | int(aligned_height), 15 | int(aligned_width)]), 16 | torch.FloatTensor([float(spatial_scale)])) 17 | 18 | num_rois = rois.size(0) 19 | 20 | output = features.new(num_rois, 21 | num_channels, 22 | int(aligned_height), 23 | int(aligned_width)).zero_() 24 | 25 | roi_align_api.forward(int(aligned_height), 26 | int(aligned_width), 27 | float(spatial_scale), 28 | features, 29 | rois, 30 | output) 31 | return output 32 | 33 | @staticmethod 34 | def backward(ctx, grad_output): 35 | rois, core_size, scale = ctx.saved_tensors 36 | 37 | batch_size, num_channels, data_height, data_width, aligned_height, aligned_width = core_size 38 | spatial_scale = scale[0] 39 | 40 | grad_input = rois.new(batch_size, 41 | num_channels, 42 | data_height, 43 | data_width).zero_() 44 | 45 | roi_align_api.backward(int(aligned_height), 46 | int(aligned_width), 47 | float(spatial_scale), 48 | grad_output, 49 | rois, 50 | grad_input) 51 | 52 | return grad_input, None, None, None, None 53 | -------------------------------------------------------------------------------- /roi_align/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd src 3 | echo "Compiling roi_align kernels by nvcc..." 4 | 5 | # Specify the architecture of your NV card below. 6 | # -arch=sm_75 is compatible with the following NV GPU cards, 7 | # GeForce RTX 2080 Ti, RTX 2080, RTX 2070 Quadro RTX 8000, Quadro RTX 6000, Quadro RTX 5000 Tesla T4 8 | # See more https://raw.githubusercontent.com/stereolabs/zed-yolo/master/libdarknet/Makefile 9 | nvcc -c -o roi_align_kernel.cu.o roi_align_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_75 10 | 11 | cd ../ 12 | # Export CUDA_HOME. Build and install the library. 13 | export CUDA_HOME=/usr/local/cuda-11.1 && python3 setup.py install 14 | -------------------------------------------------------------------------------- /roi_align/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intchous/SmartText/2a4132a09adb71a30c1f5aa5c0596e24d3ea608c/roi_align/modules/__init__.py -------------------------------------------------------------------------------- /roi_align/modules/roi_align.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.module import Module 2 | from torch.nn.functional import avg_pool2d, max_pool2d 3 | from ..functions.roi_align import RoIAlignFunction 4 | 5 | 6 | class RoIAlign(Module): 7 | def __init__(self, aligned_height, aligned_width, spatial_scale): 8 | super(RoIAlign, self).__init__() 9 | 10 | self.aligned_width = int(aligned_width) 11 | self.aligned_height = int(aligned_height) 12 | self.spatial_scale = float(spatial_scale) 13 | 14 | def forward(self, features, rois): 15 | return RoIAlignFunction.apply(features, 16 | rois, 17 | self.aligned_height, 18 | self.aligned_width, 19 | self.spatial_scale) 20 | 21 | class RoIAlignAvg(Module): 22 | def __init__(self, aligned_height, aligned_width, spatial_scale): 23 | super(RoIAlignAvg, self).__init__() 24 | 25 | self.aligned_width = int(aligned_width) 26 | self.aligned_height = int(aligned_height) 27 | self.spatial_scale = float(spatial_scale) 28 | 29 | def forward(self, features, rois): 30 | x = RoIAlignFunction.apply(features, 31 | rois, 32 | self.aligned_height+1, 33 | self.aligned_width+1, 34 | self.spatial_scale) 35 | return avg_pool2d(x, kernel_size=2, stride=1) 36 | 37 | class RoIAlignMax(Module): 38 | def __init__(self, aligned_height, aligned_width, spatial_scale): 39 | super(RoIAlignMax, self).__init__() 40 | 41 | self.aligned_width = int(aligned_width) 42 | self.aligned_height = int(aligned_height) 43 | self.spatial_scale = float(spatial_scale) 44 | 45 | def forward(self, features, rois): 46 | x = RoIAlignFunction.apply(features, 47 | rois, 48 | self.aligned_height+1, 49 | self.aligned_width+1, 50 | self.spatial_scale) 51 | return max_pool2d(x, kernel_size=2, stride=1) 52 | -------------------------------------------------------------------------------- /roi_align/setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import torch 4 | from pkg_resources import parse_version 5 | 6 | min_version = parse_version('1.0.0') 7 | current_version = parse_version(torch.__version__) 8 | 9 | 10 | if current_version < min_version: #PyTorch before 1.0 11 | from torch.utils.ffi import create_extension 12 | 13 | sources = ['src/roi_align.c'] 14 | headers = ['src/roi_align.h'] 15 | extra_objects = [] 16 | #sources = [] 17 | #headers = [] 18 | defines = [] 19 | with_cuda = False 20 | 21 | this_file = os.path.dirname(os.path.realpath(__file__)) 22 | print(this_file) 23 | 24 | if torch.cuda.is_available(): 25 | print('Including CUDA code.') 26 | sources += ['src/roi_align_cuda.c'] 27 | headers += ['src/roi_align_cuda.h'] 28 | defines += [('WITH_CUDA', None)] 29 | with_cuda = True 30 | 31 | extra_objects = ['src/roi_align_kernel.cu.o'] 32 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 33 | 34 | ffi = create_extension( 35 | '_ext.roi_align', 36 | headers=headers, 37 | sources=sources, 38 | define_macros=defines, 39 | relative_to=__file__, 40 | with_cuda=with_cuda, 41 | extra_objects=extra_objects 42 | ) 43 | 44 | if __name__ == '__main__': 45 | ffi.build() 46 | else: # PyTorch 1.0 or later 47 | from setuptools import setup 48 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 49 | 50 | print('Including CUDA code.') 51 | current_dir = os.path.dirname(os.path.realpath(__file__)) 52 | #cuda_include = '/usr/local/cuda-10.0/include' 53 | 54 | #GPU version 55 | setup( 56 | name='roi_align_api', 57 | ext_modules=[ 58 | CUDAExtension( 59 | name='roi_align_api', 60 | sources=['src/roi_align_cuda.cpp', 'src/roi_align_kernel.cu'], 61 | include_dirs=[current_dir]+torch.utils.cpp_extension.include_paths(cuda=True) 62 | ) 63 | ], 64 | cmdclass={ 65 | 'build_ext': BuildExtension 66 | }) 67 | -------------------------------------------------------------------------------- /roi_align/src/roi_align.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "roi_align.h" 5 | 6 | void ROIAlignForwardCpu(const float* bottom_data, const float spatial_scale, const int num_rois, 7 | const int height, const int width, const int channels, 8 | const int aligned_height, const int aligned_width, const float * bottom_rois, 9 | float* top_data); 10 | 11 | void ROIAlignBackwardCpu(const float* top_diff, const float spatial_scale, const int num_rois, 12 | const int height, const int width, const int channels, 13 | const int aligned_height, const int aligned_width, const float * bottom_rois, 14 | float* top_data); 15 | 16 | int roi_align_forward(int aligned_height, int aligned_width, float spatial_scale, 17 | torch::Tensor features, torch::Tensor rois, torch::Tensor output) 18 | { 19 | //Grab the input tensor 20 | //float * data_flat = THFloatTensor_data(features); 21 | //float * rois_flat = THFloatTensor_data(rois); 22 | auto data_flat = features.data(); 23 | auto rois_flat = rois.data(); 24 | 25 | //float * output_flat = THFloatTensor_data(output); 26 | auto output_flat = output.data(); 27 | 28 | // Number of ROIs 29 | //int num_rois = THFloatTensor_size(rois, 0); 30 | //int size_rois = THFloatTensor_size(rois, 1); 31 | 32 | auto rois_sz = rois.sizes(); 33 | int num_rois = rois_sz[0]; 34 | int size_rois = rois_sz[1]; 35 | 36 | if (size_rois != 5) 37 | { 38 | return 0; 39 | } 40 | 41 | // data height 42 | //int data_height = THFloatTensor_size(features, 2); 43 | // data width 44 | //int data_width = THFloatTensor_size(features, 3); 45 | // Number of channels 46 | //int num_channels = THFloatTensor_size(features, 1); 47 | auto feat_sz = features.sizes(); 48 | int data_height = feat_sz[2]; 49 | int data_width = feat_sz[3]; 50 | int num_channels = feat_sz[1]; 51 | 52 | // do ROIAlignForward 53 | ROIAlignForwardCpu(data_flat, spatial_scale, num_rois, data_height, data_width, num_channels, 54 | aligned_height, aligned_width, rois_flat, output_flat); 55 | 56 | return 1; 57 | } 58 | 59 | int roi_align_backward(int aligned_height, int aligned_width, float spatial_scale, 60 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad) 61 | { 62 | //Grab the input tensor 63 | //float * top_grad_flat = THFloatTensor_data(top_grad); 64 | //float * rois_flat = THFloatTensor_data(rois); 65 | 66 | //float * bottom_grad_flat = THFloatTensor_data(bottom_grad); 67 | auto top_grad_flat = top_grad.data(); 68 | auto rois_flat = rois.data(); 69 | auto bottom_grad_flat = bottom_grad.data(); 70 | 71 | // Number of ROIs 72 | //int num_rois = THFloatTensor_size(rois, 0); 73 | //int size_rois = THFloatTensor_size(rois, 1); 74 | auto rois_sz = rois.sizes(); 75 | int num_rois = rois_sz[0]; 76 | int size_rois = rois_sz[1]; 77 | if (size_rois != 5) 78 | { 79 | return 0; 80 | } 81 | 82 | // batch size 83 | // int batch_size = THFloatTensor_size(bottom_grad, 0); 84 | // data height 85 | //int data_height = THFloatTensor_size(bottom_grad, 2); 86 | // data width 87 | //int data_width = THFloatTensor_size(bottom_grad, 3); 88 | // Number of channels 89 | //int num_channels = THFloatTensor_size(bottom_grad, 1); 90 | 91 | auto grad_sz = bottom_grad.sizes(); 92 | int data_height = grad_sz[2]; 93 | int data_width = grad_sz[3]; 94 | int num_channels = grad_sz[1]; 95 | 96 | // do ROIAlignBackward 97 | ROIAlignBackwardCpu(top_grad_flat, spatial_scale, num_rois, data_height, 98 | data_width, num_channels, aligned_height, aligned_width, rois_flat, bottom_grad_flat); 99 | 100 | return 1; 101 | } 102 | 103 | void ROIAlignForwardCpu(const float* bottom_data, const float spatial_scale, const int num_rois, 104 | const int height, const int width, const int channels, 105 | const int aligned_height, const int aligned_width, const float * bottom_rois, 106 | float* top_data) 107 | { 108 | const int output_size = num_rois * aligned_height * aligned_width * channels; 109 | 110 | int idx = 0; 111 | for (idx = 0; idx < output_size; ++idx) 112 | { 113 | // (n, c, ph, pw) is an element in the aligned output 114 | int pw = idx % aligned_width; 115 | int ph = (idx / aligned_width) % aligned_height; 116 | int c = (idx / aligned_width / aligned_height) % channels; 117 | int n = idx / aligned_width / aligned_height / channels; 118 | 119 | float roi_batch_ind = bottom_rois[n * 5 + 0]; 120 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale; 121 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale; 122 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale; 123 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale; 124 | 125 | // Force malformed ROI to be 1x1 126 | float roi_width = fmaxf(roi_end_w - roi_start_w + 1., 0.); 127 | float roi_height = fmaxf(roi_end_h - roi_start_h + 1., 0.); 128 | float bin_size_h = roi_height / (aligned_height - 1.); 129 | float bin_size_w = roi_width / (aligned_width - 1.); 130 | 131 | float h = (float)(ph) * bin_size_h + roi_start_h; 132 | float w = (float)(pw) * bin_size_w + roi_start_w; 133 | 134 | int hstart = fminf(floor(h), height - 2); 135 | int wstart = fminf(floor(w), width - 2); 136 | 137 | int img_start = roi_batch_ind * channels * height * width; 138 | 139 | // bilinear interpolation 140 | if (h < 0 || h >= height || w < 0 || w >= width) 141 | { 142 | top_data[idx] = 0.; 143 | } 144 | else 145 | { 146 | float h_ratio = h - (float)(hstart); 147 | float w_ratio = w - (float)(wstart); 148 | int upleft = img_start + (c * height + hstart) * width + wstart; 149 | int upright = upleft + 1; 150 | int downleft = upleft + width; 151 | int downright = downleft + 1; 152 | 153 | top_data[idx] = bottom_data[upleft] * (1. - h_ratio) * (1. - w_ratio) 154 | + bottom_data[upright] * (1. - h_ratio) * w_ratio 155 | + bottom_data[downleft] * h_ratio * (1. - w_ratio) 156 | + bottom_data[downright] * h_ratio * w_ratio; 157 | } 158 | } 159 | } 160 | 161 | void ROIAlignBackwardCpu(const float* top_diff, const float spatial_scale, const int num_rois, 162 | const int height, const int width, const int channels, 163 | const int aligned_height, const int aligned_width, const float * bottom_rois, 164 | float* bottom_diff) 165 | { 166 | const int output_size = num_rois * aligned_height * aligned_width * channels; 167 | 168 | int idx = 0; 169 | for (idx = 0; idx < output_size; ++idx) 170 | { 171 | // (n, c, ph, pw) is an element in the aligned output 172 | int pw = idx % aligned_width; 173 | int ph = (idx / aligned_width) % aligned_height; 174 | int c = (idx / aligned_width / aligned_height) % channels; 175 | int n = idx / aligned_width / aligned_height / channels; 176 | 177 | float roi_batch_ind = bottom_rois[n * 5 + 0]; 178 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale; 179 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale; 180 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale; 181 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale; 182 | 183 | // Force malformed ROI to be 1x1 184 | float roi_width = fmaxf(roi_end_w - roi_start_w + 1., 0.); 185 | float roi_height = fmaxf(roi_end_h - roi_start_h + 1., 0.); 186 | float bin_size_h = roi_height / (aligned_height - 1.); 187 | float bin_size_w = roi_width / (aligned_width - 1.); 188 | 189 | float h = (float)(ph) * bin_size_h + roi_start_h; 190 | float w = (float)(pw) * bin_size_w + roi_start_w; 191 | 192 | int hstart = fminf(floor(h), height - 2); 193 | int wstart = fminf(floor(w), width - 2); 194 | 195 | int img_start = roi_batch_ind * channels * height * width; 196 | 197 | // bilinear interpolation 198 | if (h < 0 || h >= height || w < 0 || w >= width) 199 | { 200 | float h_ratio = h - (float)(hstart); 201 | float w_ratio = w - (float)(wstart); 202 | int upleft = img_start + (c * height + hstart) * width + wstart; 203 | int upright = upleft + 1; 204 | int downleft = upleft + width; 205 | int downright = downleft + 1; 206 | 207 | bottom_diff[upleft] += top_diff[idx] * (1. - h_ratio) * (1. - w_ratio); 208 | bottom_diff[upright] += top_diff[idx] * (1. - h_ratio) * w_ratio; 209 | bottom_diff[downleft] += top_diff[idx] * h_ratio * (1. - w_ratio); 210 | bottom_diff[downright] += top_diff[idx] * h_ratio * w_ratio; 211 | } 212 | } 213 | } 214 | 215 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 216 | m.def("forward", &roi_align_forward, "roi_align forward"); 217 | m.def("backward", &roi_align_backward, "roi_align backward"); 218 | } 219 | -------------------------------------------------------------------------------- /roi_align/src/roi_align.h: -------------------------------------------------------------------------------- 1 | #ifndef ROI_ALIGN_H 2 | #define ROI_ALIGN_H 3 | 4 | #include 5 | 6 | int roi_align_forward(int aligned_height, int aligned_width, float spatial_scale, 7 | torch::Tensor features, torch::Tensor rois, torch::Tensor output); 8 | 9 | int roi_align_backward(int aligned_height, int aligned_width, float spatial_scale, 10 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /roi_align/src/roi_align_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "roi_align_kernel.h" 5 | 6 | 7 | int roi_align_forward_cuda(int aligned_height, int aligned_width, float spatial_scale, 8 | torch::Tensor features, torch::Tensor rois, torch::Tensor output) 9 | { 10 | // Grab the input tensor 11 | //float * data_flat = THCudaTensor_data(state, features); 12 | //float * rois_flat = THCudaTensor_data(state, rois); 13 | 14 | //float * output_flat = THCudaTensor_data(state, output); 15 | 16 | auto data_flat = features.data(); 17 | auto rois_flat = rois.data(); 18 | auto output_flat = output.data(); 19 | 20 | // Number of ROIs 21 | //int num_rois = THCudaTensor_size(state, rois, 0); 22 | //int size_rois = THCudaTensor_size(state, rois, 1); 23 | auto rois_sz = rois.sizes(); 24 | int num_rois = rois_sz[0]; 25 | int size_rois = rois_sz[1]; 26 | if (size_rois != 5) 27 | { 28 | return 0; 29 | } 30 | 31 | // data height 32 | //int data_height = THCudaTensor_size(state, features, 2); 33 | // data width 34 | //int data_width = THCudaTensor_size(state, features, 3); 35 | // Number of channels 36 | //int num_channels = THCudaTensor_size(state, features, 1); 37 | auto feat_sz = features.sizes(); 38 | int data_height = feat_sz[2]; 39 | int data_width = feat_sz[3]; 40 | int num_channels = feat_sz[1]; 41 | 42 | 43 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 44 | 45 | ROIAlignForwardLaucher( 46 | data_flat, spatial_scale, num_rois, data_height, 47 | data_width, num_channels, aligned_height, 48 | aligned_width, rois_flat, 49 | output_flat, stream); 50 | 51 | return 1; 52 | } 53 | 54 | int roi_align_backward_cuda(int aligned_height, int aligned_width, float spatial_scale, 55 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad) 56 | { 57 | // Grab the input tensor 58 | //float * top_grad_flat = THCudaTensor_data(state, top_grad); 59 | //float * rois_flat = THCudaTensor_data(state, rois); 60 | 61 | //float * bottom_grad_flat = THCudaTensor_data(state, bottom_grad); 62 | auto top_grad_flat = top_grad.data(); 63 | auto rois_flat = rois.data(); 64 | auto bottom_grad_flat = bottom_grad.data(); 65 | 66 | // Number of ROIs 67 | //int num_rois = THCudaTensor_size(state, rois, 0); 68 | //int size_rois = THCudaTensor_size(state, rois, 1); 69 | auto rois_sz = rois.sizes(); 70 | int num_rois = rois_sz[0]; 71 | int size_rois = rois_sz[1]; 72 | 73 | if (size_rois != 5) 74 | { 75 | return 0; 76 | } 77 | 78 | // batch size 79 | //int batch_size = THCudaTensor_size(state, bottom_grad, 0); 80 | // data height 81 | //int data_height = THCudaTensor_size(state, bottom_grad, 2); 82 | // data width 83 | //int data_width = THCudaTensor_size(state, bottom_grad, 3); 84 | // Number of channels 85 | //int num_channels = THCudaTensor_size(state, bottom_grad, 1); 86 | auto grad_sz = bottom_grad.sizes(); 87 | int batch_size = grad_sz[0]; 88 | int data_height = grad_sz[2]; 89 | int data_width = grad_sz[3]; 90 | int num_channels = grad_sz[1]; 91 | 92 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 93 | ROIAlignBackwardLaucher( 94 | top_grad_flat, spatial_scale, batch_size, num_rois, data_height, 95 | data_width, num_channels, aligned_height, 96 | aligned_width, rois_flat, 97 | bottom_grad_flat, stream); 98 | 99 | return 1; 100 | } 101 | 102 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 103 | m.def("forward", &roi_align_forward_cuda, "roi_align forward"); 104 | m.def("backward", &roi_align_backward_cuda, "roi_align backward"); 105 | } 106 | -------------------------------------------------------------------------------- /roi_align/src/roi_align_cuda.h: -------------------------------------------------------------------------------- 1 | #ifndef ROI_ALIGN_CUDA_H 2 | #define ROI_ALIGN_CUDA_H 3 | 4 | #include 5 | 6 | int roi_align_forward_cuda(int aligned_height, int aligned_width, float spatial_scale, 7 | torch::Tensor features, torch::Tensor rois, torch::Tensor output); 8 | 9 | int roi_align_backward_cuda(int aligned_height, int aligned_width, float spatial_scale, 10 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /roi_align/src/roi_align_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "roi_align_kernel.h" 5 | 6 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 7 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ 8 | i += blockDim.x * gridDim.x) 9 | 10 | 11 | __global__ void ROIAlignForward(const int nthreads, const float* bottom_data, const float spatial_scale, const int height, const int width, 12 | const int channels, const int aligned_height, const int aligned_width, const float* bottom_rois, float* top_data) { 13 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 14 | // (n, c, ph, pw) is an element in the aligned output 15 | // int n = index; 16 | // int pw = n % aligned_width; 17 | // n /= aligned_width; 18 | // int ph = n % aligned_height; 19 | // n /= aligned_height; 20 | // int c = n % channels; 21 | // n /= channels; 22 | 23 | int pw = index % aligned_width; 24 | int ph = (index / aligned_width) % aligned_height; 25 | int c = (index / aligned_width / aligned_height) % channels; 26 | int n = index / aligned_width / aligned_height / channels; 27 | 28 | // bottom_rois += n * 5; 29 | float roi_batch_ind = bottom_rois[n * 5 + 0]; 30 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale; 31 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale; 32 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale; 33 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale; 34 | 35 | // Force malformed ROIs to be 1x1 36 | float roi_width = fmaxf(roi_end_w - roi_start_w + 1., 0.); 37 | float roi_height = fmaxf(roi_end_h - roi_start_h + 1., 0.); 38 | float bin_size_h = roi_height / (aligned_height - 1.); 39 | float bin_size_w = roi_width / (aligned_width - 1.); 40 | 41 | float h = (float)(ph) * bin_size_h + roi_start_h; 42 | float w = (float)(pw) * bin_size_w + roi_start_w; 43 | 44 | int hstart = fminf(floor(h), height - 2); 45 | int wstart = fminf(floor(w), width - 2); 46 | 47 | int img_start = roi_batch_ind * channels * height * width; 48 | 49 | // bilinear interpolation 50 | if (h < 0 || h >= height || w < 0 || w >= width) { 51 | top_data[index] = 0.; 52 | } else { 53 | float h_ratio = h - (float)(hstart); 54 | float w_ratio = w - (float)(wstart); 55 | int upleft = img_start + (c * height + hstart) * width + wstart; 56 | int upright = upleft + 1; 57 | int downleft = upleft + width; 58 | int downright = downleft + 1; 59 | 60 | top_data[index] = bottom_data[upleft] * (1. - h_ratio) * (1. - w_ratio) 61 | + bottom_data[upright] * (1. - h_ratio) * w_ratio 62 | + bottom_data[downleft] * h_ratio * (1. - w_ratio) 63 | + bottom_data[downright] * h_ratio * w_ratio; 64 | } 65 | } 66 | } 67 | 68 | 69 | int ROIAlignForwardLaucher(const float* bottom_data, const float spatial_scale, const int num_rois, const int height, const int width, 70 | const int channels, const int aligned_height, const int aligned_width, const float* bottom_rois, float* top_data, cudaStream_t stream) { 71 | const int kThreadsPerBlock = 1024; 72 | const int output_size = num_rois * aligned_height * aligned_width * channels; 73 | cudaError_t err; 74 | 75 | 76 | ROIAlignForward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>( 77 | output_size, bottom_data, spatial_scale, height, width, channels, 78 | aligned_height, aligned_width, bottom_rois, top_data); 79 | 80 | err = cudaGetLastError(); 81 | if(cudaSuccess != err) { 82 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 83 | exit( -1 ); 84 | } 85 | 86 | return 1; 87 | } 88 | 89 | 90 | __global__ void ROIAlignBackward(const int nthreads, const float* top_diff, const float spatial_scale, const int height, const int width, 91 | const int channels, const int aligned_height, const int aligned_width, float* bottom_diff, const float* bottom_rois) { 92 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 93 | 94 | // (n, c, ph, pw) is an element in the aligned output 95 | int pw = index % aligned_width; 96 | int ph = (index / aligned_width) % aligned_height; 97 | int c = (index / aligned_width / aligned_height) % channels; 98 | int n = index / aligned_width / aligned_height / channels; 99 | 100 | float roi_batch_ind = bottom_rois[n * 5 + 0]; 101 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale; 102 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale; 103 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale; 104 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale; 105 | /* int roi_start_w = round(bottom_rois[1] * spatial_scale); */ 106 | /* int roi_start_h = round(bottom_rois[2] * spatial_scale); */ 107 | /* int roi_end_w = round(bottom_rois[3] * spatial_scale); */ 108 | /* int roi_end_h = round(bottom_rois[4] * spatial_scale); */ 109 | 110 | // Force malformed ROIs to be 1x1 111 | float roi_width = fmaxf(roi_end_w - roi_start_w + 1., 0.); 112 | float roi_height = fmaxf(roi_end_h - roi_start_h + 1., 0.); 113 | float bin_size_h = roi_height / (aligned_height - 1.); 114 | float bin_size_w = roi_width / (aligned_width - 1.); 115 | 116 | float h = (float)(ph) * bin_size_h + roi_start_h; 117 | float w = (float)(pw) * bin_size_w + roi_start_w; 118 | 119 | int hstart = fminf(floor(h), height - 2); 120 | int wstart = fminf(floor(w), width - 2); 121 | 122 | int img_start = roi_batch_ind * channels * height * width; 123 | 124 | // bilinear interpolation 125 | if (!(h < 0 || h >= height || w < 0 || w >= width)) { 126 | float h_ratio = h - (float)(hstart); 127 | float w_ratio = w - (float)(wstart); 128 | int upleft = img_start + (c * height + hstart) * width + wstart; 129 | int upright = upleft + 1; 130 | int downleft = upleft + width; 131 | int downright = downleft + 1; 132 | 133 | atomicAdd(bottom_diff + upleft, top_diff[index] * (1. - h_ratio) * (1 - w_ratio)); 134 | atomicAdd(bottom_diff + upright, top_diff[index] * (1. - h_ratio) * w_ratio); 135 | atomicAdd(bottom_diff + downleft, top_diff[index] * h_ratio * (1 - w_ratio)); 136 | atomicAdd(bottom_diff + downright, top_diff[index] * h_ratio * w_ratio); 137 | } 138 | } 139 | } 140 | 141 | int ROIAlignBackwardLaucher(const float* top_diff, const float spatial_scale, const int batch_size, const int num_rois, const int height, const int width, 142 | const int channels, const int aligned_height, const int aligned_width, const float* bottom_rois, float* bottom_diff, cudaStream_t stream) { 143 | const int kThreadsPerBlock = 1024; 144 | const int output_size = num_rois * aligned_height * aligned_width * channels; 145 | cudaError_t err; 146 | 147 | ROIAlignBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>( 148 | output_size, top_diff, spatial_scale, height, width, channels, 149 | aligned_height, aligned_width, bottom_diff, bottom_rois); 150 | 151 | err = cudaGetLastError(); 152 | if(cudaSuccess != err) { 153 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 154 | exit( -1 ); 155 | } 156 | 157 | return 1; 158 | } 159 | -------------------------------------------------------------------------------- /roi_align/src/roi_align_kernel.cu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intchous/SmartText/2a4132a09adb71a30c1f5aa5c0596e24d3ea608c/roi_align/src/roi_align_kernel.cu.o -------------------------------------------------------------------------------- /roi_align/src/roi_align_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _ROI_ALIGN_KERNEL 2 | #define _ROI_ALIGN_KERNEL 3 | 4 | 5 | __global__ void ROIAlignForward(const int nthreads, const float* bottom_data, 6 | const float spatial_scale, const int height, const int width, 7 | const int channels, const int aligned_height, const int aligned_width, 8 | const float* bottom_rois, float* top_data); 9 | 10 | int ROIAlignForwardLaucher( 11 | const float* bottom_data, const float spatial_scale, const int num_rois, const int height, 12 | const int width, const int channels, const int aligned_height, 13 | const int aligned_width, const float* bottom_rois, 14 | float* top_data, cudaStream_t stream); 15 | 16 | __global__ void ROIAlignBackward(const int nthreads, const float* top_diff, 17 | const float spatial_scale, const int height, const int width, 18 | const int channels, const int aligned_height, const int aligned_width, 19 | float* bottom_diff, const float* bottom_rois); 20 | 21 | int ROIAlignBackwardLaucher(const float* top_diff, const float spatial_scale, const int batch_size, const int num_rois, 22 | const int height, const int width, const int channels, const int aligned_height, 23 | const int aligned_width, const float* bottom_rois, 24 | float* bottom_diff, cudaStream_t stream); 25 | 26 | #endif 27 | 28 | -------------------------------------------------------------------------------- /smtDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data as data 3 | import cv2 4 | import math 5 | import numpy as np 6 | from generate_candidates import gen_boxes_multi 7 | 8 | MOS_MEAN = 2.95 9 | MOS_STD = 0.8 10 | RGB_MEAN = (0.485, 0.456, 0.406) 11 | RGB_STD = (0.229, 0.224, 0.225) 12 | 13 | IMG_EXTENSIONS = [ 14 | '.jpg', 15 | '.JPG', 16 | '.jpeg', 17 | '.JPEG', 18 | '.png', 19 | '.PNG', 20 | '.ppm', 21 | '.PPM', 22 | '.bmp', 23 | '.BMP', 24 | ] 25 | 26 | 27 | def is_image_file(filename): 28 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 29 | 30 | 31 | def find_expand(image, annotation, exp_prop=5): 32 | anno_w = math.floor(abs(float(annotation[3] - annotation[1]))) 33 | anno_h = math.floor(abs(float(annotation[2] - annotation[0]))) 34 | lx = min(annotation[0], annotation[2]) 35 | rx = max(annotation[0], annotation[2]) 36 | ly = min(annotation[1], annotation[3]) 37 | ry = max(annotation[1], annotation[3]) 38 | new_lx = max(0, int(lx - exp_prop * anno_h)) 39 | new_ly = max(0, int(ly - exp_prop * anno_w)) 40 | new_rx = min(image.shape[0], int(rx + exp_prop * anno_h)) 41 | new_ry = min(image.shape[1], int(ry + exp_prop * anno_w)) 42 | 43 | new_image = image[new_lx:new_rx, new_ly:new_ry].copy() 44 | 45 | new_anno_lx = lx - new_lx 46 | new_anno_ly = ly - new_ly 47 | new_anno_rx = rx - new_lx 48 | new_anno_ry = ry - new_ly 49 | 50 | return new_image, [new_anno_lx, new_anno_ly, new_anno_rx, new_anno_ry, annotation[4]] 51 | 52 | 53 | class TransformFunctionTest_RoE(object): 54 | 55 | def __call__(self, image, image_size, image_fp, usr_slogan, font_fp, is_devi, visimp_model, proc_fa_dir, ratio_list, 56 | text_spacing, exp_prop, grid_num, sali_coef, max_text_area_coef, min_text_area_coef, min_font_size, 57 | max_font_size, font_inc_unit): 58 | 59 | visimp_pred_dir = proc_fa_dir + 'visimp_pred/' 60 | visimp_pred_dir_ovl = visimp_pred_dir 61 | box_dict, img_visimp = gen_boxes_multi(img_name=image_fp, 62 | visimp_pred_dir=visimp_pred_dir, 63 | visimp_pred_dir_ovl=visimp_pred_dir_ovl, 64 | visimp_model=visimp_model, 65 | usr_slogan=usr_slogan, 66 | font_fp=font_fp, 67 | base_dat_dir=proc_fa_dir, 68 | is_devi=is_devi, 69 | ratio_list=ratio_list, 70 | text_spacing=text_spacing, 71 | grid_num=grid_num, 72 | sali_coef=sali_coef, 73 | max_text_area_coef=max_text_area_coef, 74 | min_text_area_coef=min_text_area_coef, 75 | min_font_size=min_font_size, 76 | max_font_size=max_font_size, 77 | font_inc_unit=font_inc_unit) 78 | 79 | box_list = box_dict['new_anno_list'] 80 | len_box_list = len(box_list) 81 | bboxes = [] 82 | for ik in range(len_box_list): 83 | bboxes.append([ 84 | box_list[ik][0]['xl'], box_list[ik][0]['yl'], box_list[ik][0]['xr'], box_list[ik][0]['yr'], 85 | box_list[ik][0]['tl_cnt'] 86 | ]) 87 | 88 | len_bboxes = len(bboxes) 89 | transformed_bboxes = {} 90 | transformed_bboxes['xmin'] = [] 91 | transformed_bboxes['ymin'] = [] 92 | transformed_bboxes['xmax'] = [] 93 | transformed_bboxes['ymax'] = [] 94 | source_bboxes = list() 95 | resized_images = [] 96 | mx_w = 0 97 | mx_h = 0 98 | sto_image = np.array(image) 99 | 100 | for i in range(len_bboxes): 101 | image = np.array(sto_image) 102 | # exp_prop: expanding coefficient of the text region 103 | image, tmp_bbox = find_expand(image=image, annotation=bboxes[i], exp_prop=exp_prop) 104 | 105 | scale = float(image_size) / float(min(image.shape[:2])) 106 | h = round(image.shape[0] * scale / 32.0) * 32 107 | w = round(image.shape[1] * scale / 32.0) * 32 108 | resized_image = cv2.resize(image, (int(w), int(h))) / 256.0 109 | # img = cv2.resize(img, (crop_size, crop_size), interpolation = cv2.INTER_AREA) 110 | rgb_mean = np.array(RGB_MEAN, dtype=np.float32) 111 | rgb_std = np.array(RGB_STD, dtype=np.float32) 112 | resized_image = resized_image.astype(np.float32) 113 | resized_image -= rgb_mean 114 | resized_image = resized_image / rgb_std 115 | if (resized_image.shape[0] > mx_h): 116 | mx_h = resized_image.shape[0] 117 | if (resized_image.shape[1] > mx_w): 118 | mx_w = resized_image.shape[1] 119 | 120 | scale_height = image.shape[0] / float(resized_image.shape[0]) 121 | scale_width = image.shape[1] / float(resized_image.shape[1]) 122 | resized_images.append(resized_image) 123 | 124 | # source_bboxes.append([round(bbox[0] * scale_height),round(bbox[1] * scale_width),round(bbox[2] * scale_height),round(bbox[3] * scale_width)]) 125 | source_bboxes.append( 126 | [round(bboxes[i][0]), 127 | round(bboxes[i][1]), 128 | round(bboxes[i][2]), 129 | round(bboxes[i][3]), bboxes[i][4]]) 130 | transformed_bboxes['xmin'].append(tmp_bbox[1] / scale_width) 131 | transformed_bboxes['ymin'].append(tmp_bbox[0] / scale_height) 132 | transformed_bboxes['xmax'].append(tmp_bbox[3] / scale_width) 133 | transformed_bboxes['ymax'].append(tmp_bbox[2] / scale_height) 134 | 135 | len_resized_images = len(resized_images) 136 | for i in range(len_resized_images): 137 | r_itm = resized_images[i].copy() 138 | pre_h = r_itm.shape[0] 139 | pre_w = r_itm.shape[1] 140 | r_itm = np.pad(r_itm, ((0, mx_h - pre_h), (0, mx_w - pre_w), (0, 0)), 'constant') 141 | r_itm = r_itm.transpose((2, 0, 1)) 142 | resized_images[i] = r_itm.copy() 143 | 144 | return resized_images, transformed_bboxes, source_bboxes, box_list 145 | 146 | 147 | class TransformFunctionTest_RoD(object): 148 | 149 | def __call__(self, image, image_size, image_fp, usr_slogan, font_fp, is_devi, visimp_model, proc_fa_dir, ratio_list, 150 | text_spacing, grid_num, sali_coef, max_text_area_coef, min_text_area_coef, min_font_size, 151 | max_font_size, font_inc_unit): 152 | 153 | visimp_pred_dir = proc_fa_dir + 'visimp_pred/' 154 | visimp_pred_dir_ovl = visimp_pred_dir 155 | box_dict, img_visimp = gen_boxes_multi(img_name=image_fp, 156 | visimp_pred_dir=visimp_pred_dir, 157 | visimp_pred_dir_ovl=visimp_pred_dir_ovl, 158 | visimp_model=visimp_model, 159 | usr_slogan=usr_slogan, 160 | font_fp=font_fp, 161 | base_dat_dir=proc_fa_dir, 162 | is_devi=is_devi, 163 | ratio_list=ratio_list, 164 | text_spacing=text_spacing, 165 | grid_num=grid_num, 166 | sali_coef=sali_coef, 167 | max_text_area_coef=max_text_area_coef, 168 | min_text_area_coef=min_text_area_coef, 169 | min_font_size=min_font_size, 170 | max_font_size=max_font_size, 171 | font_inc_unit=font_inc_unit) 172 | 173 | box_list = box_dict['new_anno_list'] 174 | len_box_list = len(box_list) 175 | bboxes = [] 176 | for ik in range(len_box_list): 177 | bboxes.append([ 178 | box_list[ik][0]['xl'], box_list[ik][0]['yl'], box_list[ik][0]['xr'], box_list[ik][0]['yr'], 179 | box_list[ik][0]['tl_cnt'] 180 | ]) 181 | 182 | transformed_bbox = {} 183 | transformed_bbox['xmin'] = [] 184 | transformed_bbox['ymin'] = [] 185 | transformed_bbox['xmax'] = [] 186 | transformed_bbox['ymax'] = [] 187 | source_bboxes = list() 188 | 189 | scale = float(image_size) / float(min(image.shape[:2])) 190 | h = round(image.shape[0] * scale / 32.0) * 32 191 | w = round(image.shape[1] * scale / 32.0) * 32 192 | resized_image = cv2.resize(image, (int(w), int(h))) / 256.0 193 | rgb_mean = np.array(RGB_MEAN, dtype=np.float32) 194 | rgb_std = np.array(RGB_STD, dtype=np.float32) 195 | resized_image = resized_image.astype(np.float32) 196 | resized_image -= rgb_mean 197 | resized_image = resized_image / rgb_std 198 | 199 | scale_height = image.shape[0] / float(resized_image.shape[0]) 200 | scale_width = image.shape[1] / float(resized_image.shape[1]) 201 | 202 | for bbox in bboxes: 203 | # source_bboxes.append([round(bbox[0] * scale_height),round(bbox[1] * scale_width),round(bbox[2] * scale_height),round(bbox[3] * scale_width)]) 204 | source_bboxes.append([round(bbox[0]), round(bbox[1]), round(bbox[2]), round(bbox[3]), bbox[4]]) 205 | transformed_bbox['xmin'].append(bbox[1] / scale_width) 206 | transformed_bbox['ymin'].append(bbox[0] / scale_height) 207 | transformed_bbox['xmax'].append(bbox[3] / scale_width) 208 | transformed_bbox['ymax'].append(bbox[2] / scale_height) 209 | 210 | resized_image = resized_image.transpose((2, 0, 1)) 211 | return resized_image, transformed_bbox, source_bboxes, box_list 212 | 213 | 214 | class setup_test_dataset(data.Dataset): 215 | 216 | def __init__(self, 217 | usr_slogan, 218 | font_fp, 219 | visimp_model, 220 | proc_fa_dir, 221 | is_devi=False, 222 | image_size=256.0, 223 | dataset_dir='testsetDir', 224 | model_type='RoD', 225 | ratio_list=[1, 1, 1, 1, 1], 226 | text_spacing=20, 227 | exp_prop=5, 228 | grid_num=120, 229 | sali_coef=2.6, 230 | max_text_area_coef=17, 231 | min_text_area_coef=7, 232 | min_font_size=10, 233 | max_font_size=500, 234 | font_inc_unit=5): 235 | self.image_size = float(image_size) 236 | self.dataset_dir = dataset_dir 237 | image_lists = os.listdir(self.dataset_dir) 238 | self._imgpath = list() 239 | self._annopath = list() 240 | for image in image_lists: 241 | if (is_image_file(image)): 242 | self._imgpath.append(os.path.join(self.dataset_dir, image)) 243 | 244 | self.model_type = model_type 245 | if (self.model_type == 'RoE'): 246 | self.transform = TransformFunctionTest_RoE() 247 | else: 248 | self.transform = TransformFunctionTest_RoD() 249 | self.usr_slogan = usr_slogan 250 | self.font_fp = font_fp 251 | self.is_devi = is_devi 252 | self.visimp_model = visimp_model 253 | self.proc_fa_dir = proc_fa_dir 254 | self.ratio_list = ratio_list 255 | self.text_spacing = text_spacing 256 | self.exp_prop = exp_prop 257 | self.grid_num = grid_num 258 | self.sali_coef = sali_coef 259 | self.max_text_area_coef = max_text_area_coef 260 | self.min_text_area_coef = min_text_area_coef 261 | self.min_font_size = min_font_size 262 | self.max_font_size = max_font_size 263 | self.font_inc_unit = font_inc_unit 264 | 265 | def __getitem__(self, idx): 266 | image = cv2.imread(self._imgpath[idx]) 267 | # to rgb 268 | image = image[:, :, (2, 1, 0)] 269 | 270 | if (self.model_type == 'RoE'): 271 | if self.transform: 272 | resized_images, transformed_bboxes, source_bboxes, box_list = self.transform( 273 | image, self.image_size, self._imgpath[idx], self.usr_slogan, self.font_fp, self.is_devi, 274 | self.visimp_model, self.proc_fa_dir, self.ratio_list, self.text_spacing, self.exp_prop, 275 | self.grid_num, self.sali_coef, self.max_text_area_coef, self.min_text_area_coef, self.min_font_size, 276 | self.max_font_size, self.font_inc_unit) 277 | 278 | sample = { 279 | 'imgpath': self._imgpath[idx], 280 | 'image': image, 281 | 'resized_images': resized_images, 282 | 'tbboxes': transformed_bboxes, 283 | 'sourceboxes': source_bboxes, 284 | 'box_list': box_list 285 | } 286 | else: 287 | if self.transform: 288 | resized_image, transformed_bbox, source_bboxes, box_list = self.transform( 289 | image, self.image_size, self._imgpath[idx], self.usr_slogan, self.font_fp, self.is_devi, 290 | self.visimp_model, self.proc_fa_dir, self.ratio_list, self.text_spacing, self.grid_num, 291 | self.sali_coef, self.max_text_area_coef, self.min_text_area_coef, self.min_font_size, 292 | self.max_font_size, self.font_inc_unit) 293 | 294 | sample = { 295 | 'imgpath': self._imgpath[idx], 296 | 'image': image, 297 | 'resized_images': resized_image, 298 | 'tbboxes': transformed_bbox, 299 | 'sourceboxes': source_bboxes, 300 | 'box_list': box_list 301 | } 302 | 303 | return sample 304 | 305 | def __len__(self): 306 | return len(self._imgpath) 307 | -------------------------------------------------------------------------------- /smtModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | from roi_align.modules.roi_align import RoIAlignAvg, RoIAlign 5 | from rod_align.modules.rod_align import RoDAlignAvg, RoDAlign 6 | import torch.nn.init as init 7 | from ShuffleNetV2 import shufflenetv2 8 | from mobilenetv2 import MobileNetV2 9 | 10 | 11 | class vgg_base(nn.Module): 12 | 13 | def __init__(self, loadweights=True, downsample=4): 14 | super(vgg_base, self).__init__() 15 | 16 | vgg = models.vgg16(pretrained=True) 17 | 18 | if downsample == 4: 19 | self.feature = nn.Sequential(vgg.features[:-1]) 20 | elif downsample == 5: 21 | self.feature = nn.Sequential(vgg.features) 22 | 23 | self.feature3 = nn.Sequential(vgg.features[:23]) 24 | self.feature4 = nn.Sequential(vgg.features[23:30]) 25 | self.feature5 = nn.Sequential(vgg.features[30:]) 26 | 27 | def forward(self, x): 28 | f3 = self.feature3(x) 29 | f4 = self.feature4(f3) 30 | f5 = self.feature5(f4) 31 | return f3, f4, f5 32 | 33 | 34 | class resnet50_base(nn.Module): 35 | 36 | def __init__(self, loadweights=True, downsample=4): 37 | super(resnet50_base, self).__init__() 38 | 39 | resnet50 = models.resnet50(pretrained=True) 40 | 41 | self.feature3 = nn.Sequential(resnet50.conv1, resnet50.bn1, resnet50.relu, resnet50.maxpool, resnet50.layer1, 42 | resnet50.layer2) 43 | self.feature4 = nn.Sequential(resnet50.layer3) 44 | self.feature5 = nn.Sequential(resnet50.layer4) 45 | 46 | def forward(self, x): 47 | f3 = self.feature3(x) 48 | f4 = self.feature4(f3) 49 | f5 = self.feature5(f4) 50 | return f3, f4, f5 51 | 52 | 53 | class mobilenetv2_base(nn.Module): 54 | 55 | def __init__(self, loadweights=True, downsample=4, model_path='pretrained_model/mobilenetv2_1.0-0c6065bc.pth'): 56 | super(mobilenetv2_base, self).__init__() 57 | 58 | model = MobileNetV2(width_mult=1.0) 59 | 60 | if loadweights: 61 | model.load_state_dict(torch.load(model_path)) 62 | 63 | self.feature3 = nn.Sequential(model.features[:7]) 64 | self.feature4 = nn.Sequential(model.features[7:14]) 65 | self.feature5 = nn.Sequential(model.features[14:]) 66 | 67 | def forward(self, x): 68 | f3 = self.feature3(x) 69 | f4 = self.feature4(f3) 70 | f5 = self.feature5(f4) 71 | return f3, f4, f5 72 | 73 | 74 | class shufflenetv2_base(nn.Module): 75 | 76 | def __init__(self, 77 | loadweights=True, 78 | downsample=4, 79 | model_path='pretrained_model/shufflenetv2_x1_69.402_88.374.pth.tar'): 80 | super(shufflenetv2_base, self).__init__() 81 | 82 | model = shufflenetv2(width_mult=1.0) 83 | 84 | if loadweights: 85 | model.load_state_dict(torch.load(model_path)) 86 | 87 | self.feature3 = nn.Sequential(model.conv1, model.maxpool, model.features[:4]) 88 | self.feature4 = nn.Sequential(model.features[4:12]) 89 | self.feature5 = nn.Sequential(model.features[12:]) 90 | 91 | def forward(self, x): 92 | f3 = self.feature3(x) 93 | f4 = self.feature4(f3) 94 | f5 = self.feature5(f4) 95 | return f3, f4, f5 96 | 97 | 98 | def fc_layers(reddim=32, alignsize=8): 99 | conv1 = nn.Sequential(nn.Conv2d(reddim, 768, kernel_size=alignsize, padding=0), nn.BatchNorm2d(768), 100 | nn.ReLU(inplace=True)) 101 | conv2 = nn.Sequential(nn.Conv2d(768, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True)) 102 | dropout = nn.Dropout(p=0.5) 103 | conv3 = nn.Conv2d(128, 1, kernel_size=1) 104 | layers = nn.Sequential(conv1, conv2, dropout, conv3) 105 | return layers 106 | 107 | 108 | class smt_model_single_scale(nn.Module): 109 | 110 | def __init__(self, alignsize=8, reddim=8, loadweight=True, model=None, downsample=4): 111 | super(smt_model_single_scale, self).__init__() 112 | 113 | if model == 'shufflenetv2': 114 | self.Feat_ext = shufflenetv2_base(loadweight, downsample) 115 | if downsample == 4: 116 | self.DimRed = nn.Conv2d(232, reddim, kernel_size=1, padding=0) 117 | else: 118 | self.DimRed = nn.Conv2d(464, reddim, kernel_size=1, padding=0) 119 | elif model == 'mobilenetv2': 120 | self.Feat_ext = mobilenetv2_base(loadweight, downsample) 121 | if downsample == 4: 122 | self.DimRed = nn.Conv2d(96, reddim, kernel_size=1, padding=0) 123 | else: 124 | self.DimRed = nn.Conv2d(320, reddim, kernel_size=1, padding=0) 125 | elif model == 'vgg16': 126 | self.Feat_ext = vgg_base(loadweight, downsample) 127 | self.DimRed = nn.Conv2d(512, reddim, kernel_size=1, padding=0) 128 | elif model == 'resnet50': 129 | self.Feat_ext = resnet50_base(loadweight, downsample) 130 | self.DimRed = nn.Conv2d(1024, reddim, kernel_size=1, padding=0) 131 | 132 | self.RoIAlign = RoIAlignAvg(alignsize, alignsize, 1.0 / 2**downsample) 133 | self.RoDAlign = RoDAlignAvg(alignsize, alignsize, 1.0 / 2**downsample) 134 | self.FC_layers = fc_layers(reddim * 2, alignsize) 135 | 136 | def forward(self, im_data, boxes): 137 | f3, base_feat, f5 = self.Feat_ext(im_data) 138 | red_feat = self.DimRed(base_feat) 139 | RoI_feat = self.RoIAlign(red_feat, boxes) 140 | RoD_feat = self.RoDAlign(red_feat, boxes) 141 | final_feat = torch.cat((RoI_feat, RoD_feat), 1) 142 | prediction = self.FC_layers(final_feat) 143 | return prediction 144 | 145 | def _init_weights(self): 146 | print('Initializing weights...') 147 | self.DimRed.apply(weights_init) 148 | self.FC_layers.apply(weights_init) 149 | 150 | 151 | class smt_model_multi_scale_individual(nn.Module): 152 | 153 | def __init__(self, alignsize=8, reddim=32, loadweight=True, model=None, downsample=4): 154 | super(smt_model_multi_scale_individual, self).__init__() 155 | 156 | if model == 'shufflenetv2': 157 | self.Feat_ext1 = shufflenetv2_base(loadweight, downsample) 158 | self.Feat_ext2 = shufflenetv2_base(loadweight, downsample) 159 | self.Feat_ext3 = shufflenetv2_base(loadweight, downsample) 160 | self.DimRed = nn.Conv2d(232, reddim, kernel_size=1, padding=0) 161 | elif model == 'mobilenetv2': 162 | self.Feat_ext1 = mobilenetv2_base(loadweight, downsample) 163 | self.Feat_ext2 = mobilenetv2_base(loadweight, downsample) 164 | self.Feat_ext3 = mobilenetv2_base(loadweight, downsample) 165 | self.DimRed = nn.Conv2d(96, reddim, kernel_size=1, padding=0) 166 | elif model == 'vgg16': 167 | self.Feat_ext1 = vgg_base(loadweight, downsample) 168 | self.Feat_ext2 = vgg_base(loadweight, downsample) 169 | self.Feat_ext3 = vgg_base(loadweight, downsample) 170 | self.DimRed = nn.Conv2d(512, reddim, kernel_size=1, padding=0) 171 | 172 | self.downsample2 = nn.UpsamplingBilinear2d(scale_factor=1.0 / 2.0) 173 | self.upsample2 = nn.UpsamplingBilinear2d(scale_factor=2.0) 174 | self.RoIAlign = RoIAlignAvg(alignsize, alignsize, 1.0 / 2**downsample) 175 | self.RoDAlign = RoDAlignAvg(alignsize, alignsize, 1.0 / 2**downsample) 176 | self.FC_layers = fc_layers(reddim * 2, alignsize) 177 | 178 | def forward(self, im_data, boxes): 179 | base_feat = self.Feat_ext1(im_data) 180 | 181 | up_im = self.upsample2(im_data) 182 | up_feat = self.Feat_ext2(up_im) 183 | up_feat = self.downsample2(up_feat) 184 | 185 | down_im = self.downsample2(im_data) 186 | down_feat = self.Feat_ext3(down_im) 187 | down_feat = self.upsample2(down_feat) 188 | 189 | # cat_feat = torch.cat((base_feat,up_feat,down_feat),1) 190 | cat_feat = 0.5 * base_feat + 0.35 * up_feat + 0.15 * down_feat 191 | red_feat = self.DimRed(cat_feat) 192 | RoI_feat = self.RoIAlign(red_feat, boxes) 193 | RoD_feat = self.RoDAlign(red_feat, boxes) 194 | final_feat = torch.cat((RoI_feat, RoD_feat), 1) 195 | prediction = self.FC_layers(final_feat) 196 | return prediction 197 | 198 | def _init_weights(self): 199 | print('Initializing weights...') 200 | self.DimRed.apply(weights_init) 201 | self.FC_layers.apply(weights_init) 202 | 203 | 204 | class smt_model_multi_scale_shared(nn.Module): 205 | 206 | def __init__(self, alignsize=8, reddim=32, loadweight=True, model=None, downsample=4): 207 | super(smt_model_multi_scale_shared, self).__init__() 208 | 209 | if model == 'shufflenetv2': 210 | self.Feat_ext = shufflenetv2_base(loadweight, downsample) 211 | self.DimRed = nn.Conv2d(812, reddim, kernel_size=1, padding=0) 212 | elif model == 'mobilenetv2': 213 | self.Feat_ext = mobilenetv2_base(loadweight, downsample) 214 | self.DimRed = nn.Conv2d(448, reddim, kernel_size=1, padding=0) 215 | elif model == 'vgg16': 216 | self.Feat_ext = vgg_base(loadweight, downsample) 217 | self.DimRed = nn.Conv2d(1536, reddim, kernel_size=1, padding=0) 218 | elif model == 'resnet50': 219 | self.Feat_ext = resnet50_base(loadweight, downsample) 220 | self.DimRed = nn.Conv2d(3584, reddim, kernel_size=1, padding=0) 221 | 222 | self.downsample2 = nn.UpsamplingBilinear2d(scale_factor=1.0 / 2.0) 223 | self.upsample2 = nn.UpsamplingBilinear2d(scale_factor=2.0) 224 | self.RoIAlign = RoIAlignAvg(alignsize, alignsize, 1.0 / 2**downsample) 225 | self.RoDAlign = RoDAlignAvg(alignsize, alignsize, 1.0 / 2**downsample) 226 | self.FC_layers = fc_layers(reddim * 2, alignsize) 227 | 228 | def forward(self, im_data, boxes): 229 | f3, f4, f5 = self.Feat_ext(im_data) 230 | cat_feat = torch.cat((self.downsample2(f3), f4, 0.5 * self.upsample2(f5)), 1) 231 | 232 | red_feat = self.DimRed(cat_feat) 233 | RoI_feat = self.RoIAlign(red_feat, boxes) 234 | RoD_feat = self.RoDAlign(red_feat, boxes) 235 | final_feat = torch.cat((RoI_feat, RoD_feat), 1) 236 | prediction = self.FC_layers(final_feat) 237 | return prediction 238 | 239 | def _init_weights(self): 240 | print('Initializing weights...') 241 | self.DimRed.apply(weights_init) 242 | self.FC_layers.apply(weights_init) 243 | 244 | 245 | def xavier(param): 246 | init.xavier_uniform_(param) 247 | 248 | 249 | def weights_init(m): 250 | if isinstance(m, nn.Conv2d): 251 | xavier(m.weight.data) 252 | m.bias.data.zero_() 253 | 254 | 255 | def build_smt_model(scale='single', alignsize=8, reddim=32, loadweight=True, model=None, downsample=4): 256 | if scale == 'single': 257 | return smt_model_single_scale(alignsize, reddim, loadweight, model, downsample) 258 | elif scale == 'multi': 259 | return smt_model_multi_scale_shared(alignsize, reddim, loadweight, model, downsample) 260 | -------------------------------------------------------------------------------- /test_data/Fonts/verdanab.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intchous/SmartText/2a4132a09adb71a30c1f5aa5c0596e24d3ea608c/test_data/Fonts/verdanab.ttf -------------------------------------------------------------------------------- /test_data/SMT/00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intchous/SmartText/2a4132a09adb71a30c1f5aa5c0596e24d3ea608c/test_data/SMT/00001.jpg -------------------------------------------------------------------------------- /test_data/SMT/00002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intchous/SmartText/2a4132a09adb71a30c1f5aa5c0596e24d3ea608c/test_data/SMT/00002.png -------------------------------------------------------------------------------- /test_data/SMT/00003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intchous/SmartText/2a4132a09adb71a30c1f5aa5c0596e24d3ea608c/test_data/SMT/00003.jpg -------------------------------------------------------------------------------- /test_data/SMT/00004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intchous/SmartText/2a4132a09adb71a30c1f5aa5c0596e24d3ea608c/test_data/SMT/00004.jpg -------------------------------------------------------------------------------- /test_data/SMT/00005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intchous/SmartText/2a4132a09adb71a30c1f5aa5c0596e24d3ea608c/test_data/SMT/00005.jpg -------------------------------------------------------------------------------- /test_data/SMT/00006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intchous/SmartText/2a4132a09adb71a30c1f5aa5c0596e24d3ea608c/test_data/SMT/00006.jpg -------------------------------------------------------------------------------- /test_opt.yml: -------------------------------------------------------------------------------- 1 | name: SmartText 2 | model_type: RoE 3 | gpu_ids: [0] 4 | cuda: True 5 | batch_size: 1 6 | num_workers: 0 7 | dataset_name: SMT 8 | visimp_model: ./pretrained/gdi-basnet.pth 9 | smt_model: ./pretrained/SMT.pth 10 | input_dir: ./test_data/SMT/ 11 | usr_slogan: "ICME 2020\n6-10 July, London" 12 | ratio_list: [1, 0.8] # ratio of the text lines 13 | res_dir: ./test_result/SMT/ 14 | font_fp: ./test_data/Fonts/verdanab.ttf 15 | candi_res: 3 # number of the output candidates 16 | 17 | exp_prop: 6 # expanding coefficient of the text region 18 | is_devi: False 19 | text_spacing: 20 20 | grid_num: 120 21 | sali_coef: 2.6 # the larger sali_coef, the smaller area defined as important of the image 22 | max_text_area_coef: 17 23 | min_text_area_coef: 7 24 | min_font_size: 10 25 | max_font_size: 500 26 | font_inc_unit: 5 27 | 28 | contrast_threshold: 5 29 | -------------------------------------------------------------------------------- /util_cal.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | class candik(object): 8 | 9 | def __init__(self, val, rx, cy): 10 | self.val = val 11 | self.rx = rx 12 | self.cy = cy 13 | 14 | 15 | def takeVal(candik): 16 | return candik.val 17 | 18 | 19 | def bb_intersection(boxA, boxB): 20 | xA = max(boxA[0], boxB[0]) 21 | yA = max(boxA[1], boxB[1]) 22 | xB = min(boxA[2], boxB[2]) 23 | yB = min(boxA[3], boxB[3]) 24 | return max(0, xB - xA) * max(0, yB - yA) 25 | 26 | 27 | def get_top_k_submatrix(matrix, kernel_size, k, desc=False): 28 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 29 | x = torch.FloatTensor(matrix).to(device) 30 | h, w = x.shape 31 | x = F.avg_pool2d(x.view(1, 1, h, w), kernel_size=kernel_size, stride=(1, 1)) 32 | nh, nw = h - kernel_size[0] + 1, w - kernel_size[1] + 1 33 | x = x.view(nh, nw) * kernel_size[0] * kernel_size[1] 34 | x = x.cpu().numpy() 35 | idx = np.dstack(np.unravel_index(np.argsort(x.ravel()), (nh, nw)))[0] 36 | idx = idx[::-1] if desc else idx 37 | 38 | top_k = [] 39 | for px, py in idx: 40 | if len(top_k) >= k: 41 | break 42 | conflict = False 43 | for qx, qy, v in top_k: 44 | if bb_intersection((px, py, px + kernel_size[0], py + kernel_size[1]), 45 | (qx, qy, qx + kernel_size[0], qy + kernel_size[1])): 46 | conflict = True 47 | break 48 | if not conflict: 49 | # top_k.append((px, py, x[px][py])) 50 | top_k.append(candik(x[px][py], px, py)) 51 | return top_k 52 | 53 | 54 | def cal_imp_conv(n, m, matrix, matrix_cal, matrix1D, INF): 55 | comp_flg = False 56 | cal_num = 0 57 | fcnt = 0 58 | while (comp_flg == False): 59 | fcnt += 1 60 | for i in range(n): 61 | for j in range(m): 62 | tmp_ave = matrix_cal[i][j] 63 | if ((i - 1) >= 0 and (i - 1) < n and (j - 1) >= 0 and (j - 1) < m): 64 | tmp_ave += matrix_cal[i - 1][j - 1] 65 | else: 66 | tmp_ave += INF 67 | if ((i - 1) >= 0 and (i - 1) < n and j >= 0 and j < m): 68 | tmp_ave += matrix_cal[i - 1][j] 69 | else: 70 | tmp_ave += INF 71 | if ((i - 1) >= 0 and (i - 1) < n and (j + 1) >= 0 and (j + 1) < m): 72 | tmp_ave += matrix_cal[i - 1][j + 1] 73 | else: 74 | tmp_ave += INF 75 | if (i >= 0 and i < n and (j - 1) >= 0 and (j - 1) < m): 76 | tmp_ave += matrix_cal[i][j - 1] 77 | else: 78 | tmp_ave += INF 79 | if (i >= 0 and i < n and (j + 1) >= 0 and (j + 1) < m): 80 | tmp_ave += matrix_cal[i][j + 1] 81 | else: 82 | tmp_ave += INF 83 | if ((i + 1) >= 0 and (i + 1) < n and (j - 1) >= 0 and (j - 1) < m): 84 | tmp_ave += matrix_cal[i + 1][j - 1] 85 | else: 86 | tmp_ave += INF 87 | if ((i + 1) >= 0 and (i + 1) < n and j >= 0 and j < m): 88 | tmp_ave += matrix_cal[i + 1][j] 89 | else: 90 | tmp_ave += INF 91 | if ((i + 1) >= 0 and (i + 1) < n and (j + 1) >= 0 and (j + 1) < m): 92 | tmp_ave += matrix_cal[i + 1][j + 1] 93 | else: 94 | tmp_ave += INF 95 | matrix_cal[i][j] = tmp_ave / 9.0 96 | if ((matrix1D[i * m + j] != -1) and (matrix_cal[i][j] - matrix[i][j]) > 0.5): 97 | matrix1D[i * m + j] = -1 98 | cal_num += 1 99 | 100 | if (cal_num == n * m): 101 | comp_flg = True 102 | break 103 | 104 | # print("fcnt =", fcnt) 105 | return matrix_cal 106 | --------------------------------------------------------------------------------