├── .gitignore ├── DocShadow ├── models │ ├── __init__.py │ ├── backbone.py │ ├── blocks.py │ ├── blocks_.py │ └── model.py └── utils │ ├── __init__.py │ └── utils.py ├── LICENSE ├── README.md ├── assets ├── latency.png └── sample.jpg ├── eval.py ├── export.py ├── export_coreml.py ├── infer.py ├── onnx_runner ├── __init__.py └── docshadow.py ├── requirements-onnx.txt ├── requirements.txt └── weights ├── .gitkeep └── download.sh /.gitignore: -------------------------------------------------------------------------------- 1 | *.onnx 2 | *.pth 3 | *.engine 4 | *.profile 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | 167 | # VSCode 168 | .vscode* 169 | -------------------------------------------------------------------------------- /DocShadow/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import DocShadow 2 | -------------------------------------------------------------------------------- /DocShadow/models/backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LayerNormFunction(torch.autograd.Function): 7 | 8 | @staticmethod 9 | def forward(ctx, x, weight, bias, eps): 10 | ctx.eps = eps 11 | N, C, H, W = x.size() 12 | mu = x.mean(1, keepdim=True) 13 | var = (x - mu).pow(2).mean(1, keepdim=True) 14 | y = (x - mu) / (var + eps).sqrt() 15 | ctx.save_for_backward(y, var, weight) 16 | y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) 17 | return y 18 | 19 | @staticmethod 20 | def backward(ctx, grad_output): 21 | eps = ctx.eps 22 | 23 | N, C, H, W = grad_output.size() 24 | y, var, weight = ctx.saved_variables 25 | g = grad_output * weight.view(1, C, 1, 1) 26 | mean_g = g.mean(dim=1, keepdim=True) 27 | 28 | mean_gy = (g * y).mean(dim=1, keepdim=True) 29 | gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) 30 | return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum( 31 | dim=0), None 32 | 33 | 34 | class LayerNorm2d(nn.Module): 35 | 36 | def __init__(self, channels, eps=1e-6): 37 | super(LayerNorm2d, self).__init__() 38 | self.register_parameter('weight', nn.Parameter(torch.ones(channels))) 39 | self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) 40 | self.eps = eps 41 | 42 | def forward(self, x): 43 | return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) 44 | 45 | 46 | class SimpleGate(nn.Module): 47 | def forward(self, x): 48 | x1, x2 = x.chunk(2, dim=1) 49 | return x1 * x2 50 | 51 | 52 | # ------------ Deep Feature Extraction Block ----------------- 53 | class DFEBlock(nn.Module): 54 | def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): 55 | super().__init__() 56 | dw_channel = c * DW_Expand 57 | self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, 58 | bias=True) 59 | self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, 60 | groups=dw_channel, 61 | bias=True) 62 | self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, 63 | groups=1, bias=True) 64 | 65 | # Simplified Channel Attention 66 | self.sca = nn.Sequential( 67 | nn.AdaptiveAvgPool2d(1), 68 | nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, 69 | groups=1, bias=True), 70 | ) 71 | 72 | # SimpleGate 73 | self.sg = SimpleGate() 74 | 75 | ffn_channel = FFN_Expand * c 76 | self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, 77 | bias=True) 78 | self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, 79 | groups=1, bias=True) 80 | 81 | self.norm1 = LayerNorm2d(c) 82 | self.norm2 = LayerNorm2d(c) 83 | 84 | self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() 85 | self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() 86 | 87 | self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) 88 | self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) 89 | 90 | def forward(self, inp): 91 | x = inp 92 | 93 | x = self.norm1(x) 94 | 95 | x = self.conv1(x) 96 | x = self.conv2(x) 97 | x = self.sg(x) 98 | x = x * self.sca(x) 99 | x = self.conv3(x) 100 | 101 | x = self.dropout1(x) 102 | 103 | y = inp + x * self.beta 104 | 105 | x = self.conv4(self.norm2(y)) 106 | x = self.sg(x) 107 | x = self.conv5(x) 108 | 109 | x = self.dropout2(x) 110 | 111 | return y + x * self.gamma 112 | 113 | 114 | class DFE(nn.Module): 115 | 116 | def __init__(self, img_channel=3, width=32, middle_blk_num=12, enc_blk_nums=[2, 2, 4, 8], dec_blk_nums=[2, 2, 2, 2]): 117 | super().__init__() 118 | 119 | self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, 120 | groups=1, 121 | bias=True) 122 | self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, 123 | groups=1, 124 | bias=True) 125 | 126 | self.encoders = nn.ModuleList() 127 | self.decoders = nn.ModuleList() 128 | self.middle_blks = nn.ModuleList() 129 | self.ups = nn.ModuleList() 130 | self.downs = nn.ModuleList() 131 | 132 | chan = width 133 | for num in enc_blk_nums: 134 | self.encoders.append( 135 | nn.Sequential( 136 | *[DFEBlock(chan) for _ in range(num)] 137 | ) 138 | ) 139 | self.downs.append( 140 | nn.Conv2d(chan, 2 * chan, 2, 2) 141 | ) 142 | chan = chan * 2 143 | 144 | self.middle_blks = \ 145 | nn.Sequential( 146 | *[DFEBlock(chan) for _ in range(middle_blk_num)] 147 | ) 148 | 149 | for num in dec_blk_nums: 150 | self.ups.append( 151 | nn.Sequential( 152 | nn.Conv2d(chan, chan * 2, 1, bias=False), 153 | nn.PixelShuffle(2) 154 | ) 155 | ) 156 | chan = chan // 2 157 | self.decoders.append( 158 | nn.Sequential( 159 | *[DFEBlock(chan) for _ in range(num)] 160 | ) 161 | ) 162 | 163 | self.padder_size = 2 ** len(self.encoders) 164 | 165 | def forward(self, inp): 166 | B, C, H, W = inp.shape 167 | inp = self.check_image_size(inp) 168 | 169 | x = self.intro(inp) 170 | 171 | encs = [] 172 | 173 | for encoder, down in zip(self.encoders, self.downs): 174 | x = encoder(x) 175 | encs.append(x) 176 | x = down(x) 177 | 178 | x = self.middle_blks(x) 179 | 180 | for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): 181 | x = up(x) 182 | x = x + enc_skip 183 | x = decoder(x) 184 | 185 | x = self.ending(x) 186 | x = x + inp 187 | 188 | return x[:, :, :H, :W] 189 | 190 | def check_image_size(self, x): 191 | _, _, h, w = x.size() 192 | mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size 193 | mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size 194 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) 195 | return x 196 | -------------------------------------------------------------------------------- /DocShadow/models/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ConvLayer(nn.Module): 7 | def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, bias=True, groups=1, norm='in', 8 | nonlinear='relu'): 9 | super(ConvLayer, self).__init__() 10 | reflection_padding = (kernel_size + (dilation - 1) * (kernel_size - 1)) // 2 11 | self.reflection_pad = nn.ReflectionPad2d(reflection_padding) 12 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, groups=groups, bias=bias, 13 | dilation=dilation) 14 | self.norm = norm 15 | self.nonlinear = nonlinear 16 | 17 | if norm == 'bn': 18 | self.normalization = nn.BatchNorm2d(out_channels) 19 | elif norm == 'in': 20 | self.normalization = nn.InstanceNorm2d(out_channels, affine=False) 21 | else: 22 | self.normalization = None 23 | 24 | if nonlinear == 'relu': 25 | self.activation = nn.ReLU(inplace=True) 26 | elif nonlinear == 'leakyrelu': 27 | self.activation = nn.LeakyReLU(0.2) 28 | elif nonlinear == 'PReLU': 29 | self.activation = nn.PReLU() 30 | else: 31 | self.activation = None 32 | 33 | def forward(self, x): 34 | out = self.conv2d(self.reflection_pad(x)) 35 | if self.normalization is not None: 36 | out = self.normalization(out) 37 | if self.activation is not None: 38 | out = self.activation(out) 39 | 40 | return out 41 | 42 | 43 | class Aggreation(nn.Module): 44 | def __init__(self, in_channels, out_channels, kernel_size=3): 45 | super(Aggreation, self).__init__() 46 | self.attention = SelfAttention(in_channels, k=8, nonlinear='relu') 47 | self.conv = ConvLayer(in_channels, out_channels, kernel_size=kernel_size, stride=1, dilation=1, 48 | nonlinear='leakyrelu', 49 | norm=None) 50 | 51 | def forward(self, x): 52 | return self.conv(self.attention(x)) 53 | 54 | 55 | class SelfAttention(nn.Module): 56 | def __init__(self, channels, k, nonlinear='relu'): 57 | super(SelfAttention, self).__init__() 58 | self.channels = channels 59 | self.k = k 60 | self.nonlinear = nonlinear 61 | 62 | self.linear1 = nn.Linear(channels, channels // k) 63 | self.linear2 = nn.Linear(channels // k, channels) 64 | self.global_pooling = nn.AdaptiveAvgPool2d((1, 1)) 65 | 66 | if nonlinear == 'relu': 67 | self.activation = nn.ReLU(inplace=True) 68 | elif nonlinear == 'leakyrelu': 69 | self.activation = nn.LeakyReLU(0.2) 70 | elif nonlinear == 'PReLU': 71 | self.activation = nn.PReLU() 72 | else: 73 | raise ValueError 74 | 75 | def attention(self, x): 76 | N, C, H, W = x.size() 77 | out = torch.flatten(self.global_pooling(x), 1) 78 | out = self.activation(self.linear1(out)) 79 | out = torch.sigmoid(self.linear2(out)).view(N, C, 1, 1) 80 | 81 | return out.mul(x) 82 | 83 | def forward(self, x): 84 | return self.attention(x) 85 | 86 | 87 | class SPP(nn.Module): 88 | def __init__(self, in_channels, out_channels, num_layers=4, interpolation_type='bilinear'): 89 | super(SPP, self).__init__() 90 | self.conv = nn.ModuleList() 91 | self.num_layers = num_layers 92 | self.interpolation_type = interpolation_type 93 | 94 | for _ in range(self.num_layers): 95 | self.conv.append( 96 | ConvLayer(in_channels, in_channels, kernel_size=1, stride=1, dilation=1, nonlinear='leakyrelu', 97 | norm=None)) 98 | 99 | self.fusion = ConvLayer((in_channels * (self.num_layers + 1)), out_channels, kernel_size=3, stride=1, 100 | norm='False', nonlinear='leakyrelu') 101 | 102 | def forward(self, x): 103 | 104 | N, C, H, W = x.size() 105 | out = [] 106 | 107 | for level in range(self.num_layers): 108 | out.append(F.interpolate(self.conv[level]( 109 | F.avg_pool2d(x, kernel_size=2 * 2 ** (level + 1), stride=2 * 2 ** (level + 1), 110 | padding=2 * 2 ** (level + 1) % 2)), size=(H, W), mode=self.interpolation_type)) 111 | 112 | out.append(x) 113 | 114 | return self.fusion(torch.cat(out, dim=1)) 115 | -------------------------------------------------------------------------------- /DocShadow/models/blocks_.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ConvLayer(nn.Module): 7 | def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, bias=True, groups=1, norm='in', 8 | nonlinear='relu'): 9 | super(ConvLayer, self).__init__() 10 | reflection_padding = (kernel_size + (dilation - 1) * (kernel_size - 1)) // 2 11 | self.reflection_pad = nn.ReflectionPad2d(reflection_padding) 12 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, groups=groups, bias=bias, 13 | dilation=dilation) 14 | self.norm = norm 15 | self.nonlinear = nonlinear 16 | 17 | if norm == 'bn': 18 | self.normalization = nn.BatchNorm2d(out_channels) 19 | elif norm == 'in': 20 | self.normalization = nn.InstanceNorm2d(out_channels, affine=False) 21 | else: 22 | self.normalization = None 23 | 24 | if nonlinear == 'relu': 25 | self.activation = nn.ReLU(inplace=True) 26 | elif nonlinear == 'leakyrelu': 27 | self.activation = nn.LeakyReLU(0.2) 28 | elif nonlinear == 'PReLU': 29 | self.activation = nn.PReLU() 30 | else: 31 | self.activation = None 32 | 33 | def forward(self, x): 34 | out = self.conv2d(self.reflection_pad(x)) 35 | if self.normalization is not None: 36 | out = self.normalization(out) 37 | if self.activation is not None: 38 | out = self.activation(out) 39 | 40 | return out 41 | 42 | 43 | class Aggreation(nn.Module): 44 | def __init__(self, in_channels, out_channels, kernel_size=3): 45 | super(Aggreation, self).__init__() 46 | self.attention = SelfAttention(in_channels, k=8, nonlinear='relu') 47 | self.conv = ConvLayer(in_channels, out_channels, kernel_size=kernel_size, stride=1, dilation=1, 48 | nonlinear='leakyrelu', 49 | norm=None) 50 | 51 | def forward(self, x): 52 | return self.conv(self.attention(x)) 53 | 54 | 55 | class SelfAttention(nn.Module): 56 | def __init__(self, channels, k, nonlinear='relu'): 57 | super(SelfAttention, self).__init__() 58 | self.channels = channels 59 | self.k = k 60 | self.nonlinear = nonlinear 61 | 62 | self.linear1 = nn.Linear(channels, channels // k) 63 | self.linear2 = nn.Linear(channels // k, channels) 64 | self.global_pooling = nn.AdaptiveAvgPool2d((1, 1)) 65 | 66 | if nonlinear == 'relu': 67 | self.activation = nn.ReLU(inplace=True) 68 | elif nonlinear == 'leakyrelu': 69 | self.activation = nn.LeakyReLU(0.2) 70 | elif nonlinear == 'PReLU': 71 | self.activation = nn.PReLU() 72 | else: 73 | raise ValueError 74 | 75 | def attention(self, x): 76 | N, C, H, W = x.size() 77 | out = torch.flatten(self.global_pooling(x), 1) 78 | out = self.activation(self.linear1(out)) 79 | out = torch.sigmoid(self.linear2(out)).view(N, C, 1, 1) 80 | 81 | return out.mul(x) 82 | 83 | def forward(self, x): 84 | return self.attention(x) 85 | 86 | 87 | class SPP(nn.Module): 88 | def __init__(self, in_channels, out_channels, num_layers=4, interpolation_type='bilinear'): 89 | super(SPP, self).__init__() 90 | self.conv = nn.ModuleList() 91 | self.num_layers = num_layers 92 | self.interpolation_type = interpolation_type 93 | 94 | for _ in range(self.num_layers): 95 | self.conv.append( 96 | ConvLayer(in_channels, in_channels, kernel_size=1, stride=1, dilation=1, nonlinear='leakyrelu', 97 | norm=None)) 98 | 99 | self.fusion = ConvLayer((in_channels * (self.num_layers + 1)), out_channels, kernel_size=3, stride=1, 100 | norm='False', nonlinear='leakyrelu') 101 | 102 | def forward(self, x): 103 | 104 | N, C, H, W = x.size() 105 | out = [] 106 | 107 | for level in range(self.num_layers): 108 | out.append(F.interpolate(self.conv[level]( 109 | F.avg_pool2d(x, kernel_size=2 * 2 ** (level + 1), stride=2 * 2 ** (level + 1), 110 | padding=2 * 2 ** (level + 1) % 2)), size=(H, W), mode=self.interpolation_type)) 111 | 112 | out.append(x) 113 | 114 | return self.fusion(torch.cat(out, dim=1)) 115 | -------------------------------------------------------------------------------- /DocShadow/models/model.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | from torch import nn 7 | 8 | from .backbone import DFE 9 | from .blocks import SPP, Aggreation, ConvLayer 10 | 11 | 12 | class ResidualBlock(nn.Module): 13 | def __init__(self, in_features): 14 | super(ResidualBlock, self).__init__() 15 | 16 | self.block = nn.Sequential( 17 | nn.Conv2d(in_features, in_features, 3, padding=1), 18 | nn.LeakyReLU(), 19 | nn.Conv2d(in_features, in_features, 3, padding=1), 20 | ) 21 | 22 | def forward(self, x): 23 | return x + self.block(x) 24 | 25 | 26 | def gauss_kernel(channels=3): 27 | kernel = torch.tensor( 28 | [ 29 | [1.0, 4.0, 6.0, 4.0, 1], 30 | [4.0, 16.0, 24.0, 16.0, 4.0], 31 | [6.0, 24.0, 36.0, 24.0, 6.0], 32 | [4.0, 16.0, 24.0, 16.0, 4.0], 33 | [1.0, 4.0, 6.0, 4.0, 1.0], 34 | ] 35 | ) 36 | kernel /= 256.0 37 | kernel = kernel.repeat(channels, 1, 1, 1) 38 | return kernel 39 | 40 | 41 | class LapPyramidConv(nn.Module): 42 | def __init__(self, num_high=4): 43 | super(LapPyramidConv, self).__init__() 44 | 45 | self.num_high = num_high 46 | self.kernel = gauss_kernel() 47 | 48 | def downsample(self, x): 49 | return x[:, :, ::2, ::2] 50 | 51 | def upsample(self, x): 52 | cc = torch.cat( 53 | [ 54 | x, 55 | torch.zeros( 56 | x.shape[0], x.shape[1], x.shape[2], x.shape[3], device=x.device 57 | ), 58 | ], 59 | dim=3, 60 | ) 61 | cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3]) 62 | cc = cc.permute(0, 1, 3, 2) 63 | cc = torch.cat( 64 | [ 65 | cc, 66 | torch.zeros( 67 | x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2, device=x.device 68 | ), 69 | ], 70 | dim=3, 71 | ) 72 | cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2) 73 | x_up = cc.permute(0, 1, 3, 2) 74 | return self.conv_gauss(x_up, 4 * self.kernel) 75 | 76 | def conv_gauss(self, img, kernel): 77 | img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode="reflect") 78 | out = torch.nn.functional.conv2d( 79 | img, kernel.to(img.device), groups=img.shape[1] 80 | ) 81 | return out 82 | 83 | def pyramid_decom(self, img): 84 | current = img 85 | pyr = [] 86 | for _ in range(self.num_high): 87 | filtered = self.conv_gauss(current, self.kernel) 88 | down = self.downsample(filtered) 89 | up = self.upsample(down) 90 | # if up.shape[2] != current.shape[2] or up.shape[3] != current.shape[3]: 91 | up = nn.functional.interpolate( 92 | up, size=(current.shape[2], current.shape[3]) 93 | ) 94 | diff = current - up 95 | pyr.append(diff) 96 | current = down 97 | pyr.append(current) 98 | return pyr 99 | 100 | def pyramid_recons(self, pyr): 101 | image = pyr[-1] 102 | for level in reversed(pyr[:-1]): 103 | up = self.upsample(image) 104 | # if up.shape[2] != level.shape[2] or up.shape[3] != level.shape[3]: 105 | up = nn.functional.interpolate(up, size=(level.shape[2], level.shape[3])) 106 | image = up + level 107 | return image 108 | 109 | 110 | # ----------- Texture Recovery Module ------------------- 111 | class TRM(nn.Module): 112 | def __init__(self, num_residual_blocks, num_high=3): 113 | super(TRM, self).__init__() 114 | 115 | self.num_high = num_high 116 | 117 | blocks = [nn.Conv2d(9, 64, 3, padding=1), nn.LeakyReLU()] 118 | 119 | for _ in range(num_residual_blocks): 120 | blocks += [ResidualBlock(64)] 121 | 122 | blocks += [nn.Conv2d(64, 3, 3, padding=1)] 123 | 124 | self.model = nn.Sequential(*blocks) 125 | 126 | channels = 3 127 | # Stage1 128 | self.block1_1 = ConvLayer( 129 | in_channels=channels, 130 | out_channels=channels, 131 | kernel_size=3, 132 | stride=1, 133 | dilation=2, 134 | norm=None, 135 | nonlinear="leakyrelu", 136 | ) 137 | self.block1_2 = ConvLayer( 138 | in_channels=channels, 139 | out_channels=channels, 140 | kernel_size=3, 141 | stride=1, 142 | dilation=4, 143 | norm=None, 144 | nonlinear="leakyrelu", 145 | ) 146 | self.aggreation1_rgb = Aggreation( 147 | in_channels=channels * 3, out_channels=channels 148 | ) 149 | # Stage2 150 | self.block2_1 = ConvLayer( 151 | in_channels=channels, 152 | out_channels=channels, 153 | kernel_size=3, 154 | stride=1, 155 | dilation=8, 156 | norm=None, 157 | nonlinear="leakyrelu", 158 | ) 159 | self.block2_2 = ConvLayer( 160 | in_channels=channels, 161 | out_channels=channels, 162 | kernel_size=3, 163 | stride=1, 164 | dilation=16, 165 | norm=None, 166 | nonlinear="leakyrelu", 167 | ) 168 | self.aggreation2_rgb = Aggreation( 169 | in_channels=channels * 3, out_channels=channels 170 | ) 171 | # Stage3 172 | self.block3_1 = ConvLayer( 173 | in_channels=channels, 174 | out_channels=channels, 175 | kernel_size=3, 176 | stride=1, 177 | dilation=32, 178 | norm=None, 179 | nonlinear="leakyrelu", 180 | ) 181 | self.block3_2 = ConvLayer( 182 | in_channels=channels, 183 | out_channels=channels, 184 | kernel_size=3, 185 | stride=1, 186 | dilation=64, 187 | norm=None, 188 | nonlinear="leakyrelu", 189 | ) 190 | self.aggreation3_rgb = Aggreation( 191 | in_channels=channels * 3, out_channels=channels 192 | ) 193 | # Stage3 194 | self.spp_img = SPP( 195 | in_channels=channels, 196 | out_channels=channels, 197 | num_layers=4, 198 | interpolation_type="bicubic", 199 | ) 200 | self.block4_1 = nn.Conv2d( 201 | in_channels=channels, out_channels=3, kernel_size=1, stride=1 202 | ) 203 | 204 | def forward(self, x, pyr_original, fake_low): 205 | pyr_result = [fake_low] 206 | mask = self.model(x) 207 | 208 | mask = nn.functional.interpolate( 209 | mask, size=(pyr_original[-2].shape[2], pyr_original[-2].shape[3]) 210 | ) 211 | result_highfreq = torch.mul(pyr_original[-2], mask) + pyr_original[-2] 212 | out1_1 = self.block1_1(result_highfreq) 213 | out1_2 = self.block1_2(out1_1) 214 | agg1_rgb = self.aggreation1_rgb( 215 | torch.cat((result_highfreq, out1_1, out1_2), dim=1) 216 | ) 217 | pyr_result.append(agg1_rgb) 218 | 219 | mask = nn.functional.interpolate( 220 | mask, size=(pyr_original[-3].shape[2], pyr_original[-3].shape[3]) 221 | ) 222 | result_highfreq = torch.mul(pyr_original[-3], mask) + pyr_original[-3] 223 | out2_1 = self.block2_1(result_highfreq) 224 | out2_2 = self.block2_2(out2_1) 225 | agg2_rgb = self.aggreation2_rgb( 226 | torch.cat((result_highfreq, out2_1, out2_2), dim=1) 227 | ) 228 | 229 | out3_1 = self.block3_1(agg2_rgb) 230 | out3_2 = self.block3_2(out3_1) 231 | agg3_rgb = self.aggreation3_rgb(torch.cat((agg2_rgb, out3_1, out3_2), dim=1)) 232 | 233 | spp_rgb = self.spp_img(agg3_rgb) 234 | out_rgb = self.block4_1(spp_rgb) 235 | 236 | pyr_result.append(out_rgb) 237 | pyr_result.reverse() 238 | 239 | return pyr_result 240 | 241 | 242 | def to_3d(x): 243 | return rearrange(x, "b c h w -> b (h w) c") 244 | 245 | 246 | def to_4d(x, h, w): 247 | return rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 248 | 249 | 250 | class BiasFree_LayerNorm(nn.Module): 251 | def __init__(self, normalized_shape): 252 | super(BiasFree_LayerNorm, self).__init__() 253 | if isinstance(normalized_shape, numbers.Integral): 254 | normalized_shape = (normalized_shape,) 255 | normalized_shape = torch.Size(normalized_shape) 256 | 257 | assert len(normalized_shape) == 1 258 | 259 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 260 | self.normalized_shape = normalized_shape 261 | 262 | def forward(self, x): 263 | sigma = x.var(-1, keepdim=True, unbiased=False) 264 | return x / torch.sqrt(sigma + 1e-5) * self.weight 265 | 266 | 267 | class WithBias_LayerNorm(nn.Module): 268 | def __init__(self, normalized_shape): 269 | super(WithBias_LayerNorm, self).__init__() 270 | if isinstance(normalized_shape, numbers.Integral): 271 | normalized_shape = (normalized_shape,) 272 | normalized_shape = torch.Size(normalized_shape) 273 | 274 | assert len(normalized_shape) == 1 275 | 276 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 277 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 278 | self.normalized_shape = normalized_shape 279 | 280 | def forward(self, x): 281 | mu = x.mean(-1, keepdim=True) 282 | sigma = x.var(-1, keepdim=True, unbiased=False) 283 | return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias 284 | 285 | 286 | class LayerNorm(nn.Module): 287 | def __init__(self, dim, LayerNorm_type): 288 | super(LayerNorm, self).__init__() 289 | if LayerNorm_type == "BiasFree": 290 | self.body = BiasFree_LayerNorm(dim) 291 | else: 292 | self.body = WithBias_LayerNorm(dim) 293 | 294 | def forward(self, x): 295 | h, w = x.shape[-2:] 296 | return to_4d(self.body(to_3d(x)), h, w) 297 | 298 | 299 | # Axis-based Multi-head Self-Attention 300 | 301 | 302 | class NextAttentionImplZ(nn.Module): 303 | def __init__(self, num_dims, num_heads, bias) -> None: 304 | super().__init__() 305 | self.num_dims = num_dims 306 | self.num_heads = num_heads 307 | self.q1 = nn.Conv2d(num_dims, num_dims * 3, kernel_size=1, bias=bias) 308 | self.q2 = nn.Conv2d( 309 | num_dims * 3, 310 | num_dims * 3, 311 | kernel_size=3, 312 | padding=1, 313 | groups=num_dims * 3, 314 | bias=bias, 315 | ) 316 | self.q3 = nn.Conv2d( 317 | num_dims * 3, 318 | num_dims * 3, 319 | kernel_size=3, 320 | padding=1, 321 | groups=num_dims * 3, 322 | bias=bias, 323 | ) 324 | 325 | self.fac = nn.Parameter(torch.ones(1)) 326 | self.fin = nn.Conv2d(num_dims, num_dims, kernel_size=1, bias=bias) 327 | return 328 | 329 | def forward(self, x): 330 | # x: [n, c, h, w] 331 | n, c, h, w = x.size() 332 | n_heads, dim_head = self.num_heads, c // self.num_heads 333 | reshape = lambda x: rearrange( 334 | x, "n (nh dh) h w -> (n nh h) w dh", nh=n_heads, dh=dim_head 335 | ) 336 | 337 | qkv = self.q3(self.q2(self.q1(x))) 338 | q, k, v = map(reshape, qkv.chunk(3, dim=1)) 339 | q = F.normalize(q, dim=-1) 340 | k = F.normalize(k, dim=-1) 341 | 342 | # fac = dim_head ** -0.5 343 | res = k.transpose(-2, -1) 344 | res = torch.matmul(q, res) * self.fac 345 | res = torch.softmax(res, dim=-1) 346 | 347 | res = torch.matmul(res, v) 348 | res = rearrange( 349 | res, "(n nh h) w dh -> n (nh dh) h w", nh=n_heads, dh=dim_head, n=n, h=h 350 | ) 351 | res = self.fin(res) 352 | 353 | return res 354 | 355 | 356 | # Axis-based Multi-head Self-Attention (row and col attention) 357 | class NextAttentionZ(nn.Module): 358 | def __init__(self, num_dims, num_heads=1, bias=True) -> None: 359 | super().__init__() 360 | assert num_dims % num_heads == 0 361 | self.num_dims = num_dims 362 | self.num_heads = num_heads 363 | self.row_att = NextAttentionImplZ(num_dims, num_heads, bias) 364 | self.col_att = NextAttentionImplZ(num_dims, num_heads, bias) 365 | return 366 | 367 | def forward(self, x: torch.Tensor): 368 | assert len(x.size()) == 4 369 | 370 | x = self.row_att(x) 371 | x = x.transpose(-2, -1) 372 | x = self.col_att(x) 373 | x = x.transpose(-2, -1) 374 | 375 | return x 376 | 377 | 378 | # Dual Gated Feed-Forward Networ 379 | class FeedForward(nn.Module): 380 | def __init__(self, dim, ffn_expansion_factor, bias): 381 | super(FeedForward, self).__init__() 382 | 383 | hidden_features = int(dim * ffn_expansion_factor) 384 | 385 | self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) 386 | 387 | self.dwconv = nn.Conv2d( 388 | hidden_features * 2, 389 | hidden_features * 2, 390 | kernel_size=3, 391 | stride=1, 392 | padding=1, 393 | groups=hidden_features * 2, 394 | bias=bias, 395 | ) 396 | 397 | self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) 398 | 399 | def forward(self, x): 400 | x = self.project_in(x) 401 | x1, x2 = self.dwconv(x).chunk(2, dim=1) 402 | x = F.gelu(x2) * x1 + F.gelu(x1) * x2 403 | x = self.project_out(x) 404 | return x 405 | 406 | 407 | # Axis-based Transformer Block 408 | class TransformerBlock(nn.Module): 409 | def __init__( 410 | self, 411 | dim, 412 | num_heads=1, 413 | ffn_expansion_factor=2.66, 414 | bias=True, 415 | LayerNorm_type="WithBias", 416 | ): 417 | super(TransformerBlock, self).__init__() 418 | 419 | self.norm1 = LayerNorm(dim, LayerNorm_type) 420 | self.attn = NextAttentionZ(dim, num_heads) 421 | self.norm2 = LayerNorm(dim, LayerNorm_type) 422 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias) 423 | 424 | def forward(self, x): 425 | x = x + self.attn(self.norm1(x)) 426 | x = x + self.ffn(self.norm2(x)) 427 | return x 428 | 429 | 430 | ########################################################################## 431 | # Overlapped image patch embedding with 3x3 Conv 432 | class OverlapPatchEmbed(nn.Module): 433 | def __init__(self, in_c=3, embed_dim=48, bias=False): 434 | super(OverlapPatchEmbed, self).__init__() 435 | 436 | self.proj = nn.Conv2d( 437 | in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias 438 | ) 439 | 440 | def forward(self, x): 441 | x = self.proj(x) 442 | 443 | return x 444 | 445 | 446 | ########################################################################## 447 | # Resizing modules 448 | class Downsample(nn.Module): 449 | def __init__(self, n_feat): 450 | super(Downsample, self).__init__() 451 | 452 | self.body = nn.Sequential( 453 | nn.Conv2d( 454 | n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False 455 | ), 456 | nn.PixelUnshuffle(2), 457 | ) 458 | 459 | def forward(self, x): 460 | return self.body(x) 461 | 462 | 463 | class Upsample(nn.Module): 464 | def __init__(self, n_feat): 465 | super(Upsample, self).__init__() 466 | 467 | self.body = nn.Sequential( 468 | nn.Conv2d( 469 | n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False 470 | ), 471 | nn.PixelShuffle(2), 472 | ) 473 | 474 | def forward(self, x): 475 | return self.body(x) 476 | 477 | 478 | # Tri-layer Attention Alignment Block 479 | class TAA(nn.Module): 480 | """Layer attention module""" 481 | 482 | def __init__(self, in_dim, bias=True): 483 | super(TAA, self).__init__() 484 | self.chanel_in = in_dim 485 | 486 | self.temperature = nn.Parameter(torch.ones(1)) 487 | 488 | self.qkv = nn.Conv2d( 489 | self.chanel_in, self.chanel_in * 3, kernel_size=1, bias=bias 490 | ) 491 | self.qkv_dwconv = nn.Conv2d( 492 | self.chanel_in * 3, 493 | self.chanel_in * 3, 494 | kernel_size=3, 495 | stride=1, 496 | padding=1, 497 | groups=self.chanel_in * 3, 498 | bias=bias, 499 | ) 500 | self.project_out = nn.Conv2d( 501 | self.chanel_in, self.chanel_in, kernel_size=1, bias=bias 502 | ) 503 | 504 | def forward(self, x): 505 | """ 506 | inputs : 507 | x : input feature maps( B X N X C X H X W) 508 | returns : 509 | out : attention value + input feature 510 | attention: B X N X N 511 | """ 512 | m_batchsize, N, C, height, width = x.size() 513 | 514 | x_input = x.view(m_batchsize, N * C, height, width) 515 | qkv = self.qkv_dwconv(self.qkv(x_input)) 516 | q, k, v = qkv.chunk(3, dim=1) 517 | q = q.view(m_batchsize, N, -1) 518 | k = k.view(m_batchsize, N, -1) 519 | v = v.view(m_batchsize, N, -1) 520 | 521 | q = torch.nn.functional.normalize(q, dim=-1) 522 | k = torch.nn.functional.normalize(k, dim=-1) 523 | 524 | attn = (q @ k.transpose(-2, -1)) * self.temperature 525 | attn = attn.softmax(dim=-1) 526 | 527 | out_1 = attn @ v 528 | out_1 = out_1.view(m_batchsize, -1, height, width) 529 | 530 | out_1 = self.project_out(out_1) 531 | out_1 = out_1.view(m_batchsize, N, C, height, width) 532 | 533 | out = out_1 + x 534 | out = out.view(m_batchsize, -1, height, width) 535 | return out 536 | 537 | 538 | ########################################################################## 539 | # ---------- Dimension-Aware Transformer Block ----------------------- 540 | class Backbone(nn.Module): 541 | def __init__( 542 | self, 543 | inp_channels=3, 544 | out_channels=3, 545 | dim=3, 546 | num_blocks=[1, 2, 4, 8], 547 | num_refinement_blocks=2, 548 | heads=[1, 2, 4, 8], 549 | ffn_expansion_factor=2.66, 550 | bias=False, 551 | LayerNorm_type="WithBias", 552 | attention=True, 553 | ): 554 | super(Backbone, self).__init__() 555 | 556 | self.patch_embed = OverlapPatchEmbed(inp_channels, dim) 557 | 558 | self.encoder_1 = nn.Sequential( 559 | *[ 560 | TransformerBlock( 561 | dim=dim, 562 | num_heads=heads[0], 563 | ffn_expansion_factor=ffn_expansion_factor, 564 | bias=bias, 565 | LayerNorm_type=LayerNorm_type, 566 | ) 567 | for _ in range(num_blocks[0]) 568 | ] 569 | ) 570 | 571 | self.encoder_2 = nn.Sequential( 572 | *[ 573 | TransformerBlock( 574 | dim=int(dim), 575 | num_heads=heads[0], 576 | ffn_expansion_factor=ffn_expansion_factor, 577 | bias=bias, 578 | LayerNorm_type=LayerNorm_type, 579 | ) 580 | for _ in range(num_blocks[0]) 581 | ] 582 | ) 583 | 584 | self.encoder_3 = nn.Sequential( 585 | *[ 586 | TransformerBlock( 587 | dim=int(dim), 588 | num_heads=heads[0], 589 | ffn_expansion_factor=ffn_expansion_factor, 590 | bias=bias, 591 | LayerNorm_type=LayerNorm_type, 592 | ) 593 | for _ in range(num_blocks[0]) 594 | ] 595 | ) 596 | 597 | self.layer_fussion = TAA(in_dim=int(dim * 3)) 598 | self.conv_fuss = nn.Conv2d(int(dim * 3), int(dim), kernel_size=1, bias=bias) 599 | 600 | self.latent = nn.Sequential( 601 | *[ 602 | TransformerBlock( 603 | dim=int(dim), 604 | num_heads=heads[0], 605 | ffn_expansion_factor=ffn_expansion_factor, 606 | bias=bias, 607 | LayerNorm_type=LayerNorm_type, 608 | ) 609 | for _ in range(num_blocks[0]) 610 | ] 611 | ) 612 | 613 | self.trans_low = DFE() 614 | 615 | self.coefficient_1_0 = nn.Parameter( 616 | torch.ones((2, int(int(dim)))), requires_grad=attention 617 | ) 618 | 619 | self.refinement_1 = nn.Sequential( 620 | *[ 621 | TransformerBlock( 622 | dim=int(dim), 623 | num_heads=heads[0], 624 | ffn_expansion_factor=ffn_expansion_factor, 625 | bias=bias, 626 | LayerNorm_type=LayerNorm_type, 627 | ) 628 | for _ in range(num_refinement_blocks) 629 | ] 630 | ) 631 | self.refinement_2 = nn.Sequential( 632 | *[ 633 | TransformerBlock( 634 | dim=int(dim), 635 | num_heads=heads[0], 636 | ffn_expansion_factor=ffn_expansion_factor, 637 | bias=bias, 638 | LayerNorm_type=LayerNorm_type, 639 | ) 640 | for _ in range(num_refinement_blocks) 641 | ] 642 | ) 643 | self.refinement_3 = nn.Sequential( 644 | *[ 645 | TransformerBlock( 646 | dim=int(dim), 647 | num_heads=heads[0], 648 | ffn_expansion_factor=ffn_expansion_factor, 649 | bias=bias, 650 | LayerNorm_type=LayerNorm_type, 651 | ) 652 | for _ in range(num_refinement_blocks) 653 | ] 654 | ) 655 | 656 | self.layer_fussion_2 = TAA(in_dim=int(dim * 3)) 657 | self.conv_fuss_2 = nn.Conv2d(int(dim * 3), int(dim), kernel_size=1, bias=bias) 658 | 659 | self.output = nn.Conv2d( 660 | int(dim), out_channels, kernel_size=3, stride=1, padding=1, bias=bias 661 | ) 662 | 663 | def forward(self, inp): 664 | inp_enc_encoder1 = self.patch_embed(inp) 665 | out_enc_encoder1 = self.encoder_1(inp_enc_encoder1) 666 | out_enc_encoder2 = self.encoder_2(out_enc_encoder1) 667 | out_enc_encoder3 = self.encoder_3(out_enc_encoder2) 668 | 669 | inp_fusion_123 = torch.cat( 670 | [ 671 | out_enc_encoder1.unsqueeze(1), 672 | out_enc_encoder2.unsqueeze(1), 673 | out_enc_encoder3.unsqueeze(1), 674 | ], 675 | dim=1, 676 | ) 677 | 678 | out_fusion_123 = self.layer_fussion(inp_fusion_123) 679 | out_fusion_123 = self.conv_fuss(out_fusion_123) 680 | 681 | out_enc = self.trans_low(out_fusion_123) 682 | 683 | out_fusion_123 = self.latent(out_fusion_123) 684 | 685 | out = ( 686 | self.coefficient_1_0[0, :][None, :, None, None] * out_fusion_123 687 | + self.coefficient_1_0[1, :][None, :, None, None] * out_enc 688 | ) 689 | 690 | out_1 = self.refinement_1(out) 691 | out_2 = self.refinement_2(out_1) 692 | out_3 = self.refinement_3(out_2) 693 | 694 | inp_fusion = torch.cat( 695 | [out_1.unsqueeze(1), out_2.unsqueeze(1), out_3.unsqueeze(1)], dim=1 696 | ) 697 | out_fusion_123 = self.layer_fussion_2(inp_fusion) 698 | out = self.conv_fuss_2(out_fusion_123) 699 | result = self.output(out) 700 | 701 | return result 702 | 703 | 704 | class DocShadow(nn.Module): 705 | def __init__(self, depth=2): 706 | super().__init__() 707 | self.backbone = Backbone() 708 | self.lap_pyramid = LapPyramidConv(depth) 709 | self.trans_high = TRM(3, num_high=depth) 710 | 711 | def forward(self, inp): 712 | pyr_inp = self.lap_pyramid.pyramid_decom(img=inp) 713 | out_low = self.backbone(pyr_inp[-1]) 714 | 715 | inp_up = nn.functional.interpolate( 716 | pyr_inp[-1], size=(pyr_inp[-2].shape[2], pyr_inp[-2].shape[3]) 717 | ) 718 | out_up = nn.functional.interpolate( 719 | out_low, size=(pyr_inp[-2].shape[2], pyr_inp[-2].shape[3]) 720 | ) 721 | high_with_low = torch.cat([pyr_inp[-2], inp_up, out_up], 1) 722 | 723 | pyr_inp_trans = self.trans_high(high_with_low, pyr_inp, out_low) 724 | 725 | result = self.lap_pyramid.pyramid_recons(pyr_inp_trans) 726 | 727 | return result 728 | -------------------------------------------------------------------------------- /DocShadow/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import load_checkpoint 2 | -------------------------------------------------------------------------------- /DocShadow/utils/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from pathlib import Path 3 | 4 | import torch 5 | 6 | MODEL_URLS = { 7 | "sd7k": "https://drive.google.com/uc?export=download&confirm=t&id=1a1dGcSFYB1ocR5KohSTjXIQcpQ0w05V7", 8 | "jung": "https://drive.google.com/uc?export=download&confirm=t&id=19i0ms_5Cv2tOE6SmL7vyms_gCdZCZQDt", 9 | "kligler": "https://drive.google.com/uc?export=download&confirm=t&id=1JEmtyGeyhCNdZ9_yhhEYJSeFTgQr5XYw", 10 | } 11 | 12 | 13 | def load_checkpoint(model: torch.nn.Module, weights: str, device) -> None: 14 | # Check if local path 15 | if Path(weights).exists(): 16 | checkpoint = torch.load(weights, map_location=str(device)) 17 | else: 18 | # Download 19 | assert ( 20 | weights.lower() in MODEL_URLS.keys() 21 | ), f"DocShadow has only been trained on {MODEL_URLS.keys()}" 22 | checkpoint = torch.hub.load_state_dict_from_url( 23 | MODEL_URLS[weights.lower()], 24 | file_name=f"{weights.lower()}.pth", 25 | map_location=str(device), 26 | ) 27 | 28 | new_state_dict = OrderedDict() 29 | for key, value in checkpoint["state_dict"].items(): 30 | if key.startswith("module"): 31 | name = key[7:] 32 | else: 33 | name = key 34 | new_state_dict[name] = value 35 | 36 | model.load_state_dict(new_state_dict) 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Nick Chen 4 | Copyright (c) 2023 Fabio Milentiansen Sim 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![GitHub](https://img.shields.io/github/license/fabio-sim/DocShadow-ONNX-TensorRT)](/LICENSE) 2 | [![ONNX](https://img.shields.io/badge/ONNX-grey)](https://onnx.ai/) 3 | [![TensorRT](https://img.shields.io/badge/TensorRT-76B900)](https://developer.nvidia.com/tensorrt) 4 | [![GitHub Repo stars](https://img.shields.io/github/stars/fabio-sim/DocShadow-ONNX-TensorRT)](https://github.com/fabio-sim/DocShadow-ONNX-TensorRT/stargazers) 5 | [![GitHub all releases](https://img.shields.io/github/downloads/fabio-sim/DocShadow-ONNX-TensorRT/total)](https://github.com/fabio-sim/DocShadow-ONNX-TensorRT/releases) 6 | 7 | # DocShadow-ONNX-TensorRT 8 | Open Neural Network Exchange (ONNX) compatible implementation of [DocShadow: High-Resolution Document Shadow Removal via A Large-scale Real-world Dataset and A Frequency-aware Shadow Erasing Net](https://github.com/CXH-Research/DocShadow-SD7K). Supports TensorRT 🚀. 9 | 10 |

Latency figure
DocShadow ONNX TensorRT provides up to a 2x speedup over PyTorch.

11 | 12 | ## 🔥 ONNX Export 13 | 14 | Prior to exporting the ONNX models, please install the [requirements](/requirements.txt). 15 | 16 | To convert the DocShadow models to ONNX, run [`export.py`](/export.py). 17 | 18 |
19 | Export Example 20 |
21 | python export.py \
22 |     --weights sd7k \
23 |     --dynamic_img_size --dynamic_batch
24 | 
25 |
26 | 27 | If you would like to try out inference right away, you can download ONNX models that have already been exported [here](https://github.com/fabio-sim/DocShadow-ONNX-TensorRT/releases) or run `./weights/download.sh`. 28 | 29 | ## ⚡ ONNX Inference 30 | 31 | With ONNX models in hand, one can perform inference on Python using ONNX Runtime (see [requirements-onnx.txt](/requirements-onnx.txt)). 32 | 33 | The DocShadow inference pipeline has been encapsulated into a runner class: 34 | 35 | ```python 36 | from onnx_runner import DocShadowRunner 37 | 38 | images = DocShadowRunner.preprocess(image_array) 39 | # images.shape == (B, 3, H, W) 40 | 41 | # Create ONNXRuntime runner 42 | runner = DocShadowRunner( 43 | onnx_path="weights/docshadow_sd7k.onnx", 44 | providers=["CUDAExecutionProvider", "CPUExecutionProvider"], 45 | # TensorrtExecutionProvider 46 | ) 47 | 48 | # Run inference 49 | result = runner.run(images) 50 | ``` 51 | Alternatively, you can also run [`infer.py`](/infer.py). 52 | 53 |
54 | Inference Example 55 |
56 | python infer.py \
57 |     --img_path assets/sample.jpg \
58 |     --img_size 256 256 \
59 |     --onnx_path weights/docshadow_sd7k.onnx \
60 |     --viz
61 | 
62 |
63 | 64 | ## 🚀 TensorRT Support 65 | 66 | TensorRT offers the best performance and greatest memory efficiency. 67 | 68 | TensorRT inference is supported for the DocShadow model via the TensorRT Execution Provider in ONNXRuntime. Please follow the [official documentation](https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html) to install TensorRT. 69 | 70 |
71 | TensorRT Example 72 |
73 | CUDA_MODULE_LOADING=LAZY && python infer.py \
74 |   --img_path assets/sample.jpg \
75 |   --onnx_path weights/docshadow_sd7k.onnx \
76 |   --img_size 256 256 \
77 |   --trt \
78 |   --viz
79 | 
80 |
81 | 82 | The first run will take longer because TensorRT needs to initialise the `.engine` and `.profile` files. Subsequent runs should use the cached files. Only static input shapes are supported. Note that TensorRT will rebuild the cache if it encounters a different input shape. 83 | 84 | ## Credits 85 | If you use any ideas from the papers or code in this repo, please consider citing the authors of [DocShadow](https://arxiv.org/abs/2308.14221). Lastly, if the ONNX or TensorRT versions helped you in any way, please also consider starring this repository. 86 | 87 | ```bibtex 88 | @article{docshadow_sd7k, 89 | title={High-Resolution Document Shadow Removal via A Large-Scale Real-World Dataset and A Frequency-Aware Shadow Erasing Net}, 90 | author={Li, Zinuo and Chen, Xuhang and Pun, Chi-Man and Cun, Xiaodong}, 91 | journal={arXiv preprint arXiv:2308.14221}, 92 | year={2023} 93 | } 94 | ``` 95 | -------------------------------------------------------------------------------- /assets/latency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fabio-sim/DocShadow-ONNX-TensorRT/ec926bf36b4ac0778f836a2d34e27021447df27c/assets/latency.png -------------------------------------------------------------------------------- /assets/sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fabio-sim/DocShadow-ONNX-TensorRT/ec926bf36b4ac0778f836a2d34e27021447df27c/assets/sample.jpg -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import numpy as np 5 | from PIL import Image 6 | from tqdm import tqdm 7 | 8 | 9 | def parse_args() -> argparse.Namespace: 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument( 12 | "framework", 13 | type=str, 14 | choices=["torch", "ort"], 15 | help="The framework to measure inference time. Options are 'torch' for PyTorch and 'ort' for ONNXRuntime.", 16 | ) 17 | parser.add_argument( 18 | "--img_path", 19 | type=str, 20 | default="assets/sample.jpg", 21 | required=False, 22 | help="Path to the root of the MegaDepth dataset.", 23 | ) 24 | parser.add_argument( 25 | "--img_size", 26 | nargs=2, 27 | type=int, 28 | default=[512, 512], 29 | required=False, 30 | help="Image size for inference. Please provide two integers (height width). Ensure that you have enough memory.", 31 | ) 32 | 33 | # ONNXRuntime-specific args 34 | parser.add_argument( 35 | "--onnx_path", 36 | type=str, 37 | default=None, 38 | required=False, 39 | help="Path to ONNX model (end2end).", 40 | ) 41 | # parser.add_argument( 42 | # "--fp16", 43 | # action="store_true", 44 | # help="Whether to enable half-precision for ONNXRuntime.", 45 | # ) 46 | parser.add_argument( 47 | "--trt", 48 | action="store_true", 49 | help="Whether to use TensorRT Execution Provider.", 50 | ) 51 | return parser.parse_args() 52 | 53 | 54 | def create_models(framework: str, fp16=False, onnx_path=None, trt=False): 55 | if framework == "torch": 56 | device = torch.device("cuda") 57 | 58 | model = DocShadow() 59 | load_checkpoint(model, "sd7k", device) 60 | model.eval().to(device) 61 | elif framework == "ort": 62 | if onnx_path is None: 63 | onnx_path = ( 64 | f"weights/docshadow_sd7k" 65 | f"{'_fp16' if fp16 and not trt else ''}" 66 | ".onnx" 67 | ) 68 | 69 | providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] 70 | sess_opts = ort.SessionOptions() 71 | if trt: 72 | providers.insert( 73 | 0, 74 | ( 75 | "TensorrtExecutionProvider", 76 | { 77 | "trt_fp16_enable": fp16, 78 | "trt_engine_cache_enable": True, 79 | "trt_engine_cache_path": "weights/cache", 80 | "trt_builder_optimization_level": 5, 81 | }, 82 | ), 83 | ) 84 | model = ort.InferenceSession( 85 | onnx_path, sess_options=sess_opts, providers=providers 86 | ) 87 | 88 | return model 89 | 90 | 91 | def get_inputs(framework: str, img_path, img_size, fp16, trt): 92 | img = Image.open(img_path).convert("RGB") 93 | H, W = img_size 94 | img = img.resize((W, H)) 95 | 96 | if framework == "torch": 97 | image = to_tensor(img)[None].cuda() 98 | elif framework == "ort": 99 | image = DocShadowRunner.preprocess(np.array(img)) 100 | if fp16 and not trt: 101 | image = image.astype(np.float16) 102 | 103 | return image 104 | 105 | 106 | def measure_inference(framework: str, model, images, fp16) -> float: 107 | if framework == "torch": 108 | start = torch.cuda.Event(enable_timing=True) 109 | end = torch.cuda.Event(enable_timing=True) 110 | 111 | start.record() 112 | with torch.inference_mode(): 113 | result = model(images) 114 | end.record() 115 | torch.cuda.synchronize() 116 | 117 | return start.elapsed_time(end) 118 | elif framework == "ort": 119 | model_inputs = {"image": images} 120 | model_outputs = ["result"] 121 | 122 | # Prepare IO-Bindings 123 | binding = model.io_binding() 124 | 125 | for name, arr in model_inputs.items(): 126 | binding.bind_cpu_input(name, arr) 127 | 128 | for name in model_outputs: 129 | binding.bind_output(name, "cuda") 130 | 131 | # Measure only matching time 132 | start = time.perf_counter() 133 | result = model.run_with_iobinding(binding) 134 | end = time.perf_counter() 135 | 136 | return (end - start) * 1000 137 | 138 | 139 | def evaluate( 140 | framework: str, 141 | img_path="assets/sample.jpg", 142 | img_size=[512, 512], 143 | fp16=False, 144 | onnx_path=None, 145 | trt=False, 146 | ): 147 | model = create_models( 148 | framework, 149 | fp16=fp16, 150 | onnx_path=onnx_path, 151 | trt=trt, 152 | ) 153 | 154 | # Warmup 155 | for _ in tqdm(range(5)): 156 | images = get_inputs(framework, img_path, img_size=img_size, fp16=fp16, trt=trt) 157 | _ = measure_inference(framework, model, images, fp16=fp16) 158 | 159 | # Measure 160 | timings = [] 161 | for _ in tqdm(range(1000)): 162 | images = get_inputs(framework, img_path, img_size=img_size, fp16=fp16, trt=trt) 163 | 164 | inference_time = measure_inference(framework, model, images, fp16=fp16) 165 | timings.append(inference_time) 166 | 167 | # Results 168 | timings = np.array(timings) 169 | print(timings) 170 | print(f"Mean inference time: {timings.mean():.2f} +/- {timings.std():.2f} ms") 171 | print(f"Median inference time: {np.median(timings):.2f} ms") 172 | 173 | 174 | if __name__ == "__main__": 175 | args = parse_args() 176 | if args.framework == "torch": 177 | import torch 178 | from torchvision.transforms.functional import to_tensor 179 | 180 | from DocShadow.models import DocShadow 181 | from DocShadow.utils import load_checkpoint 182 | elif args.framework == "ort": 183 | import onnxruntime as ort 184 | 185 | from onnx_runner import DocShadowRunner 186 | 187 | evaluate(**vars(args)) 188 | -------------------------------------------------------------------------------- /export.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import warnings 3 | 4 | warnings.filterwarnings("ignore", module="onnxconverter_common.float16") 5 | 6 | import onnx 7 | import torch 8 | from onnxconverter_common import float16 9 | 10 | from DocShadow.models import DocShadow 11 | from DocShadow.utils import load_checkpoint 12 | 13 | 14 | def parse_args() -> argparse.Namespace: 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | "--img_size", 18 | nargs=2, 19 | type=int, 20 | default=[256, 256], 21 | required=False, 22 | help="Sample image size for ONNX tracing. Please provide two integers (height width). Ensure that you have enough memory to run the export.", 23 | ) 24 | parser.add_argument( 25 | "--weights", 26 | type=str, 27 | default="sd7k", 28 | required=False, 29 | help="DocShadow has been trained on these datasets: ['sd7k', 'jung', 'kliger']. Defaults to 'sd7k' weights. You can also specify a local path to the weights.", 30 | ) 31 | parser.add_argument( 32 | "--onnx_path", 33 | type=str, 34 | default=None, 35 | required=False, 36 | help="Path to save the exported ONNX model.", 37 | ) 38 | parser.add_argument( 39 | "--dynamic_img_size", 40 | action="store_true", 41 | help="Whether to allow dynamic image sizes.", 42 | ) 43 | parser.add_argument( 44 | "--dynamic_batch", 45 | action="store_true", 46 | help="Whether to allow dynamic batch size.", 47 | ) 48 | parser.add_argument( 49 | "--fp16", 50 | action="store_true", 51 | help="Whether to also export float16 (half) ONNX model (CUDA only).", 52 | ) 53 | 54 | return parser.parse_args() 55 | 56 | 57 | def export_onnx( 58 | img_size=[256, 256], 59 | weights="sd7k", 60 | onnx_path=None, 61 | dynamic_img_size=False, 62 | dynamic_batch=False, 63 | fp16=False, 64 | ): 65 | # Handle args. 66 | H, W = img_size 67 | if onnx_path is None: 68 | onnx_path = ( 69 | f"weights/docshadow_{weights}" 70 | f"{f'_{H}x{W}' if not dynamic_img_size else ''}" 71 | ".onnx" 72 | ) 73 | 74 | # Load inputs and models. 75 | device = torch.device("cpu") # Device on which to export. 76 | 77 | img = torch.rand(1, 3, H, W, dtype=torch.float32, device=device) 78 | 79 | docshadow = DocShadow() 80 | load_checkpoint(docshadow, weights, device) 81 | docshadow.eval().to(device) 82 | 83 | # Export. 84 | opset_version = 12 85 | dynamic_axes = {"image": {}, "result": {}} 86 | if dynamic_batch: 87 | dynamic_axes["image"].update({0: "batch_size"}) 88 | dynamic_axes["result"].update({0: "batch_size"}) 89 | if dynamic_img_size: 90 | dynamic_axes["image"].update({2: "height", 3: "width"}) 91 | dynamic_axes["result"].update({2: "height", 3: "width"}) 92 | 93 | torch.onnx.export( 94 | docshadow, 95 | img, 96 | onnx_path, 97 | input_names=["image"], 98 | output_names=["result"], 99 | opset_version=opset_version, 100 | dynamic_axes=dynamic_axes, 101 | ) 102 | if fp16: 103 | convert_fp16(onnx_path) 104 | 105 | 106 | def convert_fp16(onnx_model_path: str): 107 | onnx_model = onnx.load(onnx_model_path) 108 | fp16_model = float16.convert_float_to_float16(onnx_model) 109 | onnx.save(fp16_model, onnx_model_path.replace(".onnx", "_fp16.onnx")) 110 | 111 | 112 | if __name__ == "__main__": 113 | args = parse_args() 114 | export_onnx(**vars(args)) 115 | -------------------------------------------------------------------------------- /export_coreml.py: -------------------------------------------------------------------------------- 1 | """Export DocShadow to CoreML.""" 2 | import coremltools as ct # 6.3.0 3 | import torch 4 | 5 | from DocShadow.models import DocShadow 6 | from DocShadow.models.backbone import LayerNorm2d 7 | from DocShadow.models.model import WithBias_LayerNorm 8 | from DocShadow.utils import load_checkpoint 9 | 10 | H, W = 256, 256 11 | weights = "sd7k" # "jung", "kligler" 12 | 13 | 14 | # Patches for CoreML compatibility 15 | 16 | 17 | def WithBias_LayerNorm_forward(self, x): 18 | """Manually compute variance instead of using Tensor.var()""" 19 | mu = x.mean(-1, keepdim=True) 20 | sigma = (x - mu).pow(2).mean(-1, keepdim=True) 21 | return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias 22 | 23 | 24 | WithBias_LayerNorm.forward = WithBias_LayerNorm_forward 25 | 26 | 27 | def LayerNorm2d_forward(self, x): 28 | """Layer Normalization over the channels dimension only.""" 29 | N, C, H, W = x.size() 30 | mu = x.mean(1, keepdim=True) 31 | var = (x - mu).pow(2).mean(1, keepdim=True) 32 | y = (x - mu) / (var + self.eps).sqrt() 33 | y = self.weight.view(1, C, 1, 1) * y + self.bias.view(1, C, 1, 1) 34 | return y 35 | 36 | 37 | LayerNorm2d.forward = LayerNorm2d_forward 38 | 39 | # Load inputs and models. 40 | device = torch.device("cpu") # Device on which to export. 41 | 42 | img = torch.rand(1, 3, H, W, dtype=torch.float32, device=device) 43 | 44 | docshadow = DocShadow() 45 | load_checkpoint(docshadow, weights, device) 46 | docshadow.eval().to(device) 47 | 48 | docshadow.trans_high.spp_img.interpolation_type = "bilinear" # bicubic unsupported 49 | 50 | traced_docshadow = torch.jit.trace(docshadow, img) 51 | 52 | coreml_docshadow = ct.convert( 53 | traced_docshadow, 54 | # convert_to="mlprogram", 55 | inputs=[ct.TensorType(shape=img.shape)], 56 | ) 57 | 58 | coreml_docshadow.save(f"weights/docshadow_{weights}.mlmodel") 59 | # coreml_docshadow.save(f"weights/docshadow_{weights}.mlpackage") 60 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | from PIL import Image 5 | 6 | from onnx_runner import DocShadowRunner 7 | 8 | 9 | def parse_args() -> argparse.Namespace: 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument( 12 | "--img_path", 13 | type=str, 14 | default="assets/sample.jpg", 15 | required=False, 16 | help="Path to input image for inference.", 17 | ) 18 | parser.add_argument( 19 | "--img_size", 20 | nargs=2, 21 | type=int, 22 | default=[512, 512], 23 | required=False, 24 | help="Image size for inference. Please provide two integers (height width). Ensure that you have enough memory.", 25 | ) 26 | parser.add_argument( 27 | "--onnx_path", 28 | type=str, 29 | default=None, 30 | required=False, 31 | help="Path to the ONNX model.", 32 | ) 33 | parser.add_argument( 34 | "--fp16", 35 | action="store_true", 36 | help="Whether to run inference using float16 (half) ONNX model (CUDA only).", 37 | ) 38 | parser.add_argument( 39 | "--trt", 40 | action="store_true", 41 | help="Whether to use TensorRT. Note that the end2end ONNX model must NOT be exported with --fp16. TensorRT will perform the conversion instead. Only static input shapes are supported.", 42 | ) 43 | parser.add_argument( 44 | "--viz", action="store_true", help="Whether to visualize the results." 45 | ) 46 | return parser.parse_args() 47 | 48 | 49 | def infer( 50 | img_path="assets/sample.jpg", 51 | img_size=[512, 512], 52 | onnx_path=None, 53 | fp16=False, 54 | trt=False, 55 | viz=False, 56 | ): 57 | img = Image.open(img_path).convert("RGB") 58 | orig_W, orig_H = img.size 59 | 60 | # Handle args. 61 | if onnx_path is None: 62 | onnx_path = "weights/docshadow_sd7k.onnx" # default path 63 | 64 | # Preprocessing 65 | H, W = img_size 66 | image = DocShadowRunner.preprocess(np.array(img.resize((W, H)))) 67 | if fp16 and not trt: 68 | image = image.astype(np.float16) 69 | 70 | # Inference 71 | providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] 72 | if trt: 73 | providers.insert( 74 | 0, 75 | ( 76 | "TensorrtExecutionProvider", 77 | { 78 | "trt_fp16_enable": fp16, 79 | "trt_engine_cache_enable": True, 80 | "trt_engine_cache_path": "weights/cache", 81 | }, 82 | ), 83 | ) 84 | 85 | runner = DocShadowRunner(onnx_path, providers=providers) 86 | result = runner.run(image) 87 | 88 | # Visualisation 89 | if viz: 90 | import cv2 91 | 92 | result_img = result[0].transpose(1, 2, 0) 93 | result_img = cv2.resize(result_img, (orig_W, orig_H)) 94 | cv2.imshow("result", cv2.cvtColor(result_img, cv2.COLOR_RGB2BGR)) 95 | cv2.waitKey(0) 96 | 97 | return result 98 | 99 | 100 | if __name__ == "__main__": 101 | args = parse_args() 102 | result = infer(**vars(args)) 103 | print(result) 104 | print(result.shape) 105 | -------------------------------------------------------------------------------- /onnx_runner/__init__.py: -------------------------------------------------------------------------------- 1 | from .docshadow import DocShadowRunner 2 | -------------------------------------------------------------------------------- /onnx_runner/docshadow.py: -------------------------------------------------------------------------------- 1 | # No dependency on PyTorch 2 | 3 | import numpy as np 4 | import onnxruntime as ort 5 | from PIL import Image 6 | 7 | 8 | class DocShadowRunner: 9 | def __init__( 10 | self, 11 | onnx_path=None, 12 | providers=["CUDAExecutionProvider", "CPUExecutionProvider"], 13 | ): 14 | self.model = ort.InferenceSession(onnx_path, providers=providers) 15 | 16 | def run(self, images: np.ndarray) -> np.ndarray: 17 | result = self.model.run(None, {"image": images})[0] 18 | return result 19 | 20 | @staticmethod 21 | def preprocess(image: np.ndarray) -> np.ndarray: 22 | # image.shape == (H, W, C) 23 | image = np.asarray(image) / 255 24 | image = image[None].transpose(0, 3, 1, 2) 25 | image = image.astype(np.float32) 26 | return image 27 | -------------------------------------------------------------------------------- /requirements-onnx.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | onnx 3 | onnxruntime-gpu 4 | opencv-python # Only for visualisation 5 | pillow 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | numpy 3 | onnx 4 | onnxconverter-common 5 | onnxruntime-gpu 6 | opencv-python 7 | torch 8 | torchvision 9 | pillow 10 | -------------------------------------------------------------------------------- /weights/.gitkeep: -------------------------------------------------------------------------------- 1 | ONNX models will be exported here. 2 | -------------------------------------------------------------------------------- /weights/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | RELEASE=v1.0.0 4 | 5 | curl -L https://github.com/fabio-sim/DocShadow-ONNX-TensorRT/releases/download/${RELEASE}/docshadow_sd7k.onnx -o weights/docshadow_sd7k.onnx 6 | curl -L https://github.com/fabio-sim/DocShadow-ONNX-TensorRT/releases/download/${RELEASE}/docshadow_jung.onnx -o weights/docshadow_jung.onnx 7 | curl -L https://github.com/fabio-sim/DocShadow-ONNX-TensorRT/releases/download/${RELEASE}/docshadow_kligler.onnx -o weights/docshadow_kligler.onnx 8 | --------------------------------------------------------------------------------