├── LICENSE ├── README.md ├── assets └── network.png ├── model └── splitsr.py └── modules └── blocks.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Luka Chkhetiani 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SplitSR 2 | Unofficial implementation of [SplitSR: An End-to-End Approach to Super-Resolution on Mobile Devices](https://arxiv.org/abs/2101.07996) 3 | 4 | ![a) SplitSRBlock, b) SplitSR ](assets/network.png) 5 | 6 | ## Keys from the Paper 7 | - Split convolution splits input by alpha ratio along depth channel. 8 | - The conv-processed part is concatenated at the end. 9 | - By the second key point, every channel would be processed after 1/α blocks. 10 | - The theoretical computation reduction that can be obtained by using SplitSR is 𝛼^2, where 𝛼 ∈ (0, 1] 11 | - The architecture is very much similar to RCAN's, by replacing channelwise attention blocks with split convolutions. 12 | - Many proposed details are ambiguous. We've to guess. 13 | 14 | ## Config 15 | - 𝛼 = 0.250 16 | - Groups = 6, Blocks = 6 17 | - Hybrid Index = 3 18 | - Loss - L1 19 | - Base LR - 1e-4 20 | - LR Decay - 2.0 every 2 × 10^2 21 | - Adam, 𝛽1 = 0.9, 𝛽2 = 0.999 22 | - 𝜖 = 1e−7 23 | - Steps = 6 × 10^5 24 | 25 | ## Progress 26 | - Splitted Convolution Block is done. 27 | - Residual Block is done. 28 | - Mean shift layer is done. 29 | - Beta version of model is ready. -------------------------------------------------------------------------------- /assets/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepconsc/SplitSR/e9fbc203c74c43f3681ee7878c540f71ec8c1ccf/assets/network.png -------------------------------------------------------------------------------- /model/splitsr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from modules.blocks import MeanShift, ResidualBlock, Upsample, SplitSRBlock 4 | 5 | class SplitSR(nn.Module): 6 | def __init__(self): 7 | super(SplitSR, self).__init__() 8 | 9 | self.ResidualGroup = nn.Sequential( 10 | ResidualBlock(channels=64), 11 | SplitSRBlock(channels=64, kernel=3, alpha=0.250), 12 | SplitSRBlock(channels=64, kernel=3, alpha=0.250), 13 | ResidualBlock(channels=64), 14 | SplitSRBlock(channels=64, kernel=3, alpha=0.250), 15 | SplitSRBlock(channels=64, kernel=3, alpha=0.250), 16 | ) 17 | self.conv_head = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=3//2, bias=True) 18 | self.conv_back = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=3//2, bias=True) 19 | self.upsample = Upsample(64) 20 | self.MeanSubstract = MeanShift(-1) 21 | self.MeanAdd = MeanShift(1) 22 | self.relu = nn.ReLU() 23 | 24 | 25 | def forward(self, x): 26 | 27 | x = self.MeanSubstract(x) 28 | x = self.conv_head(x) 29 | 30 | x = self.ResidualGroup(x) 31 | x = self.ResidualGroup(x) 32 | x = self.ResidualGroup(x) 33 | x = self.ResidualGroup(x) 34 | x = self.ResidualGroup(x) 35 | x = self.ResidualGroup(x) 36 | 37 | x = self.upsample(x) 38 | x = self.conv_back(x) 39 | x = self.MeanAdd(x) 40 | 41 | return x 42 | 43 | if __name__ == '__main__': 44 | model = SplitSR() 45 | 46 | x = torch.randn(1, 3, 96, 96) 47 | y = model(x) 48 | 49 | assert y.shape[-1] == x.shape[-1] * 4 && y.shape[-2] == x.shape[-2] * 4 50 | 51 | 52 | params = sum(p.numel() for p in model.parameters() if p.requires_grad) 53 | print(f'Total params: {params}') -------------------------------------------------------------------------------- /modules/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class SplitSRBlock(nn.Module): 5 | def __init__(self, channels, kernel, alpha): 6 | super(SplitSRBlock, self).__init__() 7 | self.alpharatio = int(channels * alpha) 8 | self.channels = channels 9 | self.conv = nn.Conv2d(in_channels=self.alpharatio, out_channels=self.alpharatio, kernel_size=kernel, stride=1, padding=kernel//2, bias=True) 10 | self.batchnorm = nn.BatchNorm2d(self.alpharatio) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | 14 | 15 | def forward(self, x): 16 | active, passive = x[:, :self.alpharatio], x[:, self.alpharatio:] 17 | active = self.conv(active) # In: (1, 64 * α, W, H) | Out: (1, 64 * α, W, H) 18 | active = self.batchnorm(active) 19 | active = self.relu(active) 20 | x = torch.cat([passive, active], dim=1) # Out: (1, 64, W, H) 21 | return x 22 | 23 | class Upsample(nn.Module): 24 | def __init__(self, channels): 25 | super(Upsample, self).__init__() 26 | self.channels = channels 27 | self.conv = nn.Conv2d(in_channels=self.channels, out_channels=self.channels*4, kernel_size=3, stride=1, padding=3//2, bias=True) 28 | self.pixelshuffle = nn.PixelShuffle(2) 29 | 30 | 31 | def forward(self, x): 32 | x = self.conv(x) # In: (1, 64, W, H) | Out: (1, 256, W, H) 33 | x = self.pixelshuffle(x) # In: (1, 256, W*2, H*2) | Out: (1, 64, W*2, H*2) 34 | 35 | x = self.conv(x) # In: (1, 64, W*2, H*2) | Out: (1, 256, W*2, H*2) 36 | x = self.pixelshuffle(x) # In: (1, 256, W*4, H*4) | Out: (1, 64, W*4, H*4) 37 | 38 | return x 39 | 40 | class ResidualBlock(nn.Module): 41 | def __init__(self, channels): 42 | super(ResidualBlock, self).__init__() 43 | self.channels = channels 44 | self.conv = nn.Conv2d(in_channels=self.channels, out_channels=self.channels, kernel_size=3, stride=1, padding=3//2, bias=True) 45 | self.batchnorm = nn.BatchNorm2d(self.channels) 46 | self.relu = nn.ReLU() 47 | 48 | 49 | def forward(self, x): 50 | residual = x 51 | x = self.conv(x) 52 | x = self.batchnorm(x) 53 | x = self.relu(x) 54 | 55 | x = self.conv(x) 56 | x = self.batchnorm(x) 57 | x += residual 58 | x = self.relu(x) 59 | 60 | return x 61 | 62 | 63 | class MeanShift(nn.Conv2d): 64 | def __init__(self, coeff): 65 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 66 | std = torch.Tensor([1.0, 1.0, 1.0]) 67 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 68 | self.weight.data.div_(std.view(3, 1, 1, 1)) 69 | self.bias.data = coeff * 255 * torch.Tensor([0.4488, 0.4371, 0.4040]) 70 | self.bias.data.div_(std) 71 | self.requires_grad = False 72 | 73 | if __name__ == '__main__': 74 | block_alphas = [0.125, 0.256, 0.500, 1.000] 75 | 76 | for alpha in block_alphas: 77 | block = SplitSRBlock(128, 3, alpha) 78 | 79 | x = torch.randn(1, 128, 112, 112) 80 | y = block(x) 81 | 82 | assert y.shape == x.shape 83 | 84 | params = sum(p.numel() for p in block.parameters() if p.requires_grad) 85 | print(f'Total params for block with α={alpha}: {params}') --------------------------------------------------------------------------------