├── NBNet.py └── README.md /NBNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ConvBlock(nn.Module): 6 | def __init__(self, in_channel, out_channel, strides=1): 7 | super(ConvBlock, self).__init__() 8 | self.strides = strides 9 | self.block = nn.Sequential( 10 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=strides, padding=1), 11 | nn.LeakyReLU(inplace=True), 12 | nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=strides, padding=1), 13 | nn.LeakyReLU(inplace=True), 14 | ) 15 | 16 | self.conv11 = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=strides, padding=0) 17 | 18 | 19 | def forward(self, x): 20 | out1 = self.block(x) 21 | out2 = self.conv11(x) 22 | out = out1+out2 23 | return out 24 | 25 | 26 | class SSA(nn.Module): 27 | def __init__(self, in_channel, strides= 1): 28 | super(SSA, self).__init__() 29 | 30 | self.conv1 = nn.Conv2d(in_channel, 16, kernel_size=3, stride=strides, padding=1) 31 | self.relu1 = nn.LeakyReLU(inplace=True) 32 | self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=strides, padding=1) 33 | self.relu2 = nn.LeakyReLU(inplace=True) 34 | 35 | self.conv11 = nn.Conv2d(in_channel, 16, kernel_size=1, stride=strides, padding=0) 36 | 37 | 38 | 39 | def forward(self, input1, input2): 40 | input1 = input1.permute(0,2,3,1) 41 | input2 = input2.permute(0,2,3,1) 42 | cat = torch.cat([input1, input2], 3) 43 | cat =cat.permute(0,3,1,2) 44 | out1 = self.relu1(self.conv1(cat)) 45 | out1 = self.relu2(self.conv2(out1)) 46 | out2 = self.conv11(cat) 47 | conv = (out1+out2).permute(0,2,3,1) 48 | H, W, K = conv.shape[1], conv.shape[2], conv.shape[3] 49 | V = conv.reshape(H*W, K) 50 | Vtrans = torch.transpose(V,1,0) 51 | Vinverse = torch.inverse(torch.mm(Vtrans, V)) 52 | Projection = torch.mm(torch.mm(V, Vinverse), Vtrans) 53 | H1, W1, C1 = input1.shape[1], input1.shape[2], input1.shape[3] 54 | X1 = input1.reshape(H1*W1, C1) 55 | Yproj = torch.mm(Projection, X1) 56 | Y = Yproj.reshape(H1, W1, C1).unsqueeze(0) 57 | Y = Y.permute(0,3,1,2) 58 | return Y 59 | 60 | 61 | 62 | class NBNet(nn.Module): 63 | def __init__(self, num_classes=10): 64 | super(NBNet, self).__init__() 65 | 66 | #device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 67 | self.ConvBlock1 = ConvBlock(3, 32, strides=1) 68 | self.pool1 = nn.MaxPool2d(kernel_size=2) 69 | self.skip1 = nn.Sequential(ConvBlock(32, 32, strides=1),ConvBlock(32, 32, strides=1),ConvBlock(32, 32, strides=1),ConvBlock(32, 32, strides=1)) 70 | self.ssa1 = SSA(64, strides=1) 71 | 72 | self.ConvBlock2 = ConvBlock(32, 64, strides=1) 73 | self.pool2 = nn.MaxPool2d(kernel_size=2) 74 | self.skip2 = nn.Sequential(ConvBlock(64, 64, strides=1),ConvBlock(64, 64, strides=1),ConvBlock(64, 64, strides=1)) 75 | self.ssa2 = SSA(128, strides=1) 76 | 77 | self.ConvBlock3 = ConvBlock(64, 128, strides=1) 78 | self.pool3 = nn.MaxPool2d(kernel_size=2) 79 | self.skip3 = nn.Sequential(ConvBlock(128, 128, strides=1),ConvBlock(128, 128, strides=1)) 80 | self.ssa3 = SSA(256, strides=1) 81 | 82 | self.ConvBlock4 = ConvBlock(128, 256, strides=1) 83 | self.pool4 = nn.MaxPool2d(kernel_size=2) 84 | self.skip4 = nn.Sequential(ConvBlock(256, 256, strides=1)) 85 | self.ssa4 = SSA(512, strides=1) 86 | 87 | self.ConvBlock5 = ConvBlock(256, 512, strides=1) 88 | 89 | 90 | self.upv6 = nn.ConvTranspose2d(512, 256, 2, stride=2) 91 | self.ConvBlock6 = ConvBlock(512, 256, strides=1) 92 | 93 | self.upv7 = nn.ConvTranspose2d(256, 128, 2, stride=2) 94 | self.ConvBlock7 = ConvBlock(256, 128, strides=1) 95 | 96 | self.upv8 = nn.ConvTranspose2d(128, 64, 2, stride=2) 97 | self.ConvBlock8 = ConvBlock(128, 64, strides=1) 98 | 99 | self.upv9 = nn.ConvTranspose2d(64, 32, 2, stride=2) 100 | self.ConvBlock9 = ConvBlock(64, 32, strides=1) 101 | 102 | self.conv10 = nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1) 103 | 104 | 105 | def forward(self, x): 106 | conv1 = self.ConvBlock1(x) 107 | pool1 = self.pool1(conv1) 108 | 109 | conv2 = self.ConvBlock2(pool1) 110 | pool2 = self.pool2(conv2) 111 | 112 | conv3 = self.ConvBlock3(pool2) 113 | pool3 = self.pool3(conv3) 114 | 115 | conv4 = self.ConvBlock4(pool3) 116 | pool4 = self.pool4(conv4) 117 | 118 | conv5 = self.ConvBlock5(pool4) 119 | 120 | up6 = self.upv6(conv5) 121 | skip4 = self.skip4(conv4) 122 | skip4 = self.ssa4(skip4, up6) 123 | up6 = torch.cat([up6, skip4], 1) 124 | conv6 = self.ConvBlock6(up6) 125 | 126 | up7 = self.upv7(conv6) 127 | skip3 = self.skip3(conv3) 128 | skip3 = self.ssa3(skip3, up7) 129 | up7 = torch.cat([up7, skip3], 1) 130 | conv7 = self.ConvBlock7(up7) 131 | 132 | 133 | up8 = self.upv8(conv7) 134 | skip2 = self.skip2(conv2) 135 | skip2 = self.ssa2(skip2, up8) 136 | up8 = torch.cat([up8, skip2], 1) 137 | conv8 = self.ConvBlock8(up8) 138 | 139 | up9 = self.upv9(conv8) 140 | skip1 = self.skip1(conv1) 141 | skip1 = self.ssa1(skip1, up9) 142 | up9 = torch.cat([up9, skip1], 1) 143 | conv9 = self.ConvBlock9(up9) 144 | 145 | conv10= self.conv10(conv9) 146 | out = x + conv10 147 | 148 | return out 149 | 150 | 151 | 152 | if __name__ == "__main__": 153 | model = NBNet() 154 | input = torch.randn(1, 3, 128, 128) 155 | output = model(input) 156 | print(output.shape) 157 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NBNet 2 | ReImplement of 'NBNet: Noise Basis Learning for Image Denoising with Subspace Projection' in Pytorch 3 | 4 | It's NOT an official implementation! 5 | 6 | The architecture of NBNet follows this paper. -- https://arxiv.org/abs/2012.15028 7 | --------------------------------------------------------------------------------