├── requirements.txt ├── networks ├── unet_model.py ├── unet_parts.py ├── extractor.py ├── cross_attn.py └── seg.py ├── README.md ├── predict.py └── model.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.0+cu117 2 | opencv-python==4.8.1.78 3 | numpy==1.26.0 4 | mmcv-full==1.7.1 5 | Pillow==9.4.0 -------------------------------------------------------------------------------- /networks/unet_model.py: -------------------------------------------------------------------------------- 1 | from networks.unet_parts import * 2 | 3 | 4 | class UNet(nn.Module): 5 | def __init__(self, n_channels, n_classes, bilinear=True): 6 | super(UNet, self).__init__() 7 | self.n_channels = n_channels 8 | self.n_classes = n_classes 9 | self.bilinear = bilinear 10 | 11 | self.inc = DoubleConv(n_channels, 64) 12 | self.down1 = Down(64, 128) 13 | self.down2 = Down(128, 256) 14 | self.down3 = Down(256, 512) 15 | factor = 2 if bilinear else 1 16 | self.down4 = Down(512, 1024 // factor) 17 | self.up1 = Up(1024, 512 // factor, bilinear) 18 | self.up2 = Up(512, 256 // factor, bilinear) 19 | self.up3 = Up(256, 128 // factor, bilinear) 20 | self.up4 = Up(128, 64, bilinear) 21 | self.outc = OutConv(64, n_classes) 22 | 23 | for param in self.parameters(): 24 | param.requires_grad = False 25 | 26 | def forward(self, x): 27 | x1 = self.inc(x) 28 | x2 = self.down1(x1) 29 | x3 = self.down2(x2) 30 | x4 = self.down3(x3) 31 | x5 = self.down4(x4) 32 | x = self.up1(x5, x4) 33 | x = self.up2(x, x3) 34 | x = self.up3(x, x2) 35 | x = self.up4(x, x1) 36 | logits = self.outc(x) 37 | return x, logits 38 | 39 | 40 | if __name__ == '__main__': 41 | x1 = torch.rand((2, 3, 224, 224)).cuda() 42 | net = UNet(n_channels=3, n_classes=1).cuda() 43 | print(net) 44 | map, pred_img = net(x1) 45 | n_p = sum(x.numel() for x in net.parameters()) # number parameters 46 | n_g = sum(x.numel() for x in net.parameters() if x.requires_grad) # number gradients 47 | print('Model Summary: %g parameters, %g gradients\n' % (n_p, n_g)) 48 | 49 | print("map: ", map.shape) 50 | print("pred_img: ", pred_img.shape) -------------------------------------------------------------------------------- /networks/unet_parts.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | """(convolution => [BN] => ReLU) * 2""" 10 | 11 | def __init__(self, in_channels, out_channels, mid_channels=None): 12 | super().__init__() 13 | if not mid_channels: 14 | mid_channels = out_channels 15 | self.double_conv = nn.Sequential( 16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 17 | nn.BatchNorm2d(mid_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | 28 | class Down(nn.Module): 29 | """Downscaling with maxpool then double conv""" 30 | 31 | def __init__(self, in_channels, out_channels): 32 | super().__init__() 33 | self.maxpool_conv = nn.Sequential( 34 | nn.MaxPool2d(2), 35 | DoubleConv(in_channels, out_channels) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.maxpool_conv(x) 40 | 41 | 42 | class Up(nn.Module): 43 | """Upscaling then double conv""" 44 | 45 | def __init__(self, in_channels, out_channels, bilinear=True): 46 | super().__init__() 47 | 48 | # if bilinear, use the normal convolutions to reduce the number of channels 49 | if bilinear: 50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 52 | else: 53 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) 54 | self.conv = DoubleConv(in_channels, out_channels) 55 | 56 | 57 | def forward(self, x1, x2): 58 | x1 = self.up(x1) 59 | # input is CHW 60 | diffY = x2.size()[2] - x1.size()[2] 61 | diffX = x2.size()[3] - x1.size()[3] 62 | 63 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 64 | diffY // 2, diffY - diffY // 2]) 65 | # if you have padding issues, see 66 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 67 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 68 | x = torch.cat([x2, x1], dim=1) 69 | return self.conv(x) 70 | 71 | 72 | class OutConv(nn.Module): 73 | def __init__(self, in_channels, out_channels): 74 | super(OutConv, self).__init__() 75 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 76 | 77 | def forward(self, x): 78 | return self.conv(x) 79 | 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Document_Image_Dewarping 2 | 3 | The code for "[Foreground and Text-lines Aware Document Image Rectification](https://openaccess.thecvf.com/content/ICCV2023/papers/Li_Foreground_and_Text-lines_Aware_Document_Image_Rectification_ICCV_2023_paper.pdf)", ICCV, 2023. 4 | 5 | ## Training Dataset 6 | We use the Doc3D dataset for training. You can download the dataset on 7 | [DewarpNet](https://github.com/cvlab-stonybrook/DewarpNet) or [doc3D-dataset](https://github.com/fh2019ustc/doc3D-dataset). 8 | 9 | ## Evaluation Dataset 10 | We evaluate on two datasets [DocUNet Benchmark](https://www3.cs.stonybrook.edu/~cvl/docunet.html) and [DIR300](https://github.com/fh2019ustc/DocGeoNet). 11 | 12 | ## Inference 13 | Please download the pre-trained model from 14 | [Google Drive](https://drive.google.com/drive/folders/1UWL7wWSCcyhHuWLSKQRI9g2_cp0M0aD-?usp=sharing) 15 | or [Baidu Cloud](https://pan.baidu.com/s/1JhEznQEjaVplPQww0CNbHA?pwd=p5yp). Then execute: 16 | 17 | `python predict.py --model_path /MODEL/PATH --img_path /BENCHMARK/DIR --save_path /SAVE/PATH` 18 | 19 | ## Evaluation 20 | 21 | We follow the evaluation environment and code in [DocUNet](https://www3.cs.stonybrook.edu/~cvl/docunet.html) 22 | and [DocGeoNet](https://github.com/fh2019ustc/DocGeoNet). 23 | 24 | For CER and ED metrics evaluation: 25 | 26 | ```text 27 | Tesseract==5.0.1.20220118 (Windows) 28 | pytesseract==0.3.8 29 | ``` 30 | 31 | The dewarped images can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1PHyeZZF88-KzkeuV8YiKHJ_z9lC-3C9o) 32 | or [Baidu Cloud](https://pan.baidu.com/s/1Lq9tRbOM4nV-pQ9sbVfbww?pwd=y41i). 33 | ## Acknowledgement 34 | Our methods and codes are inspired by many existing works, to which we would like to express special thanks to: 35 | 36 | [DocUNet: Document Image Unwarping via A Stacked U-Net](https://www3.cs.stonybrook.edu/~cvl/content/papers/2018/Ma_CVPR18.pdf) 37 | 38 | [DewarpNet: Single-Image Document Unwarping With Stacked 3D and 2D 39 | Regression Networks](https://www3.cs.stonybrook.edu/~cvl/projects/dewarpnet/storage/paper.pdf) 40 | 41 | [DocTr: Document Image Transformer for Geometric Unwarping and Illumination Correction](https://arxiv.org/pdf/2110.12942.pdf) 42 | 43 | [Revisiting Document Image Dewarping by Grid Regularization](https://openaccess.thecvf.com/content/CVPR2022/papers/Jiang_Revisiting_Document_Image_Dewarping_by_Grid_Regularization_CVPR_2022_paper.pdf) 44 | 45 | [Geometric Representation Learning for Document Image Rectification](https://arxiv.org/pdf/2210.08161.pdf) 46 | 47 | 48 | ## Citation 49 | If our methods and code are helpful to you, please refer to the following BibTeX format for citation: 50 | ``` 51 | @inproceedings{li2023foreground, 52 | title={Foreground and Text-lines Aware Document Image Rectification}, 53 | author={Li, Heng and Wu, Xiangping and Chen, Qingcai and Xiang, Qianjin}, 54 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 55 | pages={19574--19583}, 56 | year={2023} 57 | } 58 | ``` 59 | 60 | 61 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code references: 3 | # >>> https://github.com/cvlab-stonybrook/DewarpNet 4 | # >>> https://github.com/fh2019ustc/DocGeoNet 5 | """ 6 | import argparse 7 | import time 8 | 9 | import cv2 10 | import glob 11 | import numpy as np 12 | import os 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | 17 | from PIL import Image 18 | from model import DewarpTextlineMaskGuide 19 | 20 | 21 | def str2bool(v): 22 | if isinstance(v, bool): 23 | return v 24 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 25 | return True 26 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 27 | return False 28 | else: 29 | raise argparse.ArgumentTypeError('Boolean value expected.') 30 | 31 | 32 | def get_args(): 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--input_size', type=int, default=224, help='image size') 35 | parser.add_argument('--model_path', type=str, default='pretrained_models/30.pt', help='model path') 36 | parser.add_argument('--img_path', type=str, default='dataset/Dewarp/DocUNet_dataset/crop', 37 | help='image path or path to folder containing images') 38 | 39 | parser.add_argument('--save_path', type=str, default='infer/', help='save path') 40 | 41 | return parser.parse_args() 42 | 43 | 44 | def predict(img_path, save_path, filename, recti_model): 45 | assert os.path.exists(img_path), 'Incorrect Image Path' 46 | assert os.path.exists(save_path), 'Incorrect Save Path' 47 | 48 | img_size = parser.input_size 49 | 50 | img = np.array(Image.open(img_path))[:, :, :3] / 255. 51 | img_h, img_w, _ = img.shape 52 | input_img = cv2.resize(img, (img_size, img_size)) 53 | 54 | with torch.no_grad(): 55 | recti_model.eval() 56 | input_ = torch.from_numpy(input_img).permute(2, 0, 1).cuda() 57 | input_ = input_.unsqueeze(0) 58 | start = time.time() 59 | 60 | bm = recti_model(input_.float()) 61 | bm = (2 * (bm / 223.) - 1) * 0.99 62 | ps_time = time.time() - start 63 | 64 | bm = bm.detach().cpu() 65 | bm0 = cv2.resize(bm[0, 0].numpy(), (img_w, img_h)) # x flow 66 | bm1 = cv2.resize(bm[0, 1].numpy(), (img_w, img_h)) # y flow 67 | bm0 = cv2.blur(bm0, (3, 3)) 68 | bm1 = cv2.blur(bm1, (3, 3)) 69 | lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0).float() # h * w * 2 70 | 71 | out = F.grid_sample(torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True) 72 | img_geo = ((out[0] * 255.).permute(1, 2, 0).numpy()).astype(np.uint8) 73 | 74 | cv2.imwrite(filename, img_geo[:, :, ::-1]) # save 75 | 76 | return ps_time 77 | 78 | 79 | if __name__ == '__main__': 80 | parser = get_args() 81 | 82 | recti_model = DewarpTextlineMaskGuide(image_size=parser.input_size) 83 | recti_model = torch.nn.DataParallel(recti_model) 84 | state_dict = torch.load(parser.model_path, map_location='cpu') 85 | 86 | recti_model.load_state_dict(state_dict) 87 | recti_model.cuda() 88 | print(f'model loaded') 89 | 90 | img_path = parser.img_path 91 | save_path = parser.save_path 92 | total_time = 0.0 93 | 94 | start = time.time() 95 | img_num = 0.0 96 | for file in glob.glob(img_path + "/*"): # img_names: # 97 | print("file: ", file) 98 | filename = (save_path + "/" + file[file.rindex("/") + 1:file.rindex(".")] + ".png") 99 | 100 | total_time += predict(file, save_path, filename, recti_model) 101 | print("Written ", file[file.rindex("/") + 1:file.rindex(".")]) 102 | img_num += 1 103 | print('FPS: %.1f' % (1.0 / (total_time / img_num))) 104 | 105 | -------------------------------------------------------------------------------- /networks/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | num_groups = planes // 8 15 | 16 | if norm_fn == 'group': 17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | if not stride == 1: 20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | 22 | elif norm_fn == 'batch': 23 | self.norm1 = nn.BatchNorm2d(planes) 24 | self.norm2 = nn.BatchNorm2d(planes) 25 | if not stride == 1: 26 | self.norm3 = nn.BatchNorm2d(planes) 27 | 28 | elif norm_fn == 'instance': 29 | self.norm1 = nn.InstanceNorm2d(planes) 30 | self.norm2 = nn.InstanceNorm2d(planes) 31 | if not stride == 1: 32 | self.norm3 = nn.InstanceNorm2d(planes) 33 | 34 | elif norm_fn == 'none': 35 | self.norm1 = nn.Sequential() 36 | self.norm2 = nn.Sequential() 37 | if not stride == 1: 38 | self.norm3 = nn.Sequential() 39 | 40 | if stride == 1: 41 | self.downsample = None 42 | 43 | else: 44 | self.downsample = nn.Sequential( 45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 46 | 47 | def forward(self, x): 48 | y = x 49 | y = self.relu(self.norm1(self.conv1(y))) 50 | y = self.relu(self.norm2(self.conv2(y))) 51 | 52 | if self.downsample is not None: 53 | x = self.downsample(x) 54 | 55 | return self.relu(x + y) 56 | 57 | 58 | class BasicEncoder(nn.Module): 59 | def __init__(self, in_channels, output_dim=128, norm_fn='batch', return_maps=False): 60 | super(BasicEncoder, self).__init__() 61 | self.norm_fn = norm_fn 62 | 63 | if self.norm_fn == 'group': 64 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 65 | 66 | elif self.norm_fn == 'batch': 67 | self.norm1 = nn.BatchNorm2d(64) 68 | 69 | elif self.norm_fn == 'instance': 70 | self.norm1 = nn.InstanceNorm2d(64) 71 | 72 | elif self.norm_fn == 'none': 73 | self.norm1 = nn.Sequential() 74 | 75 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, stride=2, padding=3) 76 | self.relu1 = nn.ReLU(inplace=True) 77 | 78 | self.in_planes = 64 79 | self.layer1 = self._make_layer(64, stride=1) 80 | self.layer2 = self._make_layer(128, stride=2) 81 | self.layer3 = self._make_layer(256, stride=2) 82 | 83 | for m in self.modules(): 84 | if isinstance(m, nn.Conv2d): 85 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 86 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 87 | if m.weight is not None: 88 | nn.init.constant_(m.weight, 1) 89 | if m.bias is not None: 90 | nn.init.constant_(m.bias, 0) 91 | 92 | def _make_layer(self, dim, stride=1): 93 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 94 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 95 | layers = (layer1, layer2) 96 | 97 | self.in_planes = dim 98 | return nn.Sequential(*layers) 99 | 100 | def forward(self, x): 101 | x = self.conv1(x) 102 | x = self.norm1(x) 103 | x = self.relu1(x) 104 | 105 | x = self.layer1(x) 106 | x = self.layer2(x) 107 | x = self.layer3(x) 108 | 109 | return x 110 | 111 | 112 | if __name__ == '__main__': 113 | x = torch.randn((2, 3, 224, 224)).cuda() 114 | net = BasicEncoder(in_channels=3).cuda() 115 | 116 | print(net) 117 | pred = net(x) 118 | n_p = sum(x.numel() for x in net.parameters()) # number parameters 119 | n_g = sum(x.numel() for x in net.parameters() if x.requires_grad) # number gradients 120 | print('Model Summary: %g parameters, %g gradients\n' % (n_p, n_g)) 121 | 122 | print("pred: ", pred.shape) 123 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from networks.unet_model import UNet 6 | from networks.cross_attn import CrossEncoder, Decoder 7 | from networks.extractor import BasicEncoder 8 | from networks.seg import U2NETP 9 | 10 | 11 | class CAM_Module(nn.Module): 12 | # Reference: https://github.com/yearing1017/DANet_PyTorch/blob/master/DAN_ResNet/attention.py 13 | """ Channel attention module""" 14 | 15 | def __init__(self): 16 | super(CAM_Module, self).__init__() 17 | self.gamma = nn.Parameter(torch.zeros(1)) 18 | self.softmax = nn.Softmax(dim=-1) 19 | 20 | def forward(self, x): 21 | """ 22 | inputs : 23 | x : input mask maps( B X C X H X W) 24 | returns : 25 | out : attention value 26 | attention: B X C X C 27 | """ 28 | m_batchsize, C, height, width = x.size() 29 | proj_query = x.view(m_batchsize, C, -1) 30 | proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1) 31 | energy = torch.bmm(proj_query, proj_key) 32 | energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy) - energy 33 | attention = self.softmax(energy_new) 34 | 35 | proj_value = x.view(m_batchsize, C, -1) 36 | 37 | out = torch.bmm(attention, proj_value) 38 | 39 | out = out.view(m_batchsize, C, height, width) 40 | 41 | out = self.gamma * out + x 42 | return out 43 | 44 | 45 | class FlowHead(nn.Module): 46 | def __init__(self, input_dim=128, hidden_dim=256): 47 | super(FlowHead, self).__init__() 48 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 49 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 50 | self.relu = nn.ReLU(inplace=True) 51 | 52 | def forward(self, x): 53 | return self.conv2(self.relu(self.conv1(x))) 54 | 55 | 56 | class UpdateBlock(nn.Module): 57 | def __init__(self, hidden_dim=128, scale=8): 58 | super(UpdateBlock, self).__init__() 59 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 60 | self.mask = nn.Sequential( 61 | nn.Conv2d(hidden_dim, 64, 3, padding=1), 62 | nn.ReLU(inplace=True), 63 | nn.Conv2d(64, scale * scale * 9, 1, padding=0)) 64 | 65 | def forward(self, imgf, coords1): 66 | mask = 0.25 * self.mask(imgf) # scale mask to balence gradients 67 | dflow = self.flow_head(imgf) 68 | coords1 = coords1 + dflow 69 | 70 | return mask, coords1 71 | 72 | 73 | def coords_grid(batch, ht, wd, gap=1): 74 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 75 | coords = torch.stack(coords[::-1], dim=0).float() 76 | coords = coords[:, ::gap, ::gap] 77 | return coords[None].repeat(batch, 1, 1, 1) 78 | 79 | 80 | def reload_model(model, path=""): 81 | if not bool(path): 82 | return model 83 | else: 84 | model_dict = model.state_dict() 85 | pretrained_dict = torch.load(path, map_location='cuda:0') 86 | # print(len(pretrained_dict.keys())) 87 | pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict} 88 | # print(len(pretrained_dict.keys())) 89 | model_dict.update(pretrained_dict) 90 | model.load_state_dict(model_dict) 91 | 92 | return model 93 | 94 | 95 | class Seg(nn.Module): 96 | def __init__(self): 97 | super(Seg, self).__init__() 98 | self.msk = U2NETP(3, 1) 99 | 100 | def forward(self, x): 101 | d0, hx6, hx5d, hx4d, hx3d, hx2d, hx1d = self.msk(x) 102 | return d0, hx6, hx5d, hx4d, hx3d, hx2d, hx1d 103 | 104 | 105 | class DewarpTextlineMaskCrossAtten(nn.Module): 106 | def __init__(self, image_size=224, hdim=256): 107 | super(DewarpTextlineMaskCrossAtten, self).__init__() 108 | self.image_size = image_size 109 | self.hdim = hdim 110 | self.n_head = 8 111 | self.d_v = self.hdim // self.n_head 112 | self.d_k = self.hdim // self.n_head 113 | self.basic_net = BasicEncoder(in_channels=3, output_dim=self.hdim) 114 | 115 | self.encoder = CrossEncoder(n_layers=12, n_head=self.n_head, d_model=self.hdim, d_k=self.d_k, d_v=self.d_v, 116 | d_inner=2048, n_position=self.image_size // 8) 117 | self.decoder = Decoder(n_layers=6, n_head=self.n_head, d_model=self.hdim * 2, d_k=self.d_k * 2, 118 | d_v=self.d_v * 2, 119 | d_inner=2048, n_position=self.image_size // 8) 120 | 121 | self.cam_1 = CAM_Module() 122 | self.cam_2 = CAM_Module() 123 | self.cam_3 = CAM_Module() 124 | 125 | self.conv3x3_1 = nn.Sequential( 126 | nn.Conv2d(in_channels=64 * 6, out_channels=self.hdim, kernel_size=3, stride=1, padding=1), 127 | nn.BatchNorm2d(self.hdim) 128 | ) 129 | 130 | self.conv3x3_2 = nn.Sequential( 131 | nn.Conv2d(in_channels=64, out_channels=self.hdim, kernel_size=3, stride=1, padding=1), 132 | nn.BatchNorm2d(self.hdim) 133 | ) 134 | 135 | self.update_block = UpdateBlock(self.hdim * 2) 136 | 137 | def _upsample(self, x, size): 138 | _, _, H, W = size 139 | return F.upsample(x, size=(H, W), mode='bilinear') # , align_corners=False) 140 | 141 | def initialize_flow(self, img): 142 | N, C, H, W = img.shape 143 | coodslar = coords_grid(N, H, W).to(img.device) 144 | coords0 = coords_grid(N, H // 8, W // 8).to(img.device) 145 | coords1 = coords_grid(N, H // 8, W // 8).to(img.device) 146 | 147 | return coodslar, coords0, coords1 148 | 149 | def upsample_flow(self, flow, mask): 150 | N, _, H, W = flow.shape 151 | mask = mask.view(N, 1, 9, 8, 8, H, W) 152 | mask = torch.softmax(mask, dim=2) 153 | 154 | up_flow = F.unfold(8 * flow, [3, 3], padding=1) 155 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 156 | 157 | up_flow = torch.sum(mask * up_flow, dim=2) 158 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 159 | 160 | return up_flow.reshape(N, 2, 8 * H, 8 * W) 161 | 162 | def forward(self, image, fore_mask, textline_mask): 163 | fmap = self.basic_net(image) 164 | fmap = self.cam_1(fmap) 165 | 166 | fore_mask = self.conv3x3_1(fore_mask) 167 | fore_mask = self.cam_2(fore_mask) 168 | 169 | textline_mask = self.conv3x3_2(textline_mask) 170 | textline_mask = F.interpolate(textline_mask, size=fmap.shape[2:], mode='bilinear', align_corners=False) 171 | textline_mask = self.cam_3(textline_mask) 172 | 173 | n, c, h, w = fore_mask.size() 174 | output_mask = self.encoder(fmap, fore_mask) 175 | output_mask = output_mask.transpose(1, 2).contiguous().view(n, c, h, w) 176 | 177 | n, c, h, w = textline_mask.size() 178 | output_textline_mask = self.encoder(fmap, textline_mask) 179 | output_textline_mask = output_textline_mask.transpose(1, 2).contiguous().view(n, c, h, w) 180 | 181 | output = torch.cat((output_mask, output_textline_mask), dim=1) 182 | 183 | output = self.decoder(output) 184 | outmap = output.transpose(1, 2).contiguous().view(n, c * 2, h, w) 185 | 186 | coodslar, coords0, coords1 = self.initialize_flow(image) 187 | coords1 = coords1.detach() 188 | 189 | mask, coords1 = self.update_block(outmap, coords1) 190 | flow_up = self.upsample_flow(coords1 - coords0, mask) 191 | bm_up = coodslar + flow_up 192 | 193 | return bm_up 194 | 195 | 196 | class DewarpTextlineMaskGuide(nn.Module): 197 | def __init__(self, image_size=256): 198 | super(DewarpTextlineMaskGuide, self).__init__() 199 | self.hdim = 256 200 | self.dewarp_net = DewarpTextlineMaskCrossAtten(image_size=image_size, hdim=self.hdim) 201 | 202 | self.initialize_weights_() 203 | self.seg = Seg() 204 | self.unet = UNet(n_channels=3, n_classes=1) 205 | 206 | def initialize_weights_(self): 207 | for m in self.modules(): 208 | if isinstance(m, nn.Conv2d): 209 | torch.nn.init.xavier_normal_(m.weight, gain=0.2) 210 | if isinstance(m, nn.ConvTranspose2d): 211 | assert m.kernel_size[0] == m.kernel_size[1] 212 | torch.nn.init.xavier_normal_(m.weight, gain=0.2) 213 | if isinstance(m, nn.Linear): 214 | # we use xavier_uniform following official JAX ViT: 215 | torch.nn.init.xavier_uniform_(m.weight) 216 | if isinstance(m, nn.Linear) and m.bias is not None: 217 | nn.init.constant_(m.bias, 0) 218 | elif isinstance(m, nn.LayerNorm): 219 | nn.init.constant_(m.bias, 0) 220 | nn.init.constant_(m.weight, 1.0) 221 | 222 | def forward(self, image): 223 | d0, hx6, hx5d, hx4d, hx3d, hx2d, hx1d = self.seg(image) 224 | hx6 = F.interpolate(hx6, scale_factor=4, mode='bilinear', align_corners=False) 225 | hx5d = F.interpolate(hx5d, scale_factor=2, mode='bilinear', align_corners=False) 226 | hx4d = F.interpolate(hx4d, scale_factor=1, mode='bilinear', align_corners=False) 227 | hx3d = F.interpolate(hx3d, scale_factor=0.5, mode='bilinear', align_corners=False) 228 | hx2d = F.interpolate(hx2d, scale_factor=0.25, mode='bilinear', align_corners=False) 229 | hx1d = F.interpolate(hx1d, scale_factor=0.125, mode='bilinear', align_corners=False) 230 | 231 | seg_map_all = torch.cat((hx6, hx5d, hx4d, hx3d, hx2d, hx1d), dim=1) 232 | textline_map, textline_mask = self.unet(image) 233 | 234 | bm_up = self.dewarp_net(image, seg_map_all, textline_map) 235 | 236 | return bm_up 237 | 238 | 239 | if __name__ == '__main__': 240 | x = torch.randn((2, 3, 224, 224)).cuda() 241 | 242 | net = DewarpTextlineMaskGuide(image_size=x.shape[-1]).cuda() 243 | 244 | checkpoint_path = 'pretrained_models/30.pt' 245 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 246 | print(checkpoint) 247 | 248 | print(net) 249 | pred = net(x) 250 | n_p = sum(x.numel() for x in net.parameters()) # number parameters 251 | n_g = sum(x.numel() for x in net.parameters() if x.requires_grad) # number gradients 252 | print('Model Summary: %g parameters, %g gradients\n' % (n_p, n_g)) 253 | 254 | print("pred: ", pred.shape) 255 | 256 | -------------------------------------------------------------------------------- /networks/cross_attn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import math 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | from mmcv.runner import BaseModule 9 | from mmcv.cnn import ConvModule 10 | import torch.nn.functional as F 11 | 12 | 13 | class LocalityAwareFeedforward(nn.Module): 14 | """Locality-aware feedforward layer in SATRN, see `SATRN. 15 | `_ 16 | """ 17 | 18 | def __init__(self, 19 | d_in, 20 | d_hid, 21 | dropout=0.1, 22 | ): 23 | super().__init__() 24 | self.conv1 = ConvModule( 25 | d_in, 26 | d_hid, 27 | kernel_size=1, 28 | padding=0, 29 | bias=False, 30 | norm_cfg=dict(type='BN'), 31 | act_cfg=dict(type='ReLU')) 32 | 33 | self.depthwise_conv = ConvModule( 34 | d_hid, 35 | d_hid, 36 | kernel_size=3, 37 | padding=1, 38 | bias=False, 39 | groups=d_hid, 40 | norm_cfg=dict(type='BN'), 41 | act_cfg=dict(type='ReLU')) 42 | 43 | self.conv2 = ConvModule( 44 | d_hid, 45 | d_in, 46 | kernel_size=1, 47 | padding=0, 48 | bias=False, 49 | norm_cfg=dict(type='BN'), 50 | act_cfg=dict(type='ReLU')) 51 | 52 | def forward(self, x): 53 | x = self.conv1(x) 54 | x = self.depthwise_conv(x) 55 | x = self.conv2(x) 56 | 57 | return x 58 | 59 | 60 | class ScaledDotProductAttention(nn.Module): 61 | """Scaled Dot-Product Attention Module. This code is adopted from 62 | https://github.com/jadore801120/attention-is-all-you-need-pytorch. 63 | Args: 64 | temperature (float): The scale factor for softmax input. 65 | attn_dropout (float): Dropout layer on attn_output_weights. 66 | """ 67 | 68 | def __init__(self, temperature, attn_dropout=0.1): 69 | super().__init__() 70 | self.temperature = temperature 71 | self.dropout = nn.Dropout(attn_dropout) 72 | 73 | def forward(self, q, k, v, mask=None): 74 | 75 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) 76 | 77 | if mask is not None: 78 | attn = attn.masked_fill(mask == 0, float('-inf')) 79 | 80 | attn = self.dropout(F.softmax(attn, dim=-1)) 81 | output = torch.matmul(attn, v) 82 | 83 | return output, attn 84 | 85 | 86 | class Adaptive2DPositionalEncoding(BaseModule): 87 | """Implement Adaptive 2D positional encoder for SATRN, see 88 | `SATRN `_ 89 | Modified from https://github.com/Media-Smart/vedastr 90 | Licensed under the Apache License, Version 2.0 (the "License"); 91 | Args: 92 | d_hid (int): Dimensions of hidden layer. 93 | n_height (int): Max height of the 2D feature output. 94 | n_width (int): Max width of the 2D feature output. 95 | dropout (int): Size of hidden layers of the model. 96 | """ 97 | 98 | def __init__(self, 99 | d_hid=512, 100 | n_height=100, 101 | n_width=100, 102 | dropout=0.1, 103 | init_cfg=[dict(type='Xavier', layer='Conv2d')]): 104 | super().__init__(init_cfg=init_cfg) 105 | 106 | h_position_encoder = self._get_sinusoid_encoding_table(n_height, d_hid) 107 | h_position_encoder = h_position_encoder.transpose(0, 1) 108 | h_position_encoder = h_position_encoder.view(1, d_hid, n_height, 1) 109 | 110 | w_position_encoder = self._get_sinusoid_encoding_table(n_width, d_hid) 111 | w_position_encoder = w_position_encoder.transpose(0, 1) 112 | w_position_encoder = w_position_encoder.view(1, d_hid, 1, n_width) 113 | 114 | self.register_buffer('h_position_encoder', h_position_encoder) 115 | self.register_buffer('w_position_encoder', w_position_encoder) 116 | 117 | self.h_scale = self.scale_factor_generate(d_hid) 118 | self.w_scale = self.scale_factor_generate(d_hid) 119 | self.pool = nn.AdaptiveAvgPool2d(1) 120 | self.dropout = nn.Dropout(p=dropout) 121 | 122 | def _get_sinusoid_encoding_table(self, n_position, d_hid): 123 | """Sinusoid position encoding table.""" 124 | denominator = torch.Tensor([ 125 | 1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid) 126 | for hid_j in range(d_hid) 127 | ]) 128 | denominator = denominator.view(1, -1) 129 | pos_tensor = torch.arange(n_position).unsqueeze(-1).float() 130 | sinusoid_table = pos_tensor * denominator 131 | sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2]) 132 | sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2]) 133 | 134 | return sinusoid_table 135 | 136 | def scale_factor_generate(self, d_hid): 137 | scale_factor = nn.Sequential( 138 | nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.ReLU(inplace=True), 139 | nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.Sigmoid()) 140 | 141 | return scale_factor 142 | 143 | def forward(self, x): 144 | b, c, h, w = x.size() 145 | 146 | avg_pool = self.pool(x) 147 | 148 | h_pos_encoding = \ 149 | self.h_scale(avg_pool) * self.h_position_encoder[:, :, :h, :] 150 | w_pos_encoding = \ 151 | self.w_scale(avg_pool) * self.w_position_encoder[:, :, :, :w] 152 | 153 | out = x + h_pos_encoding + w_pos_encoding 154 | 155 | out = self.dropout(out) 156 | 157 | return out 158 | 159 | 160 | class MultiHeadAttention(nn.Module): 161 | """Multi-Head Attention module. 162 | Args: 163 | n_head (int): The number of heads in the 164 | multiheadattention models (default=8). 165 | d_model (int): The number of expected features 166 | in the decoder inputs (default=512). 167 | d_k (int): Total number of features in key. 168 | d_v (int): Total number of features in value. 169 | dropout (float): Dropout layer on attn_output_weights. 170 | qkv_bias (bool): Add bias in projection layer. Default: False. 171 | """ 172 | 173 | def __init__(self, 174 | n_head=8, 175 | d_model=512, 176 | d_k=64, 177 | d_v=64, 178 | dropout=0.1, 179 | qkv_bias=False): 180 | super().__init__() 181 | self.n_head = n_head 182 | self.d_k = d_k 183 | self.d_v = d_v 184 | 185 | self.dim_k = n_head * d_k 186 | self.dim_v = n_head * d_v 187 | 188 | self.linear_q = nn.Linear(self.dim_k, self.dim_k, bias=qkv_bias) 189 | self.linear_k = nn.Linear(self.dim_k, self.dim_k, bias=qkv_bias) 190 | self.linear_v = nn.Linear(self.dim_v, self.dim_v, bias=qkv_bias) 191 | 192 | self.attention = ScaledDotProductAttention(d_k**0.5, dropout) 193 | 194 | self.fc = nn.Linear(self.dim_v, d_model, bias=qkv_bias) 195 | self.proj_drop = nn.Dropout(dropout) 196 | 197 | def forward(self, q, k, v, mask=None): 198 | batch_size, len_q, _ = q.size() 199 | _, len_k, _ = k.size() 200 | 201 | q = self.linear_q(q).view(batch_size, len_q, self.n_head, self.d_k) 202 | k = self.linear_k(k).view(batch_size, len_k, self.n_head, self.d_k) 203 | v = self.linear_v(v).view(batch_size, len_k, self.n_head, self.d_v) 204 | 205 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 206 | 207 | if mask is not None: 208 | if mask.dim() == 3: 209 | mask = mask.unsqueeze(1) 210 | elif mask.dim() == 2: 211 | mask = mask.unsqueeze(1).unsqueeze(1) 212 | 213 | attn_out, _ = self.attention(q, k, v, mask=mask) 214 | 215 | attn_out = attn_out.transpose(1, 2).contiguous().view( 216 | batch_size, len_q, self.dim_v) 217 | 218 | attn_out = self.fc(attn_out) 219 | attn_out = self.proj_drop(attn_out) 220 | 221 | return attn_out 222 | 223 | 224 | class CrossattnLayer(nn.Module): 225 | """""" 226 | 227 | def __init__(self, 228 | d_model=512, 229 | d_inner=512, 230 | n_head=8, 231 | d_k=64, 232 | d_v=64, 233 | dropout=0.1, 234 | qkv_bias=False): 235 | super().__init__() 236 | self.norm1 = nn.LayerNorm(d_model) 237 | self.attn = MultiHeadAttention( 238 | n_head, d_model, d_k, d_v, qkv_bias=qkv_bias, dropout=dropout) 239 | self.cross_attn = MultiHeadAttention( 240 | n_head, d_model, d_k, d_v, qkv_bias=qkv_bias, dropout=dropout) 241 | self.norm2 = nn.LayerNorm(d_model) 242 | self.norm3 = nn.LayerNorm(d_model) 243 | self.feed_forward = LocalityAwareFeedforward( 244 | d_model, d_inner, dropout=dropout) 245 | 246 | def forward(self, x, cross, h, w, mask=None): 247 | n, hw, c = x.size() 248 | residual = x 249 | x = self.norm1(x) 250 | x = residual + self.attn(x, x, x, mask) 251 | residual = x 252 | x = self.norm2(x) 253 | x = residual + self.cross_attn(cross, x, x, mask) 254 | residual = x 255 | x = self.norm3(x) 256 | x = x.transpose(1, 2).contiguous().view(n, c, h, w) 257 | x = self.feed_forward(x) 258 | x = x.view(n, c, hw).transpose(1, 2) 259 | x = residual + x 260 | return x 261 | 262 | 263 | class CrossEncoder(nn.Module): 264 | """Implement encoder for SATRN, see `SATRN. 265 | `_. 266 | Args: 267 | n_layers (int): Number of attention layers. 268 | n_head (int): Number of parallel attention heads. 269 | d_k (int): Dimension of the key vector. 270 | d_v (int): Dimension of the value vector. 271 | d_model (int): Dimension :math:`D_m` of the input from previous model. 272 | n_position (int): Length of the positional encoding vector. Must be 273 | greater than ``max_seq_len``. 274 | d_inner (int): Hidden dimension of feedforward layers. 275 | dropout (float): Dropout rate. 276 | init_cfg (dict or list[dict], optional): Initialization configs. 277 | """ 278 | 279 | def __init__(self, 280 | n_layers=12, 281 | n_head=8, 282 | d_k=64, 283 | d_v=64, 284 | d_model=512, 285 | n_position=100, 286 | d_inner=256, 287 | dropout=0.1): 288 | super().__init__() 289 | self.d_model = d_model 290 | self.position_enc = Adaptive2DPositionalEncoding( 291 | d_hid=d_model, 292 | n_height=n_position, 293 | n_width=n_position, 294 | dropout=dropout) 295 | self.position_enc_cross = Adaptive2DPositionalEncoding( 296 | d_hid=d_model, 297 | n_height=n_position, 298 | n_width=n_position, 299 | dropout=dropout) 300 | 301 | self.layer_stack = nn.ModuleList([ 302 | CrossattnLayer( 303 | d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 304 | for _ in range(n_layers) 305 | ]) 306 | self.layer_norm = nn.LayerNorm(d_model) 307 | 308 | def forward(self, feat, cross_feat, img_metas=None): 309 | """ 310 | Args: 311 | feat (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`. 312 | img_metas (dict): A dict that contains meta information of input 313 | images. Preferably with the key ``valid_ratio``. 314 | Returns: 315 | Tensor: A tensor of shape :math:`(N, T, D_m)`. 316 | """ 317 | valid_ratios = [1.0 for _ in range(feat.size(0))] 318 | if img_metas is not None: 319 | valid_ratios = [ 320 | img_meta.get('valid_ratio', 1.0) for img_meta in img_metas 321 | ] 322 | feat = self.position_enc(feat) 323 | cross_feat = self.position_enc_cross(cross_feat) 324 | 325 | n, c, h, w = feat.size() 326 | mask = feat.new_zeros((n, h, w)) 327 | for i, valid_ratio in enumerate(valid_ratios): 328 | valid_width = min(w, math.ceil(w * valid_ratio)) 329 | mask[i, :, :valid_width] = 1 330 | mask = mask.view(n, h * w) 331 | feat = feat.view(n, c, h * w) 332 | cross_feat = cross_feat.view(n, c, h * w) 333 | 334 | output = feat.permute(0, 2, 1).contiguous() 335 | cross = cross_feat.permute(0, 2, 1).contiguous() 336 | for enc_layer in self.layer_stack: 337 | output = enc_layer(output, cross, h, w, mask) 338 | output = self.layer_norm(output) 339 | 340 | return output 341 | 342 | 343 | class DecoderLayer(nn.Module): 344 | """Implement encoder for SATRN, see `SATRN. 345 | `_. 346 | Args: 347 | n_layers (int): Number of attention layers. 348 | n_head (int): Number of parallel attention heads. 349 | d_k (int): Dimension of the key vector. 350 | d_v (int): Dimension of the value vector. 351 | d_model (int): Dimension :math:`D_m` of the input from previous model. 352 | n_position (int): Length of the positional encoding vector. Must be 353 | greater than ``max_seq_len``. 354 | d_inner (int): Hidden dimension of feedforward layers. 355 | dropout (float): Dropout rate. 356 | init_cfg (dict or list[dict], optional): Initialization configs. 357 | """ 358 | 359 | def __init__(self, 360 | d_model=512, 361 | d_inner=256, 362 | n_head=8, 363 | d_k=64, 364 | d_v=64, 365 | n_position=100, 366 | dropout=0.1, 367 | qkv_bias=False): 368 | super().__init__() 369 | self.d_model = d_model 370 | self.norm1 = nn.LayerNorm(d_model) 371 | self.attn = MultiHeadAttention( 372 | n_head, d_model, d_k, d_v, qkv_bias=qkv_bias, dropout=dropout) 373 | self.norm2 = nn.LayerNorm(d_model) 374 | self.feed_forward = LocalityAwareFeedforward( 375 | d_model, d_inner, dropout=dropout) 376 | 377 | def forward(self, x, h, w, mask=None): 378 | """ 379 | Args: 380 | feat (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`. 381 | img_metas (dict): A dict that contains meta information of input 382 | images. Preferably with the key ``valid_ratio``. 383 | Returns: 384 | Tensor: A tensor of shape :math:`(N, T, D_m)`. 385 | """ 386 | n, hw, c = x.size() 387 | residual = x 388 | x = self.norm1(x) 389 | x = residual + self.attn(x, x, x, mask) 390 | residual = x 391 | x = self.norm2(x) 392 | x = x.transpose(1, 2).contiguous().view(n, c, h, w) 393 | x = self.feed_forward(x) 394 | x = x.view(n, c, hw).transpose(1, 2) 395 | x = residual + x 396 | return x 397 | 398 | 399 | class Decoder(nn.Module): 400 | """Implement encoder for SATRN, see `SATRN. 401 | `_. 402 | Args: 403 | n_layers (int): Number of attention layers. 404 | n_head (int): Number of parallel attention heads. 405 | d_k (int): Dimension of the key vector. 406 | d_v (int): Dimension of the value vector. 407 | d_model (int): Dimension :math:`D_m` of the input from previous model. 408 | n_position (int): Length of the positional encoding vector. Must be 409 | greater than ``max_seq_len``. 410 | d_inner (int): Hidden dimension of feedforward layers. 411 | dropout (float): Dropout rate. 412 | init_cfg (dict or list[dict], optional): Initialization configs. 413 | """ 414 | 415 | def __init__(self, 416 | n_layers=4, 417 | n_head=8, 418 | d_k=64, 419 | d_v=64, 420 | d_model=512, 421 | n_position=100, 422 | d_inner=256, 423 | dropout=0.1, 424 | qkv_bias=False): 425 | super().__init__() 426 | self.d_model = d_model 427 | 428 | self.position_dec = Adaptive2DPositionalEncoding( 429 | d_hid=d_model, 430 | n_height=n_position, 431 | n_width=n_position, 432 | dropout=dropout) 433 | 434 | self.layer_stack = nn.ModuleList([ 435 | DecoderLayer( 436 | d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 437 | for _ in range(n_layers) 438 | ]) 439 | 440 | self.layer_norm = nn.LayerNorm(d_model) 441 | 442 | def forward(self, feat): 443 | valid_ratios = [1.0 for _ in range(feat.size(0))] 444 | feat = self.position_dec(feat) 445 | n, c, h, w = feat.size() 446 | 447 | mask = feat.new_zeros((n, h, w)) 448 | for i, valid_ratio in enumerate(valid_ratios): 449 | valid_width = min(w, math.ceil(w * valid_ratio)) 450 | mask[i, :, :valid_width] = 1 451 | mask = mask.view(n, h * w) 452 | 453 | feat = feat.view(n, c, h * w) 454 | output = feat.permute(0, 2, 1).contiguous() 455 | for dec_layer in self.layer_stack: 456 | output = dec_layer(output, h, w, mask) 457 | output = self.layer_norm(output) 458 | return output 459 | 460 | 461 | if __name__ == '__main__': 462 | x1 = torch.randn((2, 512, 28, 28)).cuda() 463 | x2 = torch.randn((2, 512, 28, 28)).cuda() 464 | 465 | n, c, h, w = x1.size() 466 | 467 | encoder = CrossEncoder(n_layers=8, n_position=28).cuda() 468 | decoder = Decoder(n_layers=4, n_position=28).cuda() 469 | 470 | print(encoder) 471 | output = encoder(x1, x2) 472 | output = output.transpose(1, 2).contiguous().view(n, c, h, w) 473 | output = decoder(output) 474 | 475 | # loss = net(x) 476 | n_p = sum(x.numel() for x in encoder.parameters()) # number parameters 477 | n_g = sum(x.numel() for x in encoder.parameters() if x.requires_grad) # number gradients 478 | print('Model Summary: %g parameters, %g gradients\n' % (n_p, n_g)) 479 | 480 | print("output: ", output.shape) -------------------------------------------------------------------------------- /networks/seg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class sobel_net(nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | self.conv_opx = nn.Conv2d(1, 1, 3, bias=False) 11 | self.conv_opy = nn.Conv2d(1, 1, 3, bias=False) 12 | sobel_kernelx = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype='float32').reshape((1, 1, 3, 3)) 13 | sobel_kernely = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype='float32').reshape((1, 1, 3, 3)) 14 | self.conv_opx.weight.data = torch.from_numpy(sobel_kernelx) 15 | self.conv_opy.weight.data = torch.from_numpy(sobel_kernely) 16 | 17 | # for p in self.parameters(): 18 | # p.requires_grad = False 19 | 20 | def forward(self, im): # input rgb 21 | x = (0.299 * im[:, 0, :, :] + 0.587 * im[:, 1, :, :] + 0.114 * im[:, 2, :, :]).unsqueeze(1) # rgb2gray 22 | gradx = self.conv_opx(x) 23 | grady = self.conv_opy(x) 24 | 25 | x = (gradx ** 2 + grady ** 2) ** 0.5 26 | x = (x - x.min()) / (x.max() - x.min()) 27 | x = F.pad(x, (1, 1, 1, 1)) 28 | 29 | x = torch.cat([im, x], dim=1) 30 | return x 31 | 32 | 33 | class REBNCONV(nn.Module): 34 | def __init__(self, in_ch=3, out_ch=3, dirate=1): 35 | super(REBNCONV, self).__init__() 36 | 37 | self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate) 38 | self.bn_s1 = nn.BatchNorm2d(out_ch) 39 | self.relu_s1 = nn.ReLU(inplace=True) 40 | 41 | def forward(self, x): 42 | hx = x 43 | xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) 44 | 45 | return xout 46 | 47 | 48 | ## upsample tensor 'src' to have the same spatial size with tensor 'tar' 49 | def _upsample_like(src, tar): 50 | src = F.interpolate(src, size=tar.shape[2:], mode='bilinear', align_corners=False) 51 | 52 | return src 53 | 54 | 55 | ### RSU-7 ### 56 | class RSU7(nn.Module): # UNet07DRES(nn.Module): 57 | 58 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 59 | super(RSU7, self).__init__() 60 | 61 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 62 | 63 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 64 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 65 | 66 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) 67 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 68 | 69 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) 70 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 71 | 72 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) 73 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 74 | 75 | self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) 76 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 77 | 78 | self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1) 79 | 80 | self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2) 81 | 82 | self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 83 | self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 84 | self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 85 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 86 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 87 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 88 | 89 | def forward(self, x): 90 | hx = x 91 | hxin = self.rebnconvin(hx) 92 | 93 | hx1 = self.rebnconv1(hxin) 94 | hx = self.pool1(hx1) 95 | 96 | hx2 = self.rebnconv2(hx) 97 | hx = self.pool2(hx2) 98 | 99 | hx3 = self.rebnconv3(hx) 100 | hx = self.pool3(hx3) 101 | 102 | hx4 = self.rebnconv4(hx) 103 | hx = self.pool4(hx4) 104 | 105 | hx5 = self.rebnconv5(hx) 106 | hx = self.pool5(hx5) 107 | 108 | hx6 = self.rebnconv6(hx) 109 | 110 | hx7 = self.rebnconv7(hx6) 111 | 112 | hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1)) 113 | hx6dup = _upsample_like(hx6d, hx5) 114 | 115 | hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1)) 116 | hx5dup = _upsample_like(hx5d, hx4) 117 | 118 | hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) 119 | hx4dup = _upsample_like(hx4d, hx3) 120 | 121 | hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) 122 | hx3dup = _upsample_like(hx3d, hx2) 123 | 124 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 125 | hx2dup = _upsample_like(hx2d, hx1) 126 | 127 | hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) 128 | 129 | return hx1d + hxin 130 | 131 | 132 | ### RSU-6 ### 133 | class RSU6(nn.Module): # UNet06DRES(nn.Module): 134 | 135 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 136 | super(RSU6, self).__init__() 137 | 138 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 139 | 140 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 141 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 142 | 143 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) 144 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 145 | 146 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) 147 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 148 | 149 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) 150 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 151 | 152 | self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) 153 | 154 | self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2) 155 | 156 | self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 157 | self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 158 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 159 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 160 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 161 | 162 | def forward(self, x): 163 | hx = x 164 | 165 | hxin = self.rebnconvin(hx) 166 | 167 | hx1 = self.rebnconv1(hxin) 168 | hx = self.pool1(hx1) 169 | 170 | hx2 = self.rebnconv2(hx) 171 | hx = self.pool2(hx2) 172 | 173 | hx3 = self.rebnconv3(hx) 174 | hx = self.pool3(hx3) 175 | 176 | hx4 = self.rebnconv4(hx) 177 | hx = self.pool4(hx4) 178 | 179 | hx5 = self.rebnconv5(hx) 180 | 181 | hx6 = self.rebnconv6(hx5) 182 | 183 | hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1)) 184 | hx5dup = _upsample_like(hx5d, hx4) 185 | 186 | hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) 187 | hx4dup = _upsample_like(hx4d, hx3) 188 | 189 | hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) 190 | hx3dup = _upsample_like(hx3d, hx2) 191 | 192 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 193 | hx2dup = _upsample_like(hx2d, hx1) 194 | 195 | hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) 196 | 197 | return hx1d + hxin 198 | 199 | 200 | ### RSU-5 ### 201 | class RSU5(nn.Module): # UNet05DRES(nn.Module): 202 | 203 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 204 | super(RSU5, self).__init__() 205 | 206 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 207 | 208 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 209 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 210 | 211 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) 212 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 213 | 214 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) 215 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 216 | 217 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) 218 | 219 | self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2) 220 | 221 | self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 222 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 223 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 224 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 225 | 226 | def forward(self, x): 227 | hx = x 228 | 229 | hxin = self.rebnconvin(hx) 230 | 231 | hx1 = self.rebnconv1(hxin) 232 | hx = self.pool1(hx1) 233 | 234 | hx2 = self.rebnconv2(hx) 235 | hx = self.pool2(hx2) 236 | 237 | hx3 = self.rebnconv3(hx) 238 | hx = self.pool3(hx3) 239 | 240 | hx4 = self.rebnconv4(hx) 241 | 242 | hx5 = self.rebnconv5(hx4) 243 | 244 | hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1)) 245 | hx4dup = _upsample_like(hx4d, hx3) 246 | 247 | hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) 248 | hx3dup = _upsample_like(hx3d, hx2) 249 | 250 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 251 | hx2dup = _upsample_like(hx2d, hx1) 252 | 253 | hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) 254 | 255 | return hx1d + hxin 256 | 257 | 258 | ### RSU-4 ### 259 | class RSU4(nn.Module): # UNet04DRES(nn.Module): 260 | 261 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 262 | super(RSU4, self).__init__() 263 | 264 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 265 | 266 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 267 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 268 | 269 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) 270 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 271 | 272 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) 273 | 274 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2) 275 | 276 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 277 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 278 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 279 | 280 | def forward(self, x): 281 | hx = x 282 | 283 | hxin = self.rebnconvin(hx) 284 | 285 | hx1 = self.rebnconv1(hxin) 286 | hx = self.pool1(hx1) 287 | 288 | hx2 = self.rebnconv2(hx) 289 | hx = self.pool2(hx2) 290 | 291 | hx3 = self.rebnconv3(hx) 292 | 293 | hx4 = self.rebnconv4(hx3) 294 | 295 | hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) 296 | hx3dup = _upsample_like(hx3d, hx2) 297 | 298 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 299 | hx2dup = _upsample_like(hx2d, hx1) 300 | 301 | hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) 302 | 303 | return hx1d + hxin 304 | 305 | 306 | ### RSU-4F ### 307 | class RSU4F(nn.Module): # UNet04FRES(nn.Module): 308 | 309 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 310 | super(RSU4F, self).__init__() 311 | 312 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 313 | 314 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 315 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2) 316 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4) 317 | 318 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8) 319 | 320 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4) 321 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2) 322 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 323 | 324 | def forward(self, x): 325 | hx = x 326 | 327 | hxin = self.rebnconvin(hx) 328 | 329 | hx1 = self.rebnconv1(hxin) 330 | hx2 = self.rebnconv2(hx1) 331 | hx3 = self.rebnconv3(hx2) 332 | 333 | hx4 = self.rebnconv4(hx3) 334 | 335 | hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) 336 | hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1)) 337 | hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1)) 338 | 339 | return hx1d + hxin 340 | 341 | 342 | ##### U^2-Net #### 343 | class U2NET(nn.Module): 344 | 345 | def __init__(self, in_ch=3, out_ch=1): 346 | super(U2NET, self).__init__() 347 | self.edge = sobel_net() 348 | 349 | self.stage1 = RSU7(in_ch, 32, 64) 350 | self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 351 | 352 | self.stage2 = RSU6(64, 32, 128) 353 | self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 354 | 355 | self.stage3 = RSU5(128, 64, 256) 356 | self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 357 | 358 | self.stage4 = RSU4(256, 128, 512) 359 | self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 360 | 361 | self.stage5 = RSU4F(512, 256, 512) 362 | self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 363 | 364 | self.stage6 = RSU4F(512, 256, 512) 365 | 366 | # decoder 367 | self.stage5d = RSU4F(1024, 256, 512) 368 | self.stage4d = RSU4(1024, 128, 256) 369 | self.stage3d = RSU5(512, 64, 128) 370 | self.stage2d = RSU6(256, 32, 64) 371 | self.stage1d = RSU7(128, 16, 64) 372 | 373 | self.side1 = nn.Conv2d(64, out_ch, 3, padding=1) 374 | self.side2 = nn.Conv2d(64, out_ch, 3, padding=1) 375 | self.side3 = nn.Conv2d(128, out_ch, 3, padding=1) 376 | self.side4 = nn.Conv2d(256, out_ch, 3, padding=1) 377 | self.side5 = nn.Conv2d(512, out_ch, 3, padding=1) 378 | self.side6 = nn.Conv2d(512, out_ch, 3, padding=1) 379 | 380 | self.outconv = nn.Conv2d(6, out_ch, 1) 381 | 382 | def forward(self, x): 383 | x = self.edge(x) 384 | hx = x 385 | 386 | # stage 1 387 | hx1 = self.stage1(hx) 388 | hx = self.pool12(hx1) 389 | 390 | # stage 2 391 | hx2 = self.stage2(hx) 392 | hx = self.pool23(hx2) 393 | 394 | # stage 3 395 | hx3 = self.stage3(hx) 396 | hx = self.pool34(hx3) 397 | 398 | # stage 4 399 | hx4 = self.stage4(hx) 400 | hx = self.pool45(hx4) 401 | 402 | # stage 5 403 | hx5 = self.stage5(hx) 404 | hx = self.pool56(hx5) 405 | 406 | # stage 6 407 | hx6 = self.stage6(hx) 408 | hx6up = _upsample_like(hx6, hx5) 409 | 410 | # -------------------- decoder -------------------- 411 | hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) 412 | hx5dup = _upsample_like(hx5d, hx4) 413 | 414 | hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) 415 | hx4dup = _upsample_like(hx4d, hx3) 416 | 417 | hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) 418 | hx3dup = _upsample_like(hx3d, hx2) 419 | 420 | hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) 421 | hx2dup = _upsample_like(hx2d, hx1) 422 | 423 | hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) 424 | 425 | # side output 426 | d1 = self.side1(hx1d) 427 | 428 | d2 = self.side2(hx2d) 429 | d2 = _upsample_like(d2, d1) 430 | 431 | d3 = self.side3(hx3d) 432 | d3 = _upsample_like(d3, d1) 433 | 434 | d4 = self.side4(hx4d) 435 | d4 = _upsample_like(d4, d1) 436 | 437 | d5 = self.side5(hx5d) 438 | d5 = _upsample_like(d5, d1) 439 | 440 | d6 = self.side6(hx6) 441 | d6 = _upsample_like(d6, d1) 442 | 443 | d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1)) 444 | 445 | return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid( 446 | d4), torch.sigmoid(d5), torch.sigmoid(d6) 447 | 448 | 449 | ### U^2-Net small ### 450 | class U2NETP(nn.Module): 451 | 452 | def __init__(self, in_ch=3, out_ch=1): 453 | super(U2NETP, self).__init__() 454 | 455 | self.stage1 = RSU7(in_ch, 16, 64) 456 | self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 457 | 458 | self.stage2 = RSU6(64, 16, 64) 459 | self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 460 | 461 | self.stage3 = RSU5(64, 16, 64) 462 | self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 463 | 464 | self.stage4 = RSU4(64, 16, 64) 465 | self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 466 | 467 | self.stage5 = RSU4F(64, 16, 64) 468 | self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 469 | 470 | self.stage6 = RSU4F(64, 16, 64) 471 | 472 | # decoder 473 | self.stage5d = RSU4F(128, 16, 64) 474 | self.stage4d = RSU4(128, 16, 64) 475 | self.stage3d = RSU5(128, 16, 64) 476 | self.stage2d = RSU6(128, 16, 64) 477 | self.stage1d = RSU7(128, 16, 64) 478 | 479 | self.side1 = nn.Conv2d(64, out_ch, 3, padding=1) 480 | self.side2 = nn.Conv2d(64, out_ch, 3, padding=1) 481 | self.side3 = nn.Conv2d(64, out_ch, 3, padding=1) 482 | self.side4 = nn.Conv2d(64, out_ch, 3, padding=1) 483 | self.side5 = nn.Conv2d(64, out_ch, 3, padding=1) 484 | self.side6 = nn.Conv2d(64, out_ch, 3, padding=1) 485 | 486 | self.outconv = nn.Conv2d(6, out_ch, 1) 487 | 488 | # don't need the gradients, just want the features 489 | for param in self.parameters(): 490 | param.requires_grad = False 491 | 492 | def forward(self, x): 493 | hx = x 494 | 495 | # stage 1 496 | hx1 = self.stage1(hx) 497 | hx = self.pool12(hx1) 498 | 499 | # stage 2 500 | hx2 = self.stage2(hx) 501 | hx = self.pool23(hx2) 502 | 503 | # stage 3 504 | hx3 = self.stage3(hx) 505 | hx = self.pool34(hx3) 506 | 507 | # stage 4 508 | hx4 = self.stage4(hx) 509 | hx = self.pool45(hx4) 510 | 511 | # stage 5 512 | hx5 = self.stage5(hx) 513 | hx = self.pool56(hx5) 514 | 515 | # stage 6 516 | features = [] 517 | hx6 = self.stage6(hx) 518 | 519 | hx6up = _upsample_like(hx6, hx5) 520 | 521 | # decoder 522 | hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) 523 | hx5dup = _upsample_like(hx5d, hx4) 524 | 525 | hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) 526 | hx4dup = _upsample_like(hx4d, hx3) 527 | 528 | hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) 529 | hx3dup = _upsample_like(hx3d, hx2) 530 | 531 | hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) 532 | hx2dup = _upsample_like(hx2d, hx1) 533 | 534 | hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) 535 | 536 | features.append(hx2dup) 537 | features.append(hx3dup) 538 | features.append(hx4dup) 539 | features.append(hx5dup) 540 | features.append(hx6up) 541 | 542 | # side output 543 | d1 = self.side1(hx1d) 544 | 545 | d2_ = self.side2(hx2d) 546 | d2 = _upsample_like(d2_, d1) 547 | 548 | d3_ = self.side3(hx3d) 549 | d3 = _upsample_like(d3_, d1) 550 | 551 | d4_ = self.side4(hx4d) 552 | d4 = _upsample_like(d4_, d1) 553 | 554 | d5_ = self.side5(hx5d) 555 | d5 = _upsample_like(d5_, d1) 556 | 557 | d6_ = self.side6(hx6) 558 | d6 = _upsample_like(d6_, d1) 559 | 560 | d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1)) 561 | 562 | return torch.sigmoid(d0), hx6, hx5d, hx4d, hx3d, hx2d, hx1d 563 | 564 | 565 | def get_parameter_number(net): 566 | total_num = sum(p.numel() for p in net.parameters()) 567 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) 568 | return {'Total': total_num, 'Trainable': trainable_num} 569 | 570 | 571 | if __name__ == '__main__': 572 | net = U2NETP(3, 1).cuda() 573 | print(get_parameter_number(net)) # 69090500 加attention后69442032 574 | with torch.no_grad(): 575 | inputs = torch.zeros(1, 3, 224, 224).cuda() 576 | outs = net(inputs) 577 | print(outs[0].shape) # torch.Size([2, 3, 256, 256]) torch.Size([2, 2, 256, 256]) 578 | 579 | --------------------------------------------------------------------------------