├── README.md ├── images └── __init__.py ├── logs └── log.txt ├── main.py ├── main.sh ├── network ├── Punet.py ├── __init__.py └── pconv.py └── util.py /README.md: -------------------------------------------------------------------------------- 1 | # Self2Self Pytorch Implementation 2 | 3 | ## Introduction 4 | This is a pytorch implementation of [Self2Self](https://openaccess.thecvf.com/content_CVPR_2020/papers/Quan_Self2Self_With_Dropout_Learning_Self-Supervised_Denoising_From_Single_Image_CVPR_2020_paper.pdf), "Yuhui Quan, Mingqin Chen, Tongyao Pang, Hui Ji; Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2020, pp. 1890-1898." 5 | 6 | It is a pytorch reimplementation of the [tensorflow version](https://github.com/scut-mingqinchen/self2self). 7 | 8 | You can simply run main.py with the default parameters: 9 | ``` 10 | sh main.sh 11 | ``` 12 | The denoised images will be saved in images/, and the logs will be saved in logs/. 13 | 14 | ## Details of reimplementation 15 | 16 | There are some notable details in the conversion of tensorflow to pytorch, which will significantly effect the performance. 17 | 18 | ### Partial Convolution 2D layer 19 | Pytorch has a package of the implementation of [Pconv2d](https://github.com/DesignStripe/torch_pconv). However, the implementation details is different from that of the tensorflow version. Specifically, the variable **mask** in the tensorflow is a 4-d tensor with shape (1, channel, width, height), but a 3-d tensor with shape (1, width, height) in the pytorch package. We have implemented a Pconv2d structure consistent with the tensorflow version, by modifying the source code of the pytorch package. 20 | 21 | ### Optimizer 22 | The implementation details of Adam between tensorflow and pytorch have slight difference. However, it is widely discussed the suboptimal convergence in PyTorch compared to TensorFlow when using Adam optimizer. 23 | 24 | ## Update Log 25 | 26 | ### 2023-05-14 27 | - Found and fixed a bug in line 144 of file "network/pconv.py", which enables our implementation to achieve comparable denosing performance with the tensorflow version. 28 | - Found that changing the optimizer from Adam to AdamW achieves better denosing performance. (We still keep the Adam optimizer, in order to keep up with the original tensorflow version of the implementation) 29 | 30 | Thanks to @haimiaozh for all the contributions to improving this project! 31 | -------------------------------------------------------------------------------- /images/__init__.py: -------------------------------------------------------------------------------- 1 | #### 2 | -------------------------------------------------------------------------------- /logs/log.txt: -------------------------------------------------------------------------------- 1 | #### 2 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import network.Punet 5 | import skimage.metrics 6 | from argparse import ArgumentParser 7 | 8 | import util 9 | import cv2 10 | import os 11 | 12 | def data_arg(x, is_flip_lr, is_flip_ud): 13 | if is_flip_lr > 0: 14 | x = torch.flip(x, dims=[2]) 15 | if is_flip_ud > 0: 16 | x = torch.flip(x, dims=[3]) 17 | return x 18 | 19 | def get_output(noisy, model, drop_rate=0.3, bs=1, device='cpu'): 20 | noisy_tensor = torch.tensor(noisy).permute(0,3,1,2).to(device) 21 | is_flip_lr = np.random.randint(2) 22 | is_flip_ud = np.random.randint(2) 23 | noisy_tensor = data_arg(noisy_tensor, is_flip_lr, is_flip_ud) 24 | # mask_tensor = torch.ones([bs, model.width, model.height]).to(device) 25 | mask_tensor = torch.ones(noisy_tensor.shape).to(device) 26 | mask_tensor = F.dropout(mask_tensor, drop_rate) * (1-drop_rate) 27 | input_tensor = noisy_tensor * mask_tensor#.unsqueeze(1) 28 | output = model(input_tensor, mask_tensor) 29 | output = data_arg(output, is_flip_lr, is_flip_ud) 30 | output_numpy = output.detach().cpu().numpy().transpose(0,2,3,1) 31 | return output_numpy 32 | 33 | def get_loss(noisy, model, drop_rate=0.3, bs=1, device='cpu'): 34 | noisy_tensor = torch.tensor(noisy).permute(0,3,1,2).to(device) 35 | is_flip_lr = np.random.randint(2) 36 | is_flip_ud = np.random.randint(2) 37 | noisy_tensor = data_arg(noisy_tensor, is_flip_lr, is_flip_ud) 38 | # mask_tensor = torch.ones([bs, model.width, model.height]).to(device) 39 | mask_tensor = torch.ones(noisy_tensor.shape).to(device) 40 | mask_tensor = F.dropout(mask_tensor, drop_rate) * (1-drop_rate) 41 | input_tensor = noisy_tensor * mask_tensor#.unsqueeze(1) 42 | output = model(input_tensor, mask_tensor) 43 | observe_tensor = 1.0 - mask_tensor#.unsqueeze(1) 44 | loss = torch.sum((output-noisy_tensor).pow(2)*(observe_tensor)) / torch.count_nonzero(observe_tensor).float() 45 | return loss 46 | 47 | def train(file_path, args, is_realnoisy=False): 48 | print(file_path) 49 | gt = util.load_np_image(file_path) 50 | _, w, h, c = gt.shape 51 | model_path = file_path[0:file_path.rfind(".")] + "/" + str(args.sigma) + "/model_" + args.model_type + "/" 52 | os.makedirs(model_path, exist_ok=True) 53 | noisy = util.add_gaussian_noise(gt, model_path, args.sigma, bs=args.bs) 54 | print('noisy shape:', noisy.shape) 55 | print('image shape:', gt.shape) 56 | 57 | # model 58 | model = network.Punet.Punet(channel=c, width=w, height=h, drop_rate=args.drop_rate).to(args.device) 59 | model.train() 60 | 61 | # optimizer 62 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 63 | 64 | # begin training 65 | avg_loss = 0 66 | for step in range(args.iteration): 67 | # one step 68 | loss = get_loss(noisy, model, drop_rate=args.drop_rate, bs=args.bs, device=args.device) 69 | avg_loss += loss.item() 70 | optimizer.zero_grad() 71 | loss.backward() 72 | optimizer.step() 73 | with torch.cuda.device(args.device): 74 | torch.cuda.empty_cache() 75 | 76 | # test 77 | if (step+1) % args.test_frequency == 0: 78 | # model.eval() 79 | print("After %d training step(s)" % (step + 1), 80 | "loss is {:.9f}".format(avg_loss / args.test_frequency)) 81 | final_image = np.zeros(gt.shape) 82 | for j in range(args.num_prediction): 83 | output_numpy = get_output(noisy, model, drop_rate=args.drop_rate, bs=args.bs, device=args.device) 84 | final_image += output_numpy 85 | with torch.cuda.device(args.device): 86 | torch.cuda.empty_cache() 87 | final_image = np.squeeze(np.uint8(np.clip(final_image / args.num_prediction, 0, 1) * 255)) 88 | cv2.imwrite(model_path + 'Self2Self-' + str(step + 1) + '.png', final_image) 89 | PSNR = skimage.metrics.peak_signal_noise_ratio(gt[0], final_image.astype(np.float32)/255.0) 90 | print("psnr = ", PSNR) 91 | with open(args.log_pth, 'a') as f: 92 | f.write("After %d training step(s), " % (step + 1)) 93 | f.write("loss is {:.9f}, ".format(avg_loss / args.test_frequency)) 94 | f.write("psnr is {:.4f}".format(PSNR)) 95 | f.write("\n") 96 | avg_loss = 0 97 | model.train() 98 | 99 | return PSNR 100 | 101 | def main(args): 102 | path = './testsets/Set9/' 103 | path = args.path 104 | file_list = os.listdir(path) 105 | with open(args.log_pth, 'w') as f: 106 | f.write("Self2self algorithm!\n") 107 | avg_psnr = 0 108 | count = 0 109 | for file_name in file_list: 110 | if not os.path.isdir(path + file_name): 111 | PSNR = train(path+file_name, args) 112 | avg_psnr += PSNR 113 | count += 1 114 | break 115 | with open(args.log_pth, 'a') as f: 116 | f.write('average psnr is {:.4f}'.format(avg_psnr/count)) 117 | 118 | def build_args(): 119 | parser = ArgumentParser() 120 | 121 | parser.add_argument("--iteration", type=int, default=150000) 122 | parser.add_argument("--test_frequency", type=int, default=1000) 123 | parser.add_argument("--drop_rate", type=float, default=0.3) 124 | parser.add_argument("--sigma", type=float, default=25.0) 125 | parser.add_argument("--bs", type=int, default=1) 126 | parser.add_argument("--model_type", type=str, default='dropout') 127 | parser.add_argument("--lr", type=float, default=1e-4) 128 | parser.add_argument("--num_prediction", type=int, default=100) 129 | parser.add_argument("--log_pth", type=str, default='./logs/log.txt') 130 | parser.add_argument("--path", type=str, default='./testsets/Set9/') 131 | parser.add_argument("--device", type=str, default='cpu') 132 | 133 | args = parser.parse_args() 134 | return args 135 | 136 | if __name__ == "__main__": 137 | args = build_args() 138 | main(args) -------------------------------------------------------------------------------- /main.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --path './testsets/test_Set9/' \ 3 | --bs 1 \ 4 | --sigma 25 \ 5 | --iteration 150000 \ 6 | --lr 1e-4 \ 7 | --model_type 'dropout' \ 8 | --test_frequency 1000 \ 9 | --log_pth './logs/log_dropout.txt' \ 10 | --device 'cuda:0' -------------------------------------------------------------------------------- /network/Punet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | # from torch_pconv import PConv2d 6 | from network.pconv import PConv2d 7 | 8 | class Pconv_lr(nn.Module): # different padding from tf source code 9 | def __init__(self, in_channels, out_channels): 10 | super().__init__() 11 | self.Pconv2d_bias = PConv2d(in_channels = in_channels, 12 | out_channels = out_channels, 13 | kernel_size = 3, 14 | stride = 1, 15 | # padding = 'valid', 16 | padding = 0, 17 | bias = True, 18 | ) 19 | 20 | def forward(self, x, mask): 21 | padding_shape = (1,1,1,1) 22 | x = F.pad(x, padding_shape, "replicate") 23 | mask = F.pad(mask, padding_shape, "constant", value=1) 24 | x, mask = self.Pconv2d_bias(x, mask) 25 | x = F.leaky_relu(x, negative_slope=0.1) 26 | return x, mask 27 | 28 | 29 | class conv_lr(nn.Module): # different padding from tf source code 30 | def __init__(self, in_channels, out_channels): 31 | super().__init__() 32 | self.conv2d_bias = nn.Conv2d(in_channels = in_channels, 33 | out_channels = out_channels, 34 | kernel_size = 3, 35 | stride = 1, 36 | padding = 'valid', 37 | # padding = 'same', 38 | bias = True, 39 | ) 40 | 41 | def forward(self, x, drop_rate=0.3): 42 | x = F.dropout(x, drop_rate) 43 | padding_shape = (1,1,1,1) 44 | x = F.pad(x, padding_shape, "replicate") 45 | x = self.conv2d_bias(x) 46 | x = F.leaky_relu(x, negative_slope=0.1) 47 | return x 48 | 49 | 50 | class conv(nn.Module): # different padding from tf source code 51 | def __init__(self, in_channels, out_channels): 52 | super().__init__() 53 | self.conv2d_bias = nn.Conv2d(in_channels = in_channels, 54 | out_channels = out_channels, 55 | kernel_size = 3, 56 | stride = 1, 57 | padding = 'valid', 58 | # padding = 'same', 59 | bias = True, 60 | ) 61 | self.Sigmoid = torch.nn.Sigmoid() 62 | 63 | def forward(self, x, drop_rate=0.3): 64 | x = F.dropout(x, drop_rate) 65 | padding_shape = (1,1,1,1) 66 | x = F.pad(x, padding_shape, "replicate") 67 | x = self.conv2d_bias(x) 68 | x = self.Sigmoid(x) 69 | return x 70 | 71 | 72 | class Punet(nn.Module): 73 | def __init__(self, channel=3, width=256, height=256, drop_rate=0.3): 74 | super().__init__() 75 | 76 | self.channel = channel 77 | self.width = width 78 | self.height = height 79 | self.drop_rate = drop_rate 80 | 81 | # encoder 82 | self.env_conv0 = Pconv_lr(self.channel, 48) # in_channel=x.channel, out_channel=output channel 83 | self.env_conv1 = Pconv_lr(48, 48) 84 | self.env_conv2 = Pconv_lr(48, 48) 85 | self.env_conv3 = Pconv_lr(48, 48) 86 | self.env_conv4 = Pconv_lr(48, 48) 87 | self.env_conv5 = Pconv_lr(48, 48) 88 | self.env_conv6 = Pconv_lr(48, 48) 89 | # decoder 90 | self.dec_conv5 = conv_lr(96, 96) 91 | self.dec_conv5b = conv_lr(96, 96) 92 | self.dec_conv4 = conv_lr(144, 96) 93 | self.dec_conv4b = conv_lr(96, 96) 94 | self.dec_conv3 = conv_lr(144, 96) 95 | self.dec_conv3b = conv_lr(96, 96) 96 | self.dec_conv2 = conv_lr(144, 96) 97 | self.dec_conv2b = conv_lr(96, 96) 98 | self.dec_conv1a = conv_lr(99, 64) 99 | self.dec_conv1b = conv_lr(64, 32) 100 | self.dec_conv1 = conv(32, self.channel) 101 | 102 | def Pmaxpool2d(self, x, mask, kernel_size=2): 103 | # pooling = nn.MaxPool2d(kernel_size=kernel_size, padding='same') 104 | pooling = nn.MaxPool2d(kernel_size=kernel_size) 105 | x = pooling(x) 106 | mask = pooling(mask) 107 | return x, mask 108 | 109 | def encoder(self, x, mask): 110 | skips = [x] 111 | 112 | x, mask = self.env_conv0(x, mask) 113 | x, mask = self.env_conv1(x, mask) 114 | x, mask = self.Pmaxpool2d(x, mask) 115 | skips.append(x) 116 | 117 | x, mask = self.env_conv2(x, mask) 118 | x, mask = self.Pmaxpool2d(x, mask) 119 | skips.append(x) 120 | 121 | x, mask = self.env_conv3(x, mask) 122 | x, mask = self.Pmaxpool2d(x, mask) 123 | skips.append(x) 124 | 125 | x, mask = self.env_conv4(x, mask) 126 | x, mask = self.Pmaxpool2d(x, mask) 127 | skips.append(x) 128 | 129 | x, mask = self.env_conv5(x, mask) 130 | x, mask = self.Pmaxpool2d(x, mask) 131 | x, mask = self.env_conv6(x, mask) 132 | 133 | return x, skips 134 | 135 | def decoder(self, x, skips): 136 | x = F.upsample(x, scale_factor=2) 137 | x = torch.cat([x, skips.pop()], dim=1) 138 | x = self.dec_conv5(x, self.drop_rate) 139 | x = self.dec_conv5b(x, self.drop_rate) 140 | 141 | x = F.upsample(x, scale_factor=2) 142 | x = torch.cat([x, skips.pop()], dim=1) 143 | x = self.dec_conv4(x, self.drop_rate) 144 | x = self.dec_conv4b(x, self.drop_rate) 145 | 146 | x = F.upsample(x, scale_factor=2) 147 | x = torch.cat([x, skips.pop()], dim=1) 148 | x = self.dec_conv3(x, self.drop_rate) 149 | x = self.dec_conv3b(x, self.drop_rate) 150 | 151 | x = F.upsample(x, scale_factor=2) 152 | x = torch.cat([x, skips.pop()], dim=1) 153 | x = self.dec_conv2(x, self.drop_rate) 154 | x = self.dec_conv2b(x, self.drop_rate) 155 | 156 | x = F.upsample(x, scale_factor=2) 157 | x = torch.cat([x, skips.pop()], dim=1) 158 | x = self.dec_conv1a(x, self.drop_rate) 159 | x = self.dec_conv1b(x, self.drop_rate) 160 | x = self.dec_conv1(x, self.drop_rate) 161 | return x 162 | 163 | def forward(self, x, mask): 164 | x, skips = self.encoder(x, mask) 165 | x = self.decoder(x, skips) 166 | return x -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | #### 2 | -------------------------------------------------------------------------------- /network/pconv.py: -------------------------------------------------------------------------------- 1 | from tensor_type import Tensor4d, Tensor3d, Tensor 2 | import math 3 | from typing import Tuple, Union 4 | import torch 5 | from torch import nn 6 | 7 | TupleInt = Union[int, Tuple[int, int]] 8 | 9 | 10 | class PConv2d(nn.Module): 11 | def __init__( 12 | self, 13 | in_channels: int, 14 | out_channels: int, 15 | kernel_size: TupleInt = 1, 16 | stride: TupleInt = 1, 17 | padding: TupleInt = 0, 18 | dilation: TupleInt = 1, 19 | bias: bool = False, 20 | legacy_behaviour: bool = False, 21 | ): 22 | """Partial Convolution on 2D input. 23 | :param in_channels: see torch.nn.Conv2d 24 | :param out_channels: see torch.nn.Conv2d 25 | :param kernel_size: see torch.nn.Conv2d 26 | :param stride: see torch.nn.Conv2d 27 | :param padding: see torch.nn.Conv2d 28 | :param dilation: see torch.nn.Conv2d 29 | :param bias: see torch.nn.Conv2d 30 | :param legacy_behaviour: Tries to replicate Guilin's implementation's numerical error when handling the bias, 31 | but in doing so, it does extraneous operations that could be avoided and still result in *almost* the same 32 | result, at a tolerance of 0.00000458 % on the cuDNN 11.4 backend. Can safely be False for real life 33 | applications. 34 | """ 35 | super().__init__() 36 | 37 | # Set this to True, and the output is guaranteed to be exactly the same as PConvGuilin and PConvRFR 38 | # Set this to False, and the output will be very very close, but with some numerical errors removed/added, 39 | # even though formally the maths are equivalent. 40 | self.legacy_behaviour = legacy_behaviour 41 | 42 | self.in_channels = in_channels 43 | self.out_channels = out_channels 44 | self.kernel_size = self._to_int_tuple(kernel_size) 45 | self.stride = self._to_int_tuple(stride) 46 | self.padding = self._to_int_tuple(padding) 47 | self.dilation = self._to_int_tuple(dilation) 48 | self.use_bias = bias 49 | 50 | conv_kwargs = dict( 51 | kernel_size=self.kernel_size, 52 | stride=self.stride, 53 | padding=self.padding, 54 | dilation=self.dilation, 55 | groups=1, 56 | bias=False, 57 | ) 58 | 59 | # Don't use a bias here, we handle the bias manually to speed up computation 60 | self.regular_conv = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, **conv_kwargs) 61 | 62 | # I found a way to avoid doing a in_channels --> out_channels conv and instead just do a 63 | # 1 channel in --> 1 channel out conv and then just scale the output of the conv by the number 64 | # of input channels, and repeat the resulting tensor to have "out channels" 65 | # This saves 1) a lot of memory because no need to pad before the conv 66 | # 2) a lot of computation because the convolution is way smaller (in_c * out_c times less operations) 67 | # It's also possible to avoid repeating the tensor to have "out channels", and instead use broadcasting 68 | # when doing operations. This further reduces the number of operations to do and is equivalent, 69 | # and especially the amount of memory used. 70 | # self.mask_conv = nn.Conv2d(in_channels=1, out_channels=1, **conv_kwargs) 71 | self.mask_conv = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, **conv_kwargs) 72 | 73 | # Inits 74 | self.regular_conv.apply( 75 | lambda m: nn.init.kaiming_normal_(m.weight, a=0, mode="fan_in") 76 | ) 77 | 78 | # the mask convolution should be a constant operation 79 | torch.nn.init.constant_(self.mask_conv.weight, 1.0) 80 | 81 | for param in self.mask_conv.parameters(): 82 | param.requires_grad = False 83 | 84 | if self.use_bias: 85 | self.bias = nn.Parameter(torch.empty(1, self.out_channels, 1, 1)) 86 | else: 87 | self.register_parameter("bias", None) 88 | 89 | with torch.no_grad(): 90 | # This is how nn._ConvNd initialises its weights 91 | nn.init.kaiming_uniform_(self.regular_conv.weight, a=math.sqrt(5)) 92 | 93 | if self.bias is not None: 94 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out( 95 | self.regular_conv.weight 96 | ) 97 | bound = 1 / math.sqrt(fan_in) 98 | nn.init.uniform_(self.bias.view(self.out_channels), -bound, bound) 99 | 100 | def forward(self, x: Tensor4d, mask: Tensor4d) -> Tuple[Tensor4d, Tensor4d]: 101 | """Performs the 2D partial convolution. 102 | About the mask: 103 | - its dtype should be torch.float32 104 | - its values should be EITHER 0.0 OR 1.0, not in between 105 | - it should not have a channel dimensions. Just (batch, height, width). 106 | The returned mask is guaranteed to also match these criteria. 107 | This returns a tuple containing: 108 | - the result of the partial convolution on the input x. 109 | - the "updated mask", which is slightly "closed off". It is a "binary" mask of dtype float, 110 | containing values of either 0.0 or 1.0 (nothing in between). 111 | :param x: The input image batch, a 4d tensor of traditional batch, channel, height, width. 112 | :param mask: This takes as input a 3d binary (0.0 OR 1.0) mask of dtype=float 113 | :return: a tuple (output, updated_mask) 114 | """ 115 | Tensor4d.check(x) 116 | batch, channels, h, w = x.shape 117 | Tensor[batch, channels, h, w].check(mask) 118 | 119 | if mask.dtype != torch.float32: 120 | raise TypeError( 121 | "mask should have dtype=torch.float32 with values being either 0.0 or 1.0" 122 | ) 123 | 124 | if x.dtype != torch.float32: 125 | raise TypeError("x should have dtype=torch.float32") 126 | 127 | output = self.regular_conv(x * mask) 128 | _, _, conv_h, conv_w = output.shape 129 | 130 | update_mask: Tensor[batch, channels, conv_h, conv_w] 131 | mask_ratio: Tensor[batch, channels, conv_h, conv_w] 132 | with torch.no_grad(): 133 | mask_ratio, update_mask = self.compute_masks(mask) 134 | 135 | if self.use_bias: 136 | if self.legacy_behaviour: 137 | # Doing this is entirely pointless. However, the legacy Guilin's implementation does it and 138 | # if I don't do it, I get a relative numerical error of about 0.00000458 % 139 | output += self.bias 140 | output -= self.bias 141 | 142 | output *= mask_ratio # Multiply by the sum(1)/sum(mask) ratios 143 | output += self.bias # Add the bias *after* mask_ratio, not before ! 144 | # output *= update_mask # Nullify pixels outside the valid mask 145 | else: 146 | output *= mask_ratio 147 | 148 | return output, update_mask 149 | 150 | def compute_masks(self, mask: Tensor4d) -> Tuple[Tensor4d, Tensor4d]: 151 | """ 152 | This computes two masks: 153 | - the update_mask is a binary mask that has 1 if the pixel was used in the convolution, and 0 otherwise 154 | - the mask_ratio which has value sum(1)/sum(mask) if the pixel was used in the convolution, and 0 otherwise 155 | * sum(1) means the sum of a kernel full of ones of equivalent size as the self.regular_conv's kernel. 156 | It is usually calculated as self.in_channels * self.kernel_size ** 2, assuming a square kernel. 157 | * sum(mask) means the sum of ones and zeros of the mask in a particular region. 158 | If the region is entirely valid, then sum(mask) = sum(1) but if the region is only partially within the mask, 159 | then 0 < sum(mask) < sum(1). 160 | sum(mask) is calculated specifically in the vicinity of the pixel, and is pixel dependant. 161 | * mask_ratio is Tensor4d with the channel dimension as a singleton, and is NOT binary. 162 | It has values between 0 and sum(1) (included). 163 | * update_mask is a Tensor4d with the channel dimension as a singleton, and is "binary" (either 0.0 or 1.0). 164 | :param mask: the input "binary" mask. It has to be a dtype=float32, but containing only values 0.0 or 1.0. 165 | :return: mask_ratio, update_mask 166 | """ 167 | update_mask = self.mask_conv(mask) 168 | # Make values where update_mask==0 be super high 169 | # and otherwise computes the sum(ones)/sum(mask) value for other entries 170 | # noinspection PyTypeChecker 171 | mask_ratio = self.kernel_size[0] * self.kernel_size[1] / (update_mask + 1e-8) 172 | # Once we've normalised the values in update_mask and saved them elsewhere, we can now ignore their value 173 | # and return update_mask to a binary mask 174 | update_mask = torch.clamp(update_mask, 0, 1) 175 | # Then multiplies those super high values by zero so we cancel them out 176 | mask_ratio *= update_mask 177 | # We can discard the extra channel dimension what was just there to help with broadcasting 178 | 179 | return mask_ratio, update_mask 180 | 181 | @staticmethod 182 | def _to_int_tuple(v: TupleInt) -> Tuple[int, int]: 183 | if not isinstance(v, tuple): 184 | return v, v 185 | else: 186 | return v 187 | 188 | def set_weight(self, w): 189 | with torch.no_grad(): 190 | self.regular_conv.weight.copy_(w) 191 | 192 | return self 193 | 194 | def set_bias(self, b): 195 | with torch.no_grad(): 196 | self.bias.copy_(b.view(1, self.out_channels, 1, 1)) 197 | 198 | return self 199 | 200 | def get_weight(self): 201 | return self.regular_conv.weight 202 | 203 | def get_bias(self): 204 | return self.bias 205 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import scipy.io as sio 4 | import random 5 | 6 | 7 | def add_gaussian_noise(img, model_path, sigma, bs=1): 8 | index = model_path.rfind("/") 9 | if sigma > 0: 10 | # noise = np.random.normal(scale=sigma / 255., size=img.shape).astype(np.float32) 11 | noise = np.random.normal(scale=sigma / 255., size=[bs, img.shape[1], img.shape[2], img.shape[3]]).astype(np.float32) 12 | sio.savemat(model_path[0:index] + '/noise.mat', {'noise': noise}) 13 | noisy_img = (img + noise).astype(np.float32) 14 | else: 15 | noisy_img = img.astype(np.float32) 16 | cv2.imwrite(model_path[0:index] + '/noisy.png', 17 | np.squeeze(np.int32(np.clip(noisy_img, 0, 1) * 255.))) 18 | return noisy_img 19 | 20 | 21 | def load_np_image(path, is_scale=True): 22 | img = cv2.imread(path, -1) 23 | if img.ndim == 2: 24 | img = np.expand_dims(img, axis=2) 25 | img = np.expand_dims(img, axis=0) 26 | if is_scale: 27 | img = np.array(img).astype(np.float32) / 255. 28 | return img 29 | 30 | 31 | def mask_pixel(img, model_path, rate): 32 | index = model_path.rfind("/") 33 | masked_img = img.copy() 34 | mask = np.ones_like(masked_img) 35 | perm_idx = [i for i in range(np.shape(img)[1] * np.shape(img)[2])] 36 | random.shuffle(perm_idx) 37 | for i in range(np.int32(np.shape(img)[1] * np.shape(img)[2] * rate)): 38 | x, y = np.divmod(perm_idx[i], np.shape(img)[2]) 39 | masked_img[:, x, y, :] = 0 40 | mask[:, x, y, :] = 0 41 | cv2.imwrite(model_path[0:index] + '/masked_img.png', np.squeeze(np.uint8(np.clip(masked_img, 0, 1) * 255.))) 42 | cv2.imwrite(model_path[0:index] + '/mask.png', np.squeeze(np.uint8(np.clip(mask, 0, 1) * 255.))) 43 | return masked_img, mask 44 | --------------------------------------------------------------------------------