├── .gitignore ├── DocShadow ├── models │ ├── __init__.py │ ├── backbone.py │ ├── blocks.py │ ├── blocks_.py │ └── model.py └── utils │ ├── __init__.py │ └── utils.py ├── LICENSE ├── README.md ├── assets ├── latency.png └── sample.jpg ├── eval.py ├── export.py ├── export_coreml.py ├── infer.py ├── onnx_runner ├── __init__.py └── docshadow.py ├── requirements-onnx.txt ├── requirements.txt └── weights ├── .gitkeep └── download.sh /.gitignore: -------------------------------------------------------------------------------- 1 | *.onnx 2 | *.pth 3 | *.engine 4 | *.profile 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | 167 | # VSCode 168 | .vscode* 169 | -------------------------------------------------------------------------------- /DocShadow/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import DocShadow 2 | -------------------------------------------------------------------------------- /DocShadow/models/backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LayerNormFunction(torch.autograd.Function): 7 | 8 | @staticmethod 9 | def forward(ctx, x, weight, bias, eps): 10 | ctx.eps = eps 11 | N, C, H, W = x.size() 12 | mu = x.mean(1, keepdim=True) 13 | var = (x - mu).pow(2).mean(1, keepdim=True) 14 | y = (x - mu) / (var + eps).sqrt() 15 | ctx.save_for_backward(y, var, weight) 16 | y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) 17 | return y 18 | 19 | @staticmethod 20 | def backward(ctx, grad_output): 21 | eps = ctx.eps 22 | 23 | N, C, H, W = grad_output.size() 24 | y, var, weight = ctx.saved_variables 25 | g = grad_output * weight.view(1, C, 1, 1) 26 | mean_g = g.mean(dim=1, keepdim=True) 27 | 28 | mean_gy = (g * y).mean(dim=1, keepdim=True) 29 | gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) 30 | return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum( 31 | dim=0), None 32 | 33 | 34 | class LayerNorm2d(nn.Module): 35 | 36 | def __init__(self, channels, eps=1e-6): 37 | super(LayerNorm2d, self).__init__() 38 | self.register_parameter('weight', nn.Parameter(torch.ones(channels))) 39 | self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) 40 | self.eps = eps 41 | 42 | def forward(self, x): 43 | return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) 44 | 45 | 46 | class SimpleGate(nn.Module): 47 | def forward(self, x): 48 | x1, x2 = x.chunk(2, dim=1) 49 | return x1 * x2 50 | 51 | 52 | # ------------ Deep Feature Extraction Block ----------------- 53 | class DFEBlock(nn.Module): 54 | def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): 55 | super().__init__() 56 | dw_channel = c * DW_Expand 57 | self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, 58 | bias=True) 59 | self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, 60 | groups=dw_channel, 61 | bias=True) 62 | self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, 63 | groups=1, bias=True) 64 | 65 | # Simplified Channel Attention 66 | self.sca = nn.Sequential( 67 | nn.AdaptiveAvgPool2d(1), 68 | nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, 69 | groups=1, bias=True), 70 | ) 71 | 72 | # SimpleGate 73 | self.sg = SimpleGate() 74 | 75 | ffn_channel = FFN_Expand * c 76 | self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, 77 | bias=True) 78 | self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, 79 | groups=1, bias=True) 80 | 81 | self.norm1 = LayerNorm2d(c) 82 | self.norm2 = LayerNorm2d(c) 83 | 84 | self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() 85 | self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() 86 | 87 | self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) 88 | self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) 89 | 90 | def forward(self, inp): 91 | x = inp 92 | 93 | x = self.norm1(x) 94 | 95 | x = self.conv1(x) 96 | x = self.conv2(x) 97 | x = self.sg(x) 98 | x = x * self.sca(x) 99 | x = self.conv3(x) 100 | 101 | x = self.dropout1(x) 102 | 103 | y = inp + x * self.beta 104 | 105 | x = self.conv4(self.norm2(y)) 106 | x = self.sg(x) 107 | x = self.conv5(x) 108 | 109 | x = self.dropout2(x) 110 | 111 | return y + x * self.gamma 112 | 113 | 114 | class DFE(nn.Module): 115 | 116 | def __init__(self, img_channel=3, width=32, middle_blk_num=12, enc_blk_nums=[2, 2, 4, 8], dec_blk_nums=[2, 2, 2, 2]): 117 | super().__init__() 118 | 119 | self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, 120 | groups=1, 121 | bias=True) 122 | self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, 123 | groups=1, 124 | bias=True) 125 | 126 | self.encoders = nn.ModuleList() 127 | self.decoders = nn.ModuleList() 128 | self.middle_blks = nn.ModuleList() 129 | self.ups = nn.ModuleList() 130 | self.downs = nn.ModuleList() 131 | 132 | chan = width 133 | for num in enc_blk_nums: 134 | self.encoders.append( 135 | nn.Sequential( 136 | *[DFEBlock(chan) for _ in range(num)] 137 | ) 138 | ) 139 | self.downs.append( 140 | nn.Conv2d(chan, 2 * chan, 2, 2) 141 | ) 142 | chan = chan * 2 143 | 144 | self.middle_blks = \ 145 | nn.Sequential( 146 | *[DFEBlock(chan) for _ in range(middle_blk_num)] 147 | ) 148 | 149 | for num in dec_blk_nums: 150 | self.ups.append( 151 | nn.Sequential( 152 | nn.Conv2d(chan, chan * 2, 1, bias=False), 153 | nn.PixelShuffle(2) 154 | ) 155 | ) 156 | chan = chan // 2 157 | self.decoders.append( 158 | nn.Sequential( 159 | *[DFEBlock(chan) for _ in range(num)] 160 | ) 161 | ) 162 | 163 | self.padder_size = 2 ** len(self.encoders) 164 | 165 | def forward(self, inp): 166 | B, C, H, W = inp.shape 167 | inp = self.check_image_size(inp) 168 | 169 | x = self.intro(inp) 170 | 171 | encs = [] 172 | 173 | for encoder, down in zip(self.encoders, self.downs): 174 | x = encoder(x) 175 | encs.append(x) 176 | x = down(x) 177 | 178 | x = self.middle_blks(x) 179 | 180 | for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): 181 | x = up(x) 182 | x = x + enc_skip 183 | x = decoder(x) 184 | 185 | x = self.ending(x) 186 | x = x + inp 187 | 188 | return x[:, :, :H, :W] 189 | 190 | def check_image_size(self, x): 191 | _, _, h, w = x.size() 192 | mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size 193 | mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size 194 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) 195 | return x 196 | -------------------------------------------------------------------------------- /DocShadow/models/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ConvLayer(nn.Module): 7 | def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, bias=True, groups=1, norm='in', 8 | nonlinear='relu'): 9 | super(ConvLayer, self).__init__() 10 | reflection_padding = (kernel_size + (dilation - 1) * (kernel_size - 1)) // 2 11 | self.reflection_pad = nn.ReflectionPad2d(reflection_padding) 12 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, groups=groups, bias=bias, 13 | dilation=dilation) 14 | self.norm = norm 15 | self.nonlinear = nonlinear 16 | 17 | if norm == 'bn': 18 | self.normalization = nn.BatchNorm2d(out_channels) 19 | elif norm == 'in': 20 | self.normalization = nn.InstanceNorm2d(out_channels, affine=False) 21 | else: 22 | self.normalization = None 23 | 24 | if nonlinear == 'relu': 25 | self.activation = nn.ReLU(inplace=True) 26 | elif nonlinear == 'leakyrelu': 27 | self.activation = nn.LeakyReLU(0.2) 28 | elif nonlinear == 'PReLU': 29 | self.activation = nn.PReLU() 30 | else: 31 | self.activation = None 32 | 33 | def forward(self, x): 34 | out = self.conv2d(self.reflection_pad(x)) 35 | if self.normalization is not None: 36 | out = self.normalization(out) 37 | if self.activation is not None: 38 | out = self.activation(out) 39 | 40 | return out 41 | 42 | 43 | class Aggreation(nn.Module): 44 | def __init__(self, in_channels, out_channels, kernel_size=3): 45 | super(Aggreation, self).__init__() 46 | self.attention = SelfAttention(in_channels, k=8, nonlinear='relu') 47 | self.conv = ConvLayer(in_channels, out_channels, kernel_size=kernel_size, stride=1, dilation=1, 48 | nonlinear='leakyrelu', 49 | norm=None) 50 | 51 | def forward(self, x): 52 | return self.conv(self.attention(x)) 53 | 54 | 55 | class SelfAttention(nn.Module): 56 | def __init__(self, channels, k, nonlinear='relu'): 57 | super(SelfAttention, self).__init__() 58 | self.channels = channels 59 | self.k = k 60 | self.nonlinear = nonlinear 61 | 62 | self.linear1 = nn.Linear(channels, channels // k) 63 | self.linear2 = nn.Linear(channels // k, channels) 64 | self.global_pooling = nn.AdaptiveAvgPool2d((1, 1)) 65 | 66 | if nonlinear == 'relu': 67 | self.activation = nn.ReLU(inplace=True) 68 | elif nonlinear == 'leakyrelu': 69 | self.activation = nn.LeakyReLU(0.2) 70 | elif nonlinear == 'PReLU': 71 | self.activation = nn.PReLU() 72 | else: 73 | raise ValueError 74 | 75 | def attention(self, x): 76 | N, C, H, W = x.size() 77 | out = torch.flatten(self.global_pooling(x), 1) 78 | out = self.activation(self.linear1(out)) 79 | out = torch.sigmoid(self.linear2(out)).view(N, C, 1, 1) 80 | 81 | return out.mul(x) 82 | 83 | def forward(self, x): 84 | return self.attention(x) 85 | 86 | 87 | class SPP(nn.Module): 88 | def __init__(self, in_channels, out_channels, num_layers=4, interpolation_type='bilinear'): 89 | super(SPP, self).__init__() 90 | self.conv = nn.ModuleList() 91 | self.num_layers = num_layers 92 | self.interpolation_type = interpolation_type 93 | 94 | for _ in range(self.num_layers): 95 | self.conv.append( 96 | ConvLayer(in_channels, in_channels, kernel_size=1, stride=1, dilation=1, nonlinear='leakyrelu', 97 | norm=None)) 98 | 99 | self.fusion = ConvLayer((in_channels * (self.num_layers + 1)), out_channels, kernel_size=3, stride=1, 100 | norm='False', nonlinear='leakyrelu') 101 | 102 | def forward(self, x): 103 | 104 | N, C, H, W = x.size() 105 | out = [] 106 | 107 | for level in range(self.num_layers): 108 | out.append(F.interpolate(self.conv[level]( 109 | F.avg_pool2d(x, kernel_size=2 * 2 ** (level + 1), stride=2 * 2 ** (level + 1), 110 | padding=2 * 2 ** (level + 1) % 2)), size=(H, W), mode=self.interpolation_type)) 111 | 112 | out.append(x) 113 | 114 | return self.fusion(torch.cat(out, dim=1)) 115 | -------------------------------------------------------------------------------- /DocShadow/models/blocks_.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ConvLayer(nn.Module): 7 | def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, bias=True, groups=1, norm='in', 8 | nonlinear='relu'): 9 | super(ConvLayer, self).__init__() 10 | reflection_padding = (kernel_size + (dilation - 1) * (kernel_size - 1)) // 2 11 | self.reflection_pad = nn.ReflectionPad2d(reflection_padding) 12 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, groups=groups, bias=bias, 13 | dilation=dilation) 14 | self.norm = norm 15 | self.nonlinear = nonlinear 16 | 17 | if norm == 'bn': 18 | self.normalization = nn.BatchNorm2d(out_channels) 19 | elif norm == 'in': 20 | self.normalization = nn.InstanceNorm2d(out_channels, affine=False) 21 | else: 22 | self.normalization = None 23 | 24 | if nonlinear == 'relu': 25 | self.activation = nn.ReLU(inplace=True) 26 | elif nonlinear == 'leakyrelu': 27 | self.activation = nn.LeakyReLU(0.2) 28 | elif nonlinear == 'PReLU': 29 | self.activation = nn.PReLU() 30 | else: 31 | self.activation = None 32 | 33 | def forward(self, x): 34 | out = self.conv2d(self.reflection_pad(x)) 35 | if self.normalization is not None: 36 | out = self.normalization(out) 37 | if self.activation is not None: 38 | out = self.activation(out) 39 | 40 | return out 41 | 42 | 43 | class Aggreation(nn.Module): 44 | def __init__(self, in_channels, out_channels, kernel_size=3): 45 | super(Aggreation, self).__init__() 46 | self.attention = SelfAttention(in_channels, k=8, nonlinear='relu') 47 | self.conv = ConvLayer(in_channels, out_channels, kernel_size=kernel_size, stride=1, dilation=1, 48 | nonlinear='leakyrelu', 49 | norm=None) 50 | 51 | def forward(self, x): 52 | return self.conv(self.attention(x)) 53 | 54 | 55 | class SelfAttention(nn.Module): 56 | def __init__(self, channels, k, nonlinear='relu'): 57 | super(SelfAttention, self).__init__() 58 | self.channels = channels 59 | self.k = k 60 | self.nonlinear = nonlinear 61 | 62 | self.linear1 = nn.Linear(channels, channels // k) 63 | self.linear2 = nn.Linear(channels // k, channels) 64 | self.global_pooling = nn.AdaptiveAvgPool2d((1, 1)) 65 | 66 | if nonlinear == 'relu': 67 | self.activation = nn.ReLU(inplace=True) 68 | elif nonlinear == 'leakyrelu': 69 | self.activation = nn.LeakyReLU(0.2) 70 | elif nonlinear == 'PReLU': 71 | self.activation = nn.PReLU() 72 | else: 73 | raise ValueError 74 | 75 | def attention(self, x): 76 | N, C, H, W = x.size() 77 | out = torch.flatten(self.global_pooling(x), 1) 78 | out = self.activation(self.linear1(out)) 79 | out = torch.sigmoid(self.linear2(out)).view(N, C, 1, 1) 80 | 81 | return out.mul(x) 82 | 83 | def forward(self, x): 84 | return self.attention(x) 85 | 86 | 87 | class SPP(nn.Module): 88 | def __init__(self, in_channels, out_channels, num_layers=4, interpolation_type='bilinear'): 89 | super(SPP, self).__init__() 90 | self.conv = nn.ModuleList() 91 | self.num_layers = num_layers 92 | self.interpolation_type = interpolation_type 93 | 94 | for _ in range(self.num_layers): 95 | self.conv.append( 96 | ConvLayer(in_channels, in_channels, kernel_size=1, stride=1, dilation=1, nonlinear='leakyrelu', 97 | norm=None)) 98 | 99 | self.fusion = ConvLayer((in_channels * (self.num_layers + 1)), out_channels, kernel_size=3, stride=1, 100 | norm='False', nonlinear='leakyrelu') 101 | 102 | def forward(self, x): 103 | 104 | N, C, H, W = x.size() 105 | out = [] 106 | 107 | for level in range(self.num_layers): 108 | out.append(F.interpolate(self.conv[level]( 109 | F.avg_pool2d(x, kernel_size=2 * 2 ** (level + 1), stride=2 * 2 ** (level + 1), 110 | padding=2 * 2 ** (level + 1) % 2)), size=(H, W), mode=self.interpolation_type)) 111 | 112 | out.append(x) 113 | 114 | return self.fusion(torch.cat(out, dim=1)) 115 | -------------------------------------------------------------------------------- /DocShadow/models/model.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | from torch import nn 7 | 8 | from .backbone import DFE 9 | from .blocks import SPP, Aggreation, ConvLayer 10 | 11 | 12 | class ResidualBlock(nn.Module): 13 | def __init__(self, in_features): 14 | super(ResidualBlock, self).__init__() 15 | 16 | self.block = nn.Sequential( 17 | nn.Conv2d(in_features, in_features, 3, padding=1), 18 | nn.LeakyReLU(), 19 | nn.Conv2d(in_features, in_features, 3, padding=1), 20 | ) 21 | 22 | def forward(self, x): 23 | return x + self.block(x) 24 | 25 | 26 | def gauss_kernel(channels=3): 27 | kernel = torch.tensor( 28 | [ 29 | [1.0, 4.0, 6.0, 4.0, 1], 30 | [4.0, 16.0, 24.0, 16.0, 4.0], 31 | [6.0, 24.0, 36.0, 24.0, 6.0], 32 | [4.0, 16.0, 24.0, 16.0, 4.0], 33 | [1.0, 4.0, 6.0, 4.0, 1.0], 34 | ] 35 | ) 36 | kernel /= 256.0 37 | kernel = kernel.repeat(channels, 1, 1, 1) 38 | return kernel 39 | 40 | 41 | class LapPyramidConv(nn.Module): 42 | def __init__(self, num_high=4): 43 | super(LapPyramidConv, self).__init__() 44 | 45 | self.num_high = num_high 46 | self.kernel = gauss_kernel() 47 | 48 | def downsample(self, x): 49 | return x[:, :, ::2, ::2] 50 | 51 | def upsample(self, x): 52 | cc = torch.cat( 53 | [ 54 | x, 55 | torch.zeros( 56 | x.shape[0], x.shape[1], x.shape[2], x.shape[3], device=x.device 57 | ), 58 | ], 59 | dim=3, 60 | ) 61 | cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3]) 62 | cc = cc.permute(0, 1, 3, 2) 63 | cc = torch.cat( 64 | [ 65 | cc, 66 | torch.zeros( 67 | x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2, device=x.device 68 | ), 69 | ], 70 | dim=3, 71 | ) 72 | cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2) 73 | x_up = cc.permute(0, 1, 3, 2) 74 | return self.conv_gauss(x_up, 4 * self.kernel) 75 | 76 | def conv_gauss(self, img, kernel): 77 | img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode="reflect") 78 | out = torch.nn.functional.conv2d( 79 | img, kernel.to(img.device), groups=img.shape[1] 80 | ) 81 | return out 82 | 83 | def pyramid_decom(self, img): 84 | current = img 85 | pyr = [] 86 | for _ in range(self.num_high): 87 | filtered = self.conv_gauss(current, self.kernel) 88 | down = self.downsample(filtered) 89 | up = self.upsample(down) 90 | # if up.shape[2] != current.shape[2] or up.shape[3] != current.shape[3]: 91 | up = nn.functional.interpolate( 92 | up, size=(current.shape[2], current.shape[3]) 93 | ) 94 | diff = current - up 95 | pyr.append(diff) 96 | current = down 97 | pyr.append(current) 98 | return pyr 99 | 100 | def pyramid_recons(self, pyr): 101 | image = pyr[-1] 102 | for level in reversed(pyr[:-1]): 103 | up = self.upsample(image) 104 | # if up.shape[2] != level.shape[2] or up.shape[3] != level.shape[3]: 105 | up = nn.functional.interpolate(up, size=(level.shape[2], level.shape[3])) 106 | image = up + level 107 | return image 108 | 109 | 110 | # ----------- Texture Recovery Module ------------------- 111 | class TRM(nn.Module): 112 | def __init__(self, num_residual_blocks, num_high=3): 113 | super(TRM, self).__init__() 114 | 115 | self.num_high = num_high 116 | 117 | blocks = [nn.Conv2d(9, 64, 3, padding=1), nn.LeakyReLU()] 118 | 119 | for _ in range(num_residual_blocks): 120 | blocks += [ResidualBlock(64)] 121 | 122 | blocks += [nn.Conv2d(64, 3, 3, padding=1)] 123 | 124 | self.model = nn.Sequential(*blocks) 125 | 126 | channels = 3 127 | # Stage1 128 | self.block1_1 = ConvLayer( 129 | in_channels=channels, 130 | out_channels=channels, 131 | kernel_size=3, 132 | stride=1, 133 | dilation=2, 134 | norm=None, 135 | nonlinear="leakyrelu", 136 | ) 137 | self.block1_2 = ConvLayer( 138 | in_channels=channels, 139 | out_channels=channels, 140 | kernel_size=3, 141 | stride=1, 142 | dilation=4, 143 | norm=None, 144 | nonlinear="leakyrelu", 145 | ) 146 | self.aggreation1_rgb = Aggreation( 147 | in_channels=channels * 3, out_channels=channels 148 | ) 149 | # Stage2 150 | self.block2_1 = ConvLayer( 151 | in_channels=channels, 152 | out_channels=channels, 153 | kernel_size=3, 154 | stride=1, 155 | dilation=8, 156 | norm=None, 157 | nonlinear="leakyrelu", 158 | ) 159 | self.block2_2 = ConvLayer( 160 | in_channels=channels, 161 | out_channels=channels, 162 | kernel_size=3, 163 | stride=1, 164 | dilation=16, 165 | norm=None, 166 | nonlinear="leakyrelu", 167 | ) 168 | self.aggreation2_rgb = Aggreation( 169 | in_channels=channels * 3, out_channels=channels 170 | ) 171 | # Stage3 172 | self.block3_1 = ConvLayer( 173 | in_channels=channels, 174 | out_channels=channels, 175 | kernel_size=3, 176 | stride=1, 177 | dilation=32, 178 | norm=None, 179 | nonlinear="leakyrelu", 180 | ) 181 | self.block3_2 = ConvLayer( 182 | in_channels=channels, 183 | out_channels=channels, 184 | kernel_size=3, 185 | stride=1, 186 | dilation=64, 187 | norm=None, 188 | nonlinear="leakyrelu", 189 | ) 190 | self.aggreation3_rgb = Aggreation( 191 | in_channels=channels * 3, out_channels=channels 192 | ) 193 | # Stage3 194 | self.spp_img = SPP( 195 | in_channels=channels, 196 | out_channels=channels, 197 | num_layers=4, 198 | interpolation_type="bicubic", 199 | ) 200 | self.block4_1 = nn.Conv2d( 201 | in_channels=channels, out_channels=3, kernel_size=1, stride=1 202 | ) 203 | 204 | def forward(self, x, pyr_original, fake_low): 205 | pyr_result = [fake_low] 206 | mask = self.model(x) 207 | 208 | mask = nn.functional.interpolate( 209 | mask, size=(pyr_original[-2].shape[2], pyr_original[-2].shape[3]) 210 | ) 211 | result_highfreq = torch.mul(pyr_original[-2], mask) + pyr_original[-2] 212 | out1_1 = self.block1_1(result_highfreq) 213 | out1_2 = self.block1_2(out1_1) 214 | agg1_rgb = self.aggreation1_rgb( 215 | torch.cat((result_highfreq, out1_1, out1_2), dim=1) 216 | ) 217 | pyr_result.append(agg1_rgb) 218 | 219 | mask = nn.functional.interpolate( 220 | mask, size=(pyr_original[-3].shape[2], pyr_original[-3].shape[3]) 221 | ) 222 | result_highfreq = torch.mul(pyr_original[-3], mask) + pyr_original[-3] 223 | out2_1 = self.block2_1(result_highfreq) 224 | out2_2 = self.block2_2(out2_1) 225 | agg2_rgb = self.aggreation2_rgb( 226 | torch.cat((result_highfreq, out2_1, out2_2), dim=1) 227 | ) 228 | 229 | out3_1 = self.block3_1(agg2_rgb) 230 | out3_2 = self.block3_2(out3_1) 231 | agg3_rgb = self.aggreation3_rgb(torch.cat((agg2_rgb, out3_1, out3_2), dim=1)) 232 | 233 | spp_rgb = self.spp_img(agg3_rgb) 234 | out_rgb = self.block4_1(spp_rgb) 235 | 236 | pyr_result.append(out_rgb) 237 | pyr_result.reverse() 238 | 239 | return pyr_result 240 | 241 | 242 | def to_3d(x): 243 | return rearrange(x, "b c h w -> b (h w) c") 244 | 245 | 246 | def to_4d(x, h, w): 247 | return rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 248 | 249 | 250 | class BiasFree_LayerNorm(nn.Module): 251 | def __init__(self, normalized_shape): 252 | super(BiasFree_LayerNorm, self).__init__() 253 | if isinstance(normalized_shape, numbers.Integral): 254 | normalized_shape = (normalized_shape,) 255 | normalized_shape = torch.Size(normalized_shape) 256 | 257 | assert len(normalized_shape) == 1 258 | 259 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 260 | self.normalized_shape = normalized_shape 261 | 262 | def forward(self, x): 263 | sigma = x.var(-1, keepdim=True, unbiased=False) 264 | return x / torch.sqrt(sigma + 1e-5) * self.weight 265 | 266 | 267 | class WithBias_LayerNorm(nn.Module): 268 | def __init__(self, normalized_shape): 269 | super(WithBias_LayerNorm, self).__init__() 270 | if isinstance(normalized_shape, numbers.Integral): 271 | normalized_shape = (normalized_shape,) 272 | normalized_shape = torch.Size(normalized_shape) 273 | 274 | assert len(normalized_shape) == 1 275 | 276 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 277 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 278 | self.normalized_shape = normalized_shape 279 | 280 | def forward(self, x): 281 | mu = x.mean(-1, keepdim=True) 282 | sigma = x.var(-1, keepdim=True, unbiased=False) 283 | return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias 284 | 285 | 286 | class LayerNorm(nn.Module): 287 | def __init__(self, dim, LayerNorm_type): 288 | super(LayerNorm, self).__init__() 289 | if LayerNorm_type == "BiasFree": 290 | self.body = BiasFree_LayerNorm(dim) 291 | else: 292 | self.body = WithBias_LayerNorm(dim) 293 | 294 | def forward(self, x): 295 | h, w = x.shape[-2:] 296 | return to_4d(self.body(to_3d(x)), h, w) 297 | 298 | 299 | # Axis-based Multi-head Self-Attention 300 | 301 | 302 | class NextAttentionImplZ(nn.Module): 303 | def __init__(self, num_dims, num_heads, bias) -> None: 304 | super().__init__() 305 | self.num_dims = num_dims 306 | self.num_heads = num_heads 307 | self.q1 = nn.Conv2d(num_dims, num_dims * 3, kernel_size=1, bias=bias) 308 | self.q2 = nn.Conv2d( 309 | num_dims * 3, 310 | num_dims * 3, 311 | kernel_size=3, 312 | padding=1, 313 | groups=num_dims * 3, 314 | bias=bias, 315 | ) 316 | self.q3 = nn.Conv2d( 317 | num_dims * 3, 318 | num_dims * 3, 319 | kernel_size=3, 320 | padding=1, 321 | groups=num_dims * 3, 322 | bias=bias, 323 | ) 324 | 325 | self.fac = nn.Parameter(torch.ones(1)) 326 | self.fin = nn.Conv2d(num_dims, num_dims, kernel_size=1, bias=bias) 327 | return 328 | 329 | def forward(self, x): 330 | # x: [n, c, h, w] 331 | n, c, h, w = x.size() 332 | n_heads, dim_head = self.num_heads, c // self.num_heads 333 | reshape = lambda x: rearrange( 334 | x, "n (nh dh) h w -> (n nh h) w dh", nh=n_heads, dh=dim_head 335 | ) 336 | 337 | qkv = self.q3(self.q2(self.q1(x))) 338 | q, k, v = map(reshape, qkv.chunk(3, dim=1)) 339 | q = F.normalize(q, dim=-1) 340 | k = F.normalize(k, dim=-1) 341 | 342 | # fac = dim_head ** -0.5 343 | res = k.transpose(-2, -1) 344 | res = torch.matmul(q, res) * self.fac 345 | res = torch.softmax(res, dim=-1) 346 | 347 | res = torch.matmul(res, v) 348 | res = rearrange( 349 | res, "(n nh h) w dh -> n (nh dh) h w", nh=n_heads, dh=dim_head, n=n, h=h 350 | ) 351 | res = self.fin(res) 352 | 353 | return res 354 | 355 | 356 | # Axis-based Multi-head Self-Attention (row and col attention) 357 | class NextAttentionZ(nn.Module): 358 | def __init__(self, num_dims, num_heads=1, bias=True) -> None: 359 | super().__init__() 360 | assert num_dims % num_heads == 0 361 | self.num_dims = num_dims 362 | self.num_heads = num_heads 363 | self.row_att = NextAttentionImplZ(num_dims, num_heads, bias) 364 | self.col_att = NextAttentionImplZ(num_dims, num_heads, bias) 365 | return 366 | 367 | def forward(self, x: torch.Tensor): 368 | assert len(x.size()) == 4 369 | 370 | x = self.row_att(x) 371 | x = x.transpose(-2, -1) 372 | x = self.col_att(x) 373 | x = x.transpose(-2, -1) 374 | 375 | return x 376 | 377 | 378 | # Dual Gated Feed-Forward Networ 379 | class FeedForward(nn.Module): 380 | def __init__(self, dim, ffn_expansion_factor, bias): 381 | super(FeedForward, self).__init__() 382 | 383 | hidden_features = int(dim * ffn_expansion_factor) 384 | 385 | self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) 386 | 387 | self.dwconv = nn.Conv2d( 388 | hidden_features * 2, 389 | hidden_features * 2, 390 | kernel_size=3, 391 | stride=1, 392 | padding=1, 393 | groups=hidden_features * 2, 394 | bias=bias, 395 | ) 396 | 397 | self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) 398 | 399 | def forward(self, x): 400 | x = self.project_in(x) 401 | x1, x2 = self.dwconv(x).chunk(2, dim=1) 402 | x = F.gelu(x2) * x1 + F.gelu(x1) * x2 403 | x = self.project_out(x) 404 | return x 405 | 406 | 407 | # Axis-based Transformer Block 408 | class TransformerBlock(nn.Module): 409 | def __init__( 410 | self, 411 | dim, 412 | num_heads=1, 413 | ffn_expansion_factor=2.66, 414 | bias=True, 415 | LayerNorm_type="WithBias", 416 | ): 417 | super(TransformerBlock, self).__init__() 418 | 419 | self.norm1 = LayerNorm(dim, LayerNorm_type) 420 | self.attn = NextAttentionZ(dim, num_heads) 421 | self.norm2 = LayerNorm(dim, LayerNorm_type) 422 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias) 423 | 424 | def forward(self, x): 425 | x = x + self.attn(self.norm1(x)) 426 | x = x + self.ffn(self.norm2(x)) 427 | return x 428 | 429 | 430 | ########################################################################## 431 | # Overlapped image patch embedding with 3x3 Conv 432 | class OverlapPatchEmbed(nn.Module): 433 | def __init__(self, in_c=3, embed_dim=48, bias=False): 434 | super(OverlapPatchEmbed, self).__init__() 435 | 436 | self.proj = nn.Conv2d( 437 | in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias 438 | ) 439 | 440 | def forward(self, x): 441 | x = self.proj(x) 442 | 443 | return x 444 | 445 | 446 | ########################################################################## 447 | # Resizing modules 448 | class Downsample(nn.Module): 449 | def __init__(self, n_feat): 450 | super(Downsample, self).__init__() 451 | 452 | self.body = nn.Sequential( 453 | nn.Conv2d( 454 | n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False 455 | ), 456 | nn.PixelUnshuffle(2), 457 | ) 458 | 459 | def forward(self, x): 460 | return self.body(x) 461 | 462 | 463 | class Upsample(nn.Module): 464 | def __init__(self, n_feat): 465 | super(Upsample, self).__init__() 466 | 467 | self.body = nn.Sequential( 468 | nn.Conv2d( 469 | n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False 470 | ), 471 | nn.PixelShuffle(2), 472 | ) 473 | 474 | def forward(self, x): 475 | return self.body(x) 476 | 477 | 478 | # Tri-layer Attention Alignment Block 479 | class TAA(nn.Module): 480 | """Layer attention module""" 481 | 482 | def __init__(self, in_dim, bias=True): 483 | super(TAA, self).__init__() 484 | self.chanel_in = in_dim 485 | 486 | self.temperature = nn.Parameter(torch.ones(1)) 487 | 488 | self.qkv = nn.Conv2d( 489 | self.chanel_in, self.chanel_in * 3, kernel_size=1, bias=bias 490 | ) 491 | self.qkv_dwconv = nn.Conv2d( 492 | self.chanel_in * 3, 493 | self.chanel_in * 3, 494 | kernel_size=3, 495 | stride=1, 496 | padding=1, 497 | groups=self.chanel_in * 3, 498 | bias=bias, 499 | ) 500 | self.project_out = nn.Conv2d( 501 | self.chanel_in, self.chanel_in, kernel_size=1, bias=bias 502 | ) 503 | 504 | def forward(self, x): 505 | """ 506 | inputs : 507 | x : input feature maps( B X N X C X H X W) 508 | returns : 509 | out : attention value + input feature 510 | attention: B X N X N 511 | """ 512 | m_batchsize, N, C, height, width = x.size() 513 | 514 | x_input = x.view(m_batchsize, N * C, height, width) 515 | qkv = self.qkv_dwconv(self.qkv(x_input)) 516 | q, k, v = qkv.chunk(3, dim=1) 517 | q = q.view(m_batchsize, N, -1) 518 | k = k.view(m_batchsize, N, -1) 519 | v = v.view(m_batchsize, N, -1) 520 | 521 | q = torch.nn.functional.normalize(q, dim=-1) 522 | k = torch.nn.functional.normalize(k, dim=-1) 523 | 524 | attn = (q @ k.transpose(-2, -1)) * self.temperature 525 | attn = attn.softmax(dim=-1) 526 | 527 | out_1 = attn @ v 528 | out_1 = out_1.view(m_batchsize, -1, height, width) 529 | 530 | out_1 = self.project_out(out_1) 531 | out_1 = out_1.view(m_batchsize, N, C, height, width) 532 | 533 | out = out_1 + x 534 | out = out.view(m_batchsize, -1, height, width) 535 | return out 536 | 537 | 538 | ########################################################################## 539 | # ---------- Dimension-Aware Transformer Block ----------------------- 540 | class Backbone(nn.Module): 541 | def __init__( 542 | self, 543 | inp_channels=3, 544 | out_channels=3, 545 | dim=3, 546 | num_blocks=[1, 2, 4, 8], 547 | num_refinement_blocks=2, 548 | heads=[1, 2, 4, 8], 549 | ffn_expansion_factor=2.66, 550 | bias=False, 551 | LayerNorm_type="WithBias", 552 | attention=True, 553 | ): 554 | super(Backbone, self).__init__() 555 | 556 | self.patch_embed = OverlapPatchEmbed(inp_channels, dim) 557 | 558 | self.encoder_1 = nn.Sequential( 559 | *[ 560 | TransformerBlock( 561 | dim=dim, 562 | num_heads=heads[0], 563 | ffn_expansion_factor=ffn_expansion_factor, 564 | bias=bias, 565 | LayerNorm_type=LayerNorm_type, 566 | ) 567 | for _ in range(num_blocks[0]) 568 | ] 569 | ) 570 | 571 | self.encoder_2 = nn.Sequential( 572 | *[ 573 | TransformerBlock( 574 | dim=int(dim), 575 | num_heads=heads[0], 576 | ffn_expansion_factor=ffn_expansion_factor, 577 | bias=bias, 578 | LayerNorm_type=LayerNorm_type, 579 | ) 580 | for _ in range(num_blocks[0]) 581 | ] 582 | ) 583 | 584 | self.encoder_3 = nn.Sequential( 585 | *[ 586 | TransformerBlock( 587 | dim=int(dim), 588 | num_heads=heads[0], 589 | ffn_expansion_factor=ffn_expansion_factor, 590 | bias=bias, 591 | LayerNorm_type=LayerNorm_type, 592 | ) 593 | for _ in range(num_blocks[0]) 594 | ] 595 | ) 596 | 597 | self.layer_fussion = TAA(in_dim=int(dim * 3)) 598 | self.conv_fuss = nn.Conv2d(int(dim * 3), int(dim), kernel_size=1, bias=bias) 599 | 600 | self.latent = nn.Sequential( 601 | *[ 602 | TransformerBlock( 603 | dim=int(dim), 604 | num_heads=heads[0], 605 | ffn_expansion_factor=ffn_expansion_factor, 606 | bias=bias, 607 | LayerNorm_type=LayerNorm_type, 608 | ) 609 | for _ in range(num_blocks[0]) 610 | ] 611 | ) 612 | 613 | self.trans_low = DFE() 614 | 615 | self.coefficient_1_0 = nn.Parameter( 616 | torch.ones((2, int(int(dim)))), requires_grad=attention 617 | ) 618 | 619 | self.refinement_1 = nn.Sequential( 620 | *[ 621 | TransformerBlock( 622 | dim=int(dim), 623 | num_heads=heads[0], 624 | ffn_expansion_factor=ffn_expansion_factor, 625 | bias=bias, 626 | LayerNorm_type=LayerNorm_type, 627 | ) 628 | for _ in range(num_refinement_blocks) 629 | ] 630 | ) 631 | self.refinement_2 = nn.Sequential( 632 | *[ 633 | TransformerBlock( 634 | dim=int(dim), 635 | num_heads=heads[0], 636 | ffn_expansion_factor=ffn_expansion_factor, 637 | bias=bias, 638 | LayerNorm_type=LayerNorm_type, 639 | ) 640 | for _ in range(num_refinement_blocks) 641 | ] 642 | ) 643 | self.refinement_3 = nn.Sequential( 644 | *[ 645 | TransformerBlock( 646 | dim=int(dim), 647 | num_heads=heads[0], 648 | ffn_expansion_factor=ffn_expansion_factor, 649 | bias=bias, 650 | LayerNorm_type=LayerNorm_type, 651 | ) 652 | for _ in range(num_refinement_blocks) 653 | ] 654 | ) 655 | 656 | self.layer_fussion_2 = TAA(in_dim=int(dim * 3)) 657 | self.conv_fuss_2 = nn.Conv2d(int(dim * 3), int(dim), kernel_size=1, bias=bias) 658 | 659 | self.output = nn.Conv2d( 660 | int(dim), out_channels, kernel_size=3, stride=1, padding=1, bias=bias 661 | ) 662 | 663 | def forward(self, inp): 664 | inp_enc_encoder1 = self.patch_embed(inp) 665 | out_enc_encoder1 = self.encoder_1(inp_enc_encoder1) 666 | out_enc_encoder2 = self.encoder_2(out_enc_encoder1) 667 | out_enc_encoder3 = self.encoder_3(out_enc_encoder2) 668 | 669 | inp_fusion_123 = torch.cat( 670 | [ 671 | out_enc_encoder1.unsqueeze(1), 672 | out_enc_encoder2.unsqueeze(1), 673 | out_enc_encoder3.unsqueeze(1), 674 | ], 675 | dim=1, 676 | ) 677 | 678 | out_fusion_123 = self.layer_fussion(inp_fusion_123) 679 | out_fusion_123 = self.conv_fuss(out_fusion_123) 680 | 681 | out_enc = self.trans_low(out_fusion_123) 682 | 683 | out_fusion_123 = self.latent(out_fusion_123) 684 | 685 | out = ( 686 | self.coefficient_1_0[0, :][None, :, None, None] * out_fusion_123 687 | + self.coefficient_1_0[1, :][None, :, None, None] * out_enc 688 | ) 689 | 690 | out_1 = self.refinement_1(out) 691 | out_2 = self.refinement_2(out_1) 692 | out_3 = self.refinement_3(out_2) 693 | 694 | inp_fusion = torch.cat( 695 | [out_1.unsqueeze(1), out_2.unsqueeze(1), out_3.unsqueeze(1)], dim=1 696 | ) 697 | out_fusion_123 = self.layer_fussion_2(inp_fusion) 698 | out = self.conv_fuss_2(out_fusion_123) 699 | result = self.output(out) 700 | 701 | return result 702 | 703 | 704 | class DocShadow(nn.Module): 705 | def __init__(self, depth=2): 706 | super().__init__() 707 | self.backbone = Backbone() 708 | self.lap_pyramid = LapPyramidConv(depth) 709 | self.trans_high = TRM(3, num_high=depth) 710 | 711 | def forward(self, inp): 712 | pyr_inp = self.lap_pyramid.pyramid_decom(img=inp) 713 | out_low = self.backbone(pyr_inp[-1]) 714 | 715 | inp_up = nn.functional.interpolate( 716 | pyr_inp[-1], size=(pyr_inp[-2].shape[2], pyr_inp[-2].shape[3]) 717 | ) 718 | out_up = nn.functional.interpolate( 719 | out_low, size=(pyr_inp[-2].shape[2], pyr_inp[-2].shape[3]) 720 | ) 721 | high_with_low = torch.cat([pyr_inp[-2], inp_up, out_up], 1) 722 | 723 | pyr_inp_trans = self.trans_high(high_with_low, pyr_inp, out_low) 724 | 725 | result = self.lap_pyramid.pyramid_recons(pyr_inp_trans) 726 | 727 | return result 728 | -------------------------------------------------------------------------------- /DocShadow/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import load_checkpoint 2 | -------------------------------------------------------------------------------- /DocShadow/utils/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from pathlib import Path 3 | 4 | import torch 5 | 6 | MODEL_URLS = { 7 | "sd7k": "https://drive.google.com/uc?export=download&confirm=t&id=1a1dGcSFYB1ocR5KohSTjXIQcpQ0w05V7", 8 | "jung": "https://drive.google.com/uc?export=download&confirm=t&id=19i0ms_5Cv2tOE6SmL7vyms_gCdZCZQDt", 9 | "kligler": "https://drive.google.com/uc?export=download&confirm=t&id=1JEmtyGeyhCNdZ9_yhhEYJSeFTgQr5XYw", 10 | } 11 | 12 | 13 | def load_checkpoint(model: torch.nn.Module, weights: str, device) -> None: 14 | # Check if local path 15 | if Path(weights).exists(): 16 | checkpoint = torch.load(weights, map_location=str(device)) 17 | else: 18 | # Download 19 | assert ( 20 | weights.lower() in MODEL_URLS.keys() 21 | ), f"DocShadow has only been trained on {MODEL_URLS.keys()}" 22 | checkpoint = torch.hub.load_state_dict_from_url( 23 | MODEL_URLS[weights.lower()], 24 | file_name=f"{weights.lower()}.pth", 25 | map_location=str(device), 26 | ) 27 | 28 | new_state_dict = OrderedDict() 29 | for key, value in checkpoint["state_dict"].items(): 30 | if key.startswith("module"): 31 | name = key[7:] 32 | else: 33 | name = key 34 | new_state_dict[name] = value 35 | 36 | model.load_state_dict(new_state_dict) 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Nick Chen 4 | Copyright (c) 2023 Fabio Milentiansen Sim 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [](/LICENSE) 2 | [](https://onnx.ai/) 3 | [](https://developer.nvidia.com/tensorrt) 4 | [](https://github.com/fabio-sim/DocShadow-ONNX-TensorRT/stargazers) 5 | [](https://github.com/fabio-sim/DocShadow-ONNX-TensorRT/releases) 6 | 7 | # DocShadow-ONNX-TensorRT 8 | Open Neural Network Exchange (ONNX) compatible implementation of [DocShadow: High-Resolution Document Shadow Removal via A Large-scale Real-world Dataset and A Frequency-aware Shadow Erasing Net](https://github.com/CXH-Research/DocShadow-SD7K). Supports TensorRT 🚀. 9 | 10 |
DocShadow ONNX TensorRT provides up to a 2x speedup over PyTorch.
21 | python export.py \ 22 | --weights sd7k \ 23 | --dynamic_img_size --dynamic_batch 24 |25 |
56 | python infer.py \ 57 | --img_path assets/sample.jpg \ 58 | --img_size 256 256 \ 59 | --onnx_path weights/docshadow_sd7k.onnx \ 60 | --viz 61 |62 |
73 | CUDA_MODULE_LOADING=LAZY && python infer.py \ 74 | --img_path assets/sample.jpg \ 75 | --onnx_path weights/docshadow_sd7k.onnx \ 76 | --img_size 256 256 \ 77 | --trt \ 78 | --viz 79 |80 |