├── util ├── loss.py ├── dataset.py └── augment.py ├── demo.py ├── README.md ├── model └── RIDNet.py └── train.py /util/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class L1_Loss(nn.Module): 6 | def __init__(self): 7 | super(L1_Loss, self).__init__() 8 | 9 | def forward(self, x, y): 10 | loss = F.l1_loss(x, y, reduction='mean') 11 | return loss * 1000 12 | 13 | 14 | class Smooth_L1_Loss(nn.Module): 15 | def __init__(self): 16 | super(Smooth_L1_Loss, self).__init__() 17 | 18 | def forward(self, x, y): 19 | loss = F.smooth_l1_loss(x, y, reduction='mean') 20 | return loss * 1000 21 | 22 | 23 | class L1_L2_Loss(nn.Module): 24 | def __init__(self, ratio): 25 | super(L1_L2_Loss, self).__init__() 26 | self.ratio = ratio 27 | 28 | def forward(self, x, y): 29 | L1_loss = F.l1_loss(x, y, reduction='mean') 30 | L2_loss = F.mse_loss(x, y, reduction='mean') 31 | L1_L2 = (self.ratio)*L1_loss + (1-self.ratio)*L2_loss 32 | 33 | return L1_L2 34 | -------------------------------------------------------------------------------- /util/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | import glob 4 | import cv2 5 | import numpy as np 6 | import copy 7 | 8 | class Denoising_dataset(torch.utils.data.Dataset): 9 | def __init__(self, img_dir, train_val, transform): 10 | super(Denoising_dataset, self).__init__() 11 | 12 | self.img_dir = [f for f in glob.glob(img_dir+'/**/*.jpg', recursive=True)] 13 | self.train_val = train_val 14 | self.transform = transform 15 | 16 | def __len__(self): 17 | return len(self.img_dir) 18 | 19 | def __getitem__(self, idx): 20 | img_dir = self.img_dir[idx] 21 | 22 | clean = cv2.imread(img_dir, cv2.IMREAD_GRAYSCALE) 23 | noisy = np.copy(clean) 24 | origin_img = copy.deepcopy(clean) 25 | 26 | noisy = self.gaussian_noise(clean) 27 | 28 | data = {'noisy': noisy, 'clean': clean} 29 | 30 | if self.transform: 31 | data = self.transform(data) 32 | 33 | return data 34 | 35 | def gaussian_noise(self, img, noise_level=[15, 25, 50]): 36 | sigma = np.random.choice(noise_level) 37 | gaussian_noise = np.random.normal(0, sigma, (img.shape[0], img.shape[1])) 38 | 39 | noisy_img = img + gaussian_noise 40 | noisy_img = np.clip(noisy_img, 0, 255).astype(np.uint8) 41 | 42 | return noisy_img 43 | 44 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import cv2 4 | import numpy as np 5 | 6 | from model.RIDNet import RIDNet 7 | from util.dataset import * 8 | 9 | 10 | def gaussian_noise(img, noise_level=[5, 10, 15, 20, 25, 30]): 11 | sigma = np.random.choice(noise_level) 12 | gaussian_noise = np.random.normal(0, sigma, (img.shape[0], img.shape[1])) 13 | 14 | noisy_img = img + gaussian_noise 15 | noisy_img = np.clip(noisy_img, 0, 255).astype(np.uint8) 16 | return noisy_img 17 | 18 | 19 | def demo(): 20 | model = RIDNet(in_channels=1, out_channels=1, num_feautres=32) 21 | 22 | checkpoint = torch.load('./weight/weight.pth') 23 | model.load_state_dict(checkpoint['model_state_dict']) 24 | criterion = checkpoint['loss'] 25 | 26 | if torch.cuda.is_available(): 27 | device = torch.device("cuda:0") 28 | print(device) 29 | model.to(device) 30 | 31 | img = cv2.imread(v, cv2.IMREAD_GRAYSCALE) 32 | origin_img = copy.deepcopy(img) 33 | 34 | img = gaussian_noise(img, noise_level=[15]) 35 | img = np.expand_dims(img, -1) 36 | img = img / 255. 37 | img = np.expand_dims(img , 0) 38 | img = torch.from_numpy(img).type(torch.float32) 39 | img = img.permute(0, 3, 1, 2) 40 | img = img.to(device) 41 | 42 | pred = model(img) 43 | output = pred[0].cpu().numpy().transpose(1, 2, 0) 44 | output = output * 255 45 | output = np.clip(output, 0, 255).astype(np.uint8) 46 | 47 | cv2.imwrite('input.jpg', origin_img) 48 | cv2.imwrite('output.jpg', output) 49 | 50 | 51 | if __name__ == '__main__': 52 | demo() 53 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch RIDNet Implementation (unofficial code) 2 | 3 | ## [[Paper]](https://openaccess.thecvf.com/content_ICCV_2019/papers/Anwar_Real_Image_Denoising_With_Feature_Attention_ICCV_2019_paper.pdf) 4 | ## Real Image Denoising with Feature Attention (ICCV, 2019) 5 | 6 | 7 | ***Abstract*** 8 | 9 | 10 | *Deep convolutional neural networks perform better 11 | on images containing spatially invariant noise (synthetic 12 | noise); however, their performance is limited on real-noisy 13 | photographs and requires multiple stage network modeling. To advance the practicability of denoising algorithms, 14 | this paper proposes a novel single-stage blind real image 15 | denoising network (RIDNet) by employing a modular architecture. We use a residual on the residual structure to 16 | ease the flow of low-frequency information and apply feature attention to exploit the channel dependencies. Furthermore, the evaluation in terms of quantitative metrics and visual quality on three synthetic and four real noisy datasets 17 | against 19 state-of-the-art algorithms demonstrate the superiority of our RIDNet.* 18 | 19 | 20 | 21 | ![image](https://user-images.githubusercontent.com/33386742/152332696-244dd263-b210-45ab-b551-5f6e6d7d8df7.png) 22 | 23 |
24 | 25 | 26 | ## Train 27 | ``` 28 | > python train.py --epochs 100 --batch_size 16 29 | ``` 30 |
31 | 32 | ## Result 33 | ### Ground Truth / Noised image / Denoised image 34 | ![image](https://user-images.githubusercontent.com/33386742/152548263-d466771c-ecc6-43c2-bea3-ec4eb59ad95f.png) 35 | ![image](https://user-images.githubusercontent.com/33386742/152549070-dd95cdc6-e77a-4bbb-a8f5-59b9220275e3.png) 36 | ![image](https://user-images.githubusercontent.com/33386742/152549416-c3f1bff2-61bb-4f46-b149-f166ee2bf550.png) 37 | ![image](https://user-images.githubusercontent.com/33386742/152549553-07bbe83a-d5b0-493a-ae07-227f2e0134dc.png) 38 |
39 | 40 | ## Reference 41 | * [Official code](https://github.com/saeed-anwar/RIDNet) 42 | -------------------------------------------------------------------------------- /model/RIDNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class ChannelAttention(nn.Module): 6 | def __init__(self, in_channels, out_channels, reduction=16): 7 | super(ChannelAttention, self).__init__() 8 | 9 | self.gap = nn.AdaptiveAvgPool2d(1) 10 | self.conv1 =nn.Conv2d(in_channels, out_channels//reduction, 1, 1, 0) 11 | self.relu1 = nn.ReLU() 12 | self.conv2 = nn.Conv2d(out_channels//reduction, in_channels, 1, 1, 0) 13 | self.sigmoid2 = nn.Sigmoid() 14 | 15 | def forward(self, x): 16 | gap = self.gap(x) 17 | x_out = self.conv1(gap) 18 | x_out = self.relu1(x_out) 19 | x_out = self.conv2(x_out) 20 | x_out = self.sigmoid2(x_out) 21 | x_out = x_out * x 22 | return x_out 23 | 24 | 25 | class EAM(nn.Module): 26 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, reduciton=4): 27 | super(EAM, self).__init__() 28 | 29 | # Merge and run block 30 | self.path1_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1) 31 | self.path1_relu1 = nn.ReLU() 32 | self.path1_conv2 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=2, dilation=2) 33 | self.path1_relu2 = nn.ReLU() 34 | 35 | self.path2_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=3, dilation=3) 36 | self.path2_relu1 = nn.ReLU() 37 | self.path2_conv2 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=4, dilation=4) 38 | self.path2_relu2 = nn.ReLU() 39 | 40 | self.conv3 = nn.Conv2d(in_channels*2, out_channels, kernel_size, stride=1, padding=1) 41 | self.relu3 = nn.ReLU() 42 | 43 | # Residual block 44 | self.conv4 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 45 | self.relu4 = nn.ReLU() 46 | self.conv5 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 47 | self.relu5 = nn.ReLU() 48 | 49 | # Enhance Residual block 50 | self.conv6 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 51 | self.relu6 = nn.ReLU() 52 | self.conv7 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 53 | self.relu7 = nn.ReLU() 54 | self.conv8 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 55 | self.relu8 = nn.ReLU() 56 | 57 | # Channel Attention 58 | self.ca = ChannelAttention(in_channels, out_channels, reduction=16) 59 | 60 | def forward(self, x): 61 | # Merge and run block 62 | x1 = self.path1_conv1(x) 63 | x1 = self.path1_relu1(x1) 64 | x1 = self.path1_conv2(x1) 65 | x1 = self.path1_relu2(x1) 66 | 67 | x2 = self.path2_conv1(x) 68 | x2 = self.path2_relu1(x2) 69 | x2 = self.path2_conv2(x2) 70 | x2 = self.path2_relu2(x2) 71 | 72 | x3 = torch.cat([x1, x2], dim=1) 73 | x3 = self.conv3(x3) 74 | x3 = self.relu3(x3) 75 | x3 = x3 + x 76 | 77 | # Residual block 78 | x4 = self.conv4(x3) 79 | x4 = self.relu4(x4) 80 | x4 = self.conv5(x4) 81 | x5 = x4 + x3 82 | x5 = self.relu5(x5) 83 | 84 | # Enhance Residual block 85 | x6 = self.conv6(x5) 86 | x6 = self.relu6(x6) 87 | x7 = self.conv7(x6) 88 | x7 = self.relu7(x7) 89 | x8 = self.conv8(x7) 90 | x8 = x8 + x5 91 | x8 = self.relu8(x8) 92 | 93 | x_ca = self.ca(x8) 94 | 95 | return x_ca + x 96 | 97 | 98 | 99 | 100 | 101 | class RIDNet(nn.Module): 102 | def __init__(self, in_channels, out_channels, num_feautres): 103 | super(RIDNet, self).__init__() 104 | 105 | self.conv1 = nn.Conv2d(in_channels, num_feautres, kernel_size=3, stride=1, padding=1) 106 | self.relu1 = nn.ReLU(inplace=False) 107 | 108 | self.eam1 = EAM(in_channels=num_feautres, out_channels=num_feautres) 109 | self.eam2 = EAM(in_channels=num_feautres, out_channels=num_feautres) 110 | self.eam3 = EAM(in_channels=num_feautres, out_channels=num_feautres) 111 | self.eam4 = EAM(in_channels=num_feautres, out_channels=num_feautres) 112 | 113 | self.last_conv = nn.Conv2d(num_feautres, out_channels, kernel_size=3, stride=1, padding=1, dilation=1) 114 | 115 | self.init_weights() 116 | 117 | def forward(self, x): 118 | x1 = self.conv1(x) # feature extraction module 119 | x1 = self.relu1(x1) 120 | 121 | x_eam = self.eam1(x1) 122 | x_eam = self.eam2(x_eam) 123 | x_eam = self.eam3(x_eam) 124 | x_eam = self.eam4(x_eam) 125 | 126 | x_lsc = x_eam + x1 # Long skip connection 127 | x_out = self.last_conv(x_lsc) # reconstruction module 128 | x_out = x_out + x # Long skip connection 129 | 130 | return x_out 131 | 132 | def init_weights(self): 133 | for m in self.modules(): 134 | if isinstance(m, nn.Conv2d): 135 | nn.init.xavier_uniform_(m.weight) 136 | elif isinstance(m, nn.BatchNorm2d): 137 | nn.init.constant_(m.weight, 1) 138 | nn.init.constant_(m.bias, 0) 139 | 140 | 141 | 142 | -------------------------------------------------------------------------------- /util/augment.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import cv2 3 | import numpy as np 4 | import torch 5 | 6 | 7 | class ToTensor(object): 8 | def __call__(self, data): 9 | noisy, clean = data['noisy'], data['clean'] 10 | 11 | # (512, 512) -> (512, 512, 1) 12 | noisy = np.expand_dims(noisy, -1) 13 | clean = np.expand_dims(clean, -1) 14 | 15 | noisy = torch.from_numpy(noisy.copy()).type(torch.float32) 16 | clean = torch.from_numpy(clean.copy()).type(torch.float32) 17 | 18 | # (H, W, C) -> (C, H, W) 19 | noisy = noisy.permute(2, 0, 1) 20 | clean = clean.permute(2, 0, 1) 21 | 22 | data = {'noisy': noisy, 'clean': clean} 23 | 24 | return data 25 | 26 | class Normalize(object): 27 | def __call__(self, data): 28 | noisy, clean = data['noisy'], data['clean'] 29 | 30 | noisy = noisy / 255. 31 | clean = clean / 255. 32 | 33 | data = {'noisy': noisy, 'clean': clean} 34 | 35 | return data 36 | 37 | 38 | class Random_Brightness(object): 39 | def __init__(self, p, sigma1): 40 | self.p = p 41 | self.sigma1 = sigma1 42 | 43 | def __call__(self, data): 44 | noisy, clean = data['noisy'], data['clean'] 45 | 46 | if self.p >= np.random.random(): 47 | self.sigma1 = np.random.uniform(low=-(self.sigma1), high=(self.sigma1)) # e.g. -0.3 ~ 0.3 48 | noisy = cv2.add(noisy, np.mean(noisy)*self.sigma1) 49 | 50 | data = {'noisy': noisy, 'clean': clean} 51 | 52 | return data 53 | 54 | 55 | class Horizontal_Flip(object): 56 | def __init__(self, p=0.5): 57 | self.p = p 58 | 59 | def __call__(self, data): 60 | noisy, clean = data['noisy'], data['clean'] 61 | 62 | if np.random.rand() <= self.p: 63 | noisy = noisy[:, ::-1] 64 | clean = clean[:, ::-1] 65 | 66 | data = {'noisy': noisy, 'clean': clean} 67 | 68 | return data 69 | 70 | 71 | class Vertical_Flip(object): 72 | def __init__(self, p=0.5): 73 | self.p = p 74 | 75 | def __call__(self, data): 76 | noisy, clean = data['noisy'], data['clean'] 77 | 78 | if np.random.rand() <= self.p: 79 | noisy = noisy[::-1, :] 80 | clean = clean[::-1, :] 81 | 82 | data = {'noisy': noisy, 'clean': clean} 83 | 84 | return data 85 | 86 | 87 | class Rotation(object): 88 | def __init__(self, p=0.5, angle=(-30, 30)): 89 | self.p = p 90 | self.angle = angle 91 | 92 | def __call__(self, data): 93 | noisy, clean = data['noisy'], data['clean'] 94 | 95 | if self.p >= np.random.random(): 96 | h, w = clean.shape 97 | rotation_angle = np.random.randint(self.angle[0], self.angle[1]) 98 | rotation_matrix = cv2.getRotationMatrix2D((h/2, w/2), rotation_angle, 1) 99 | 100 | noisy = cv2.warpAffine(noisy, rotation_matrix, (h, w)) 101 | clean = cv2.warpAffine(clean, rotation_matrix, (h, w)) 102 | 103 | data = {'noisy': noisy, 'clean': clean} 104 | 105 | return data 106 | 107 | 108 | class Shift_X(object): 109 | def __init__(self, p, dx=30): 110 | self.p = p 111 | self.dx = np.random.randint(low=-dx, high=dx) 112 | 113 | def __call__(self, data): 114 | noisy, clean = data['noisy'], data['clean'] 115 | 116 | if self.p >= np.random.random(): 117 | h, w = clean.shape 118 | shifted_noisy = np.zeros(noisy.shape).astype(np.uint8) 119 | shifted_clean = np.zeros(clean.shape).astype(np.uint8) 120 | 121 | if self.dx > 0: # shift right 122 | shifted_noisy[:, self.dx:] = noisy[:, :w-self.dx] 123 | shifted_clean[:, self.dx:] = clean[:, :w-self.dx] 124 | else: # shift left 125 | shifted_noisy[:, :w+self.dx] = noisy[:, (-self.dx):] 126 | shifted_clean[:, :w+self.dx] = clean[:, (-self.dx):] 127 | 128 | data = {'noisy': shifted_noisy, 'clean': shifted_clean} 129 | else: 130 | data = {'noisy': noisy, 'clean': clean} 131 | 132 | return data 133 | 134 | 135 | class Shift_Y(object): 136 | def __init__(self, p, dy=30): 137 | self.p = p 138 | self.dy = np.random.randint(low=-dy, high=dy) 139 | 140 | def __call__(self, data): 141 | noisy, clean = data['noisy'], data['clean'] 142 | 143 | if self.p >= np.random.random(): 144 | h, w = clean.shape 145 | shifted_noisy = np.zeros(noisy.shape).astype(np.uint8) 146 | shifted_clean = np.zeros(clean.shape).astype(np.uint8) 147 | 148 | if self.dy > 0: # shift up 149 | shifted_noisy[:h-self.dy, :] = noisy[self.dy:, :] 150 | shifted_clean[:h-self.dy, :] = clean[self.dy:, :] 151 | else: # shift down 152 | shifted_noisy[-self.dy:, :] = noisy[:(h+self.dy), :] 153 | shifted_clean[-self.dy:, :] = clean[:(h+self.dy), :] 154 | 155 | data = {'noisy': shifted_noisy, 'clean': shifted_clean} 156 | else: 157 | data = {'noisy': noisy, 'clean': clean} 158 | 159 | return data 160 | 161 | 162 | class Random_Crop(object): 163 | def __init__(self, patch_size): 164 | self.patch_size = patch_size 165 | 166 | def __call__(self, data): 167 | noisy, clean = data['noisy'], data['clean'] 168 | 169 | h, w = clean.shape 170 | 171 | top = np.random.randint(0, h - self.patch_size[0]) 172 | bottom = top + self.patch_size[0] 173 | left = np.random.randint(0, w - self.patch_size[1]) 174 | right = left + self.patch_size[1] 175 | 176 | noisy_patch = noisy[top:bottom, left:right] 177 | clean_patch = clean[top:bottom, left:right] 178 | 179 | data = {'noisy': noisy_patch, 'clean': clean_patch} 180 | 181 | return data 182 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from cmath import inf 2 | import os 3 | import random 4 | import argparse 5 | from random import shuffle 6 | from re import T 7 | from tqdm import tqdm 8 | 9 | import torch 10 | import torchvision 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | from torchvision import transforms 14 | from torch.utils.data import DataLoader 15 | from torch.utils.tensorboard import SummaryWriter 16 | 17 | from torchsummary import summary 18 | 19 | from model.RIDNet import RIDNet 20 | from util.dataset import * 21 | from util.loss import * 22 | from util.augment import * 23 | 24 | def seed_everything(seed: int = 42): 25 | random.seed(seed) 26 | np.random.seed(seed) 27 | os.environ["PYTHONHASHSEED"] = str(seed) 28 | torch.manual_seed(seed) 29 | torch.cuda.manual_seed(seed) # type: ignore 30 | torch.cuda.manual_seed_all(seed) # if use multi-GPU 31 | torch.backends.cudnn.deterministic = True # type: ignore 32 | torch.backends.cudnn.benchmark = False # type: ignore 33 | 34 | def train(): 35 | parser = argparse.ArgumentParser(description='argparse argument') 36 | parser.add_argument('--epochs', 37 | type=int, 38 | help='epoch', 39 | default='300', 40 | dest='epochs') 41 | 42 | parser.add_argument('--batch_size', 43 | type=int, 44 | help='batch_size', 45 | default='8', 46 | dest='batch_size') 47 | 48 | args = parser.parse_args() 49 | 50 | 51 | # hyper parameters 52 | EPOCHS = args.epochs 53 | BATCH_SIZE = args.batch_size 54 | 55 | if torch.cuda.is_available(): 56 | device = torch.device('cuda:0') 57 | print(device) 58 | 59 | train_transform = transforms.Compose([ 60 | Random_Brightness(p=0.5, 61 | sigma1=0.3), 62 | Horizontal_Flip(p=0.5), 63 | Vertical_Flip(p=0.5), 64 | Shift_X(p=0.5, 65 | dx=30), 66 | Shift_Y(p=0.5, 67 | dy=30), 68 | Rotation(p=0.5, 69 | angle=(-30, 30)), 70 | Random_Crop(patch_size=(64, 64)), # for patch-wise training 71 | Normalize(), 72 | ToTensor() 73 | ]) 74 | 75 | train_dataset = Denoising_dataset(img_dir='your dataset path', 76 | train_val='train', 77 | transform=train_transform) 78 | 79 | train_loader = DataLoader(train_dataset, 80 | batch_size=BATCH_SIZE, 81 | shuffle=True, 82 | num_workers=0) 83 | 84 | val_transform = transforms.Compose([ 85 | Normalize(), 86 | ToTensor() 87 | ]) 88 | 89 | val_dataset = Denoising_dataset(img_dir='your dataset path', 90 | train_val='val', 91 | transform=val_transform) 92 | 93 | val_loader = DataLoader(val_dataset, 94 | batch_size=BATCH_SIZE, 95 | shuffle=False, 96 | num_workers=0) 97 | 98 | 99 | model = RIDNet(in_channels=1, out_channels=1, num_feautres=128) 100 | model.to(device) 101 | summary(model,(1, 512, 512), batch_size=BATCH_SIZE) 102 | 103 | 104 | criterion = L1_Loss().to(device) 105 | optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5) 106 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-5, verbose=1) 107 | 108 | # tensorboard 109 | writer = SummaryWriter('runs/') 110 | 111 | best_val_loss = inf 112 | 113 | for epoch in range(1, EPOCHS+1): 114 | train_loss = 0. 115 | val_loss = 0. 116 | 117 | loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=False) 118 | 119 | model.train() 120 | for i, data in loop: 121 | noisy = data['noisy'].to(device) 122 | clean = data['clean'].to(device) 123 | 124 | optimizer.zero_grad() 125 | pred = model(noisy) 126 | loss = criterion(pred, clean) # pred, gt 127 | loss.backward() 128 | optimizer.step() 129 | 130 | train_loss += loss.item() 131 | loop.set_description(f'Epoch [{epoch}/{EPOCHS}') 132 | 133 | current_lr = scheduler.optimizer.param_groups[0]['lr'] 134 | writer.add_scalar('lr', current_lr, epoch) 135 | scheduler.step() 136 | 137 | model.eval() 138 | with torch.no_grad(): 139 | loop = tqdm(enumerate(val_loader), total=len(val_loader), leave=False) 140 | 141 | for j, data in loop: 142 | noisy = data['noisy'].to(device) 143 | clean = data['clean'].to(device) 144 | 145 | pred = model(noisy) 146 | 147 | loss = criterion(pred, clean) 148 | val_loss += loss.item() 149 | loop.set_description(f'valid') 150 | 151 | train_loss = train_loss / len(train_loader) 152 | val_loss = val_loss / len(val_loader) 153 | 154 | writer.add_scalar('Loss/train', train_loss, epoch) 155 | writer.add_scalar('Loss/val', val_loss, epoch) 156 | 157 | print(f'Epoch: {epoch}\t train_loss: {train_loss}\t val_loss: {val_loss}') 158 | 159 | if best_val_loss > val_loss: 160 | # print('=' * 100) 161 | print('=' * 100) 162 | print(f'val_loss is improved from {best_val_loss:.8f} to {val_loss:.8f}\t saved current weight') 163 | print('=' * 100) 164 | best_val_loss = val_loss 165 | 166 | # torch.save(model, 'model.pth') 167 | torch.save({'epoch': epoch, 168 | 'model_state_dict': model.state_dict(), 169 | 'optimizer_state_dict': optimizer.state_dict(), 170 | 'loss': criterion}, 171 | f'weight/{str(criterion).split("()")[0]}_model_{epoch:05d}_valloss_{best_val_loss:.4f}.pth') 172 | 173 | writer.close() 174 | 175 | if __name__ == '__main__': 176 | seed_everything(42) 177 | train() 178 | --------------------------------------------------------------------------------