├── util
├── loss.py
├── dataset.py
└── augment.py
├── demo.py
├── README.md
├── model
└── RIDNet.py
└── train.py
/util/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class L1_Loss(nn.Module):
6 | def __init__(self):
7 | super(L1_Loss, self).__init__()
8 |
9 | def forward(self, x, y):
10 | loss = F.l1_loss(x, y, reduction='mean')
11 | return loss * 1000
12 |
13 |
14 | class Smooth_L1_Loss(nn.Module):
15 | def __init__(self):
16 | super(Smooth_L1_Loss, self).__init__()
17 |
18 | def forward(self, x, y):
19 | loss = F.smooth_l1_loss(x, y, reduction='mean')
20 | return loss * 1000
21 |
22 |
23 | class L1_L2_Loss(nn.Module):
24 | def __init__(self, ratio):
25 | super(L1_L2_Loss, self).__init__()
26 | self.ratio = ratio
27 |
28 | def forward(self, x, y):
29 | L1_loss = F.l1_loss(x, y, reduction='mean')
30 | L2_loss = F.mse_loss(x, y, reduction='mean')
31 | L1_L2 = (self.ratio)*L1_loss + (1-self.ratio)*L2_loss
32 |
33 | return L1_L2
34 |
--------------------------------------------------------------------------------
/util/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import copy
3 | import glob
4 | import cv2
5 | import numpy as np
6 | import copy
7 |
8 | class Denoising_dataset(torch.utils.data.Dataset):
9 | def __init__(self, img_dir, train_val, transform):
10 | super(Denoising_dataset, self).__init__()
11 |
12 | self.img_dir = [f for f in glob.glob(img_dir+'/**/*.jpg', recursive=True)]
13 | self.train_val = train_val
14 | self.transform = transform
15 |
16 | def __len__(self):
17 | return len(self.img_dir)
18 |
19 | def __getitem__(self, idx):
20 | img_dir = self.img_dir[idx]
21 |
22 | clean = cv2.imread(img_dir, cv2.IMREAD_GRAYSCALE)
23 | noisy = np.copy(clean)
24 | origin_img = copy.deepcopy(clean)
25 |
26 | noisy = self.gaussian_noise(clean)
27 |
28 | data = {'noisy': noisy, 'clean': clean}
29 |
30 | if self.transform:
31 | data = self.transform(data)
32 |
33 | return data
34 |
35 | def gaussian_noise(self, img, noise_level=[15, 25, 50]):
36 | sigma = np.random.choice(noise_level)
37 | gaussian_noise = np.random.normal(0, sigma, (img.shape[0], img.shape[1]))
38 |
39 | noisy_img = img + gaussian_noise
40 | noisy_img = np.clip(noisy_img, 0, 255).astype(np.uint8)
41 |
42 | return noisy_img
43 |
44 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import cv2
4 | import numpy as np
5 |
6 | from model.RIDNet import RIDNet
7 | from util.dataset import *
8 |
9 |
10 | def gaussian_noise(img, noise_level=[5, 10, 15, 20, 25, 30]):
11 | sigma = np.random.choice(noise_level)
12 | gaussian_noise = np.random.normal(0, sigma, (img.shape[0], img.shape[1]))
13 |
14 | noisy_img = img + gaussian_noise
15 | noisy_img = np.clip(noisy_img, 0, 255).astype(np.uint8)
16 | return noisy_img
17 |
18 |
19 | def demo():
20 | model = RIDNet(in_channels=1, out_channels=1, num_feautres=32)
21 |
22 | checkpoint = torch.load('./weight/weight.pth')
23 | model.load_state_dict(checkpoint['model_state_dict'])
24 | criterion = checkpoint['loss']
25 |
26 | if torch.cuda.is_available():
27 | device = torch.device("cuda:0")
28 | print(device)
29 | model.to(device)
30 |
31 | img = cv2.imread(v, cv2.IMREAD_GRAYSCALE)
32 | origin_img = copy.deepcopy(img)
33 |
34 | img = gaussian_noise(img, noise_level=[15])
35 | img = np.expand_dims(img, -1)
36 | img = img / 255.
37 | img = np.expand_dims(img , 0)
38 | img = torch.from_numpy(img).type(torch.float32)
39 | img = img.permute(0, 3, 1, 2)
40 | img = img.to(device)
41 |
42 | pred = model(img)
43 | output = pred[0].cpu().numpy().transpose(1, 2, 0)
44 | output = output * 255
45 | output = np.clip(output, 0, 255).astype(np.uint8)
46 |
47 | cv2.imwrite('input.jpg', origin_img)
48 | cv2.imwrite('output.jpg', output)
49 |
50 |
51 | if __name__ == '__main__':
52 | demo()
53 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Pytorch RIDNet Implementation (unofficial code)
2 |
3 | ## [[Paper]](https://openaccess.thecvf.com/content_ICCV_2019/papers/Anwar_Real_Image_Denoising_With_Feature_Attention_ICCV_2019_paper.pdf)
4 | ## Real Image Denoising with Feature Attention (ICCV, 2019)
5 |
6 |
7 | ***Abstract***
8 |
9 |
10 | *Deep convolutional neural networks perform better
11 | on images containing spatially invariant noise (synthetic
12 | noise); however, their performance is limited on real-noisy
13 | photographs and requires multiple stage network modeling. To advance the practicability of denoising algorithms,
14 | this paper proposes a novel single-stage blind real image
15 | denoising network (RIDNet) by employing a modular architecture. We use a residual on the residual structure to
16 | ease the flow of low-frequency information and apply feature attention to exploit the channel dependencies. Furthermore, the evaluation in terms of quantitative metrics and visual quality on three synthetic and four real noisy datasets
17 | against 19 state-of-the-art algorithms demonstrate the superiority of our RIDNet.*
18 |
19 |
20 |
21 | 
22 |
23 |
24 |
25 |
26 | ## Train
27 | ```
28 | > python train.py --epochs 100 --batch_size 16
29 | ```
30 |
31 |
32 | ## Result
33 | ### Ground Truth / Noised image / Denoised image
34 | 
35 | 
36 | 
37 | 
38 |
39 |
40 | ## Reference
41 | * [Official code](https://github.com/saeed-anwar/RIDNet)
42 |
--------------------------------------------------------------------------------
/model/RIDNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class ChannelAttention(nn.Module):
6 | def __init__(self, in_channels, out_channels, reduction=16):
7 | super(ChannelAttention, self).__init__()
8 |
9 | self.gap = nn.AdaptiveAvgPool2d(1)
10 | self.conv1 =nn.Conv2d(in_channels, out_channels//reduction, 1, 1, 0)
11 | self.relu1 = nn.ReLU()
12 | self.conv2 = nn.Conv2d(out_channels//reduction, in_channels, 1, 1, 0)
13 | self.sigmoid2 = nn.Sigmoid()
14 |
15 | def forward(self, x):
16 | gap = self.gap(x)
17 | x_out = self.conv1(gap)
18 | x_out = self.relu1(x_out)
19 | x_out = self.conv2(x_out)
20 | x_out = self.sigmoid2(x_out)
21 | x_out = x_out * x
22 | return x_out
23 |
24 |
25 | class EAM(nn.Module):
26 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, reduciton=4):
27 | super(EAM, self).__init__()
28 |
29 | # Merge and run block
30 | self.path1_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1)
31 | self.path1_relu1 = nn.ReLU()
32 | self.path1_conv2 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=2, dilation=2)
33 | self.path1_relu2 = nn.ReLU()
34 |
35 | self.path2_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=3, dilation=3)
36 | self.path2_relu1 = nn.ReLU()
37 | self.path2_conv2 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=4, dilation=4)
38 | self.path2_relu2 = nn.ReLU()
39 |
40 | self.conv3 = nn.Conv2d(in_channels*2, out_channels, kernel_size, stride=1, padding=1)
41 | self.relu3 = nn.ReLU()
42 |
43 | # Residual block
44 | self.conv4 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
45 | self.relu4 = nn.ReLU()
46 | self.conv5 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
47 | self.relu5 = nn.ReLU()
48 |
49 | # Enhance Residual block
50 | self.conv6 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
51 | self.relu6 = nn.ReLU()
52 | self.conv7 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
53 | self.relu7 = nn.ReLU()
54 | self.conv8 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
55 | self.relu8 = nn.ReLU()
56 |
57 | # Channel Attention
58 | self.ca = ChannelAttention(in_channels, out_channels, reduction=16)
59 |
60 | def forward(self, x):
61 | # Merge and run block
62 | x1 = self.path1_conv1(x)
63 | x1 = self.path1_relu1(x1)
64 | x1 = self.path1_conv2(x1)
65 | x1 = self.path1_relu2(x1)
66 |
67 | x2 = self.path2_conv1(x)
68 | x2 = self.path2_relu1(x2)
69 | x2 = self.path2_conv2(x2)
70 | x2 = self.path2_relu2(x2)
71 |
72 | x3 = torch.cat([x1, x2], dim=1)
73 | x3 = self.conv3(x3)
74 | x3 = self.relu3(x3)
75 | x3 = x3 + x
76 |
77 | # Residual block
78 | x4 = self.conv4(x3)
79 | x4 = self.relu4(x4)
80 | x4 = self.conv5(x4)
81 | x5 = x4 + x3
82 | x5 = self.relu5(x5)
83 |
84 | # Enhance Residual block
85 | x6 = self.conv6(x5)
86 | x6 = self.relu6(x6)
87 | x7 = self.conv7(x6)
88 | x7 = self.relu7(x7)
89 | x8 = self.conv8(x7)
90 | x8 = x8 + x5
91 | x8 = self.relu8(x8)
92 |
93 | x_ca = self.ca(x8)
94 |
95 | return x_ca + x
96 |
97 |
98 |
99 |
100 |
101 | class RIDNet(nn.Module):
102 | def __init__(self, in_channels, out_channels, num_feautres):
103 | super(RIDNet, self).__init__()
104 |
105 | self.conv1 = nn.Conv2d(in_channels, num_feautres, kernel_size=3, stride=1, padding=1)
106 | self.relu1 = nn.ReLU(inplace=False)
107 |
108 | self.eam1 = EAM(in_channels=num_feautres, out_channels=num_feautres)
109 | self.eam2 = EAM(in_channels=num_feautres, out_channels=num_feautres)
110 | self.eam3 = EAM(in_channels=num_feautres, out_channels=num_feautres)
111 | self.eam4 = EAM(in_channels=num_feautres, out_channels=num_feautres)
112 |
113 | self.last_conv = nn.Conv2d(num_feautres, out_channels, kernel_size=3, stride=1, padding=1, dilation=1)
114 |
115 | self.init_weights()
116 |
117 | def forward(self, x):
118 | x1 = self.conv1(x) # feature extraction module
119 | x1 = self.relu1(x1)
120 |
121 | x_eam = self.eam1(x1)
122 | x_eam = self.eam2(x_eam)
123 | x_eam = self.eam3(x_eam)
124 | x_eam = self.eam4(x_eam)
125 |
126 | x_lsc = x_eam + x1 # Long skip connection
127 | x_out = self.last_conv(x_lsc) # reconstruction module
128 | x_out = x_out + x # Long skip connection
129 |
130 | return x_out
131 |
132 | def init_weights(self):
133 | for m in self.modules():
134 | if isinstance(m, nn.Conv2d):
135 | nn.init.xavier_uniform_(m.weight)
136 | elif isinstance(m, nn.BatchNorm2d):
137 | nn.init.constant_(m.weight, 1)
138 | nn.init.constant_(m.bias, 0)
139 |
140 |
141 |
142 |
--------------------------------------------------------------------------------
/util/augment.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import cv2
3 | import numpy as np
4 | import torch
5 |
6 |
7 | class ToTensor(object):
8 | def __call__(self, data):
9 | noisy, clean = data['noisy'], data['clean']
10 |
11 | # (512, 512) -> (512, 512, 1)
12 | noisy = np.expand_dims(noisy, -1)
13 | clean = np.expand_dims(clean, -1)
14 |
15 | noisy = torch.from_numpy(noisy.copy()).type(torch.float32)
16 | clean = torch.from_numpy(clean.copy()).type(torch.float32)
17 |
18 | # (H, W, C) -> (C, H, W)
19 | noisy = noisy.permute(2, 0, 1)
20 | clean = clean.permute(2, 0, 1)
21 |
22 | data = {'noisy': noisy, 'clean': clean}
23 |
24 | return data
25 |
26 | class Normalize(object):
27 | def __call__(self, data):
28 | noisy, clean = data['noisy'], data['clean']
29 |
30 | noisy = noisy / 255.
31 | clean = clean / 255.
32 |
33 | data = {'noisy': noisy, 'clean': clean}
34 |
35 | return data
36 |
37 |
38 | class Random_Brightness(object):
39 | def __init__(self, p, sigma1):
40 | self.p = p
41 | self.sigma1 = sigma1
42 |
43 | def __call__(self, data):
44 | noisy, clean = data['noisy'], data['clean']
45 |
46 | if self.p >= np.random.random():
47 | self.sigma1 = np.random.uniform(low=-(self.sigma1), high=(self.sigma1)) # e.g. -0.3 ~ 0.3
48 | noisy = cv2.add(noisy, np.mean(noisy)*self.sigma1)
49 |
50 | data = {'noisy': noisy, 'clean': clean}
51 |
52 | return data
53 |
54 |
55 | class Horizontal_Flip(object):
56 | def __init__(self, p=0.5):
57 | self.p = p
58 |
59 | def __call__(self, data):
60 | noisy, clean = data['noisy'], data['clean']
61 |
62 | if np.random.rand() <= self.p:
63 | noisy = noisy[:, ::-1]
64 | clean = clean[:, ::-1]
65 |
66 | data = {'noisy': noisy, 'clean': clean}
67 |
68 | return data
69 |
70 |
71 | class Vertical_Flip(object):
72 | def __init__(self, p=0.5):
73 | self.p = p
74 |
75 | def __call__(self, data):
76 | noisy, clean = data['noisy'], data['clean']
77 |
78 | if np.random.rand() <= self.p:
79 | noisy = noisy[::-1, :]
80 | clean = clean[::-1, :]
81 |
82 | data = {'noisy': noisy, 'clean': clean}
83 |
84 | return data
85 |
86 |
87 | class Rotation(object):
88 | def __init__(self, p=0.5, angle=(-30, 30)):
89 | self.p = p
90 | self.angle = angle
91 |
92 | def __call__(self, data):
93 | noisy, clean = data['noisy'], data['clean']
94 |
95 | if self.p >= np.random.random():
96 | h, w = clean.shape
97 | rotation_angle = np.random.randint(self.angle[0], self.angle[1])
98 | rotation_matrix = cv2.getRotationMatrix2D((h/2, w/2), rotation_angle, 1)
99 |
100 | noisy = cv2.warpAffine(noisy, rotation_matrix, (h, w))
101 | clean = cv2.warpAffine(clean, rotation_matrix, (h, w))
102 |
103 | data = {'noisy': noisy, 'clean': clean}
104 |
105 | return data
106 |
107 |
108 | class Shift_X(object):
109 | def __init__(self, p, dx=30):
110 | self.p = p
111 | self.dx = np.random.randint(low=-dx, high=dx)
112 |
113 | def __call__(self, data):
114 | noisy, clean = data['noisy'], data['clean']
115 |
116 | if self.p >= np.random.random():
117 | h, w = clean.shape
118 | shifted_noisy = np.zeros(noisy.shape).astype(np.uint8)
119 | shifted_clean = np.zeros(clean.shape).astype(np.uint8)
120 |
121 | if self.dx > 0: # shift right
122 | shifted_noisy[:, self.dx:] = noisy[:, :w-self.dx]
123 | shifted_clean[:, self.dx:] = clean[:, :w-self.dx]
124 | else: # shift left
125 | shifted_noisy[:, :w+self.dx] = noisy[:, (-self.dx):]
126 | shifted_clean[:, :w+self.dx] = clean[:, (-self.dx):]
127 |
128 | data = {'noisy': shifted_noisy, 'clean': shifted_clean}
129 | else:
130 | data = {'noisy': noisy, 'clean': clean}
131 |
132 | return data
133 |
134 |
135 | class Shift_Y(object):
136 | def __init__(self, p, dy=30):
137 | self.p = p
138 | self.dy = np.random.randint(low=-dy, high=dy)
139 |
140 | def __call__(self, data):
141 | noisy, clean = data['noisy'], data['clean']
142 |
143 | if self.p >= np.random.random():
144 | h, w = clean.shape
145 | shifted_noisy = np.zeros(noisy.shape).astype(np.uint8)
146 | shifted_clean = np.zeros(clean.shape).astype(np.uint8)
147 |
148 | if self.dy > 0: # shift up
149 | shifted_noisy[:h-self.dy, :] = noisy[self.dy:, :]
150 | shifted_clean[:h-self.dy, :] = clean[self.dy:, :]
151 | else: # shift down
152 | shifted_noisy[-self.dy:, :] = noisy[:(h+self.dy), :]
153 | shifted_clean[-self.dy:, :] = clean[:(h+self.dy), :]
154 |
155 | data = {'noisy': shifted_noisy, 'clean': shifted_clean}
156 | else:
157 | data = {'noisy': noisy, 'clean': clean}
158 |
159 | return data
160 |
161 |
162 | class Random_Crop(object):
163 | def __init__(self, patch_size):
164 | self.patch_size = patch_size
165 |
166 | def __call__(self, data):
167 | noisy, clean = data['noisy'], data['clean']
168 |
169 | h, w = clean.shape
170 |
171 | top = np.random.randint(0, h - self.patch_size[0])
172 | bottom = top + self.patch_size[0]
173 | left = np.random.randint(0, w - self.patch_size[1])
174 | right = left + self.patch_size[1]
175 |
176 | noisy_patch = noisy[top:bottom, left:right]
177 | clean_patch = clean[top:bottom, left:right]
178 |
179 | data = {'noisy': noisy_patch, 'clean': clean_patch}
180 |
181 | return data
182 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from cmath import inf
2 | import os
3 | import random
4 | import argparse
5 | from random import shuffle
6 | from re import T
7 | from tqdm import tqdm
8 |
9 | import torch
10 | import torchvision
11 | import torch.nn as nn
12 | import torch.optim as optim
13 | from torchvision import transforms
14 | from torch.utils.data import DataLoader
15 | from torch.utils.tensorboard import SummaryWriter
16 |
17 | from torchsummary import summary
18 |
19 | from model.RIDNet import RIDNet
20 | from util.dataset import *
21 | from util.loss import *
22 | from util.augment import *
23 |
24 | def seed_everything(seed: int = 42):
25 | random.seed(seed)
26 | np.random.seed(seed)
27 | os.environ["PYTHONHASHSEED"] = str(seed)
28 | torch.manual_seed(seed)
29 | torch.cuda.manual_seed(seed) # type: ignore
30 | torch.cuda.manual_seed_all(seed) # if use multi-GPU
31 | torch.backends.cudnn.deterministic = True # type: ignore
32 | torch.backends.cudnn.benchmark = False # type: ignore
33 |
34 | def train():
35 | parser = argparse.ArgumentParser(description='argparse argument')
36 | parser.add_argument('--epochs',
37 | type=int,
38 | help='epoch',
39 | default='300',
40 | dest='epochs')
41 |
42 | parser.add_argument('--batch_size',
43 | type=int,
44 | help='batch_size',
45 | default='8',
46 | dest='batch_size')
47 |
48 | args = parser.parse_args()
49 |
50 |
51 | # hyper parameters
52 | EPOCHS = args.epochs
53 | BATCH_SIZE = args.batch_size
54 |
55 | if torch.cuda.is_available():
56 | device = torch.device('cuda:0')
57 | print(device)
58 |
59 | train_transform = transforms.Compose([
60 | Random_Brightness(p=0.5,
61 | sigma1=0.3),
62 | Horizontal_Flip(p=0.5),
63 | Vertical_Flip(p=0.5),
64 | Shift_X(p=0.5,
65 | dx=30),
66 | Shift_Y(p=0.5,
67 | dy=30),
68 | Rotation(p=0.5,
69 | angle=(-30, 30)),
70 | Random_Crop(patch_size=(64, 64)), # for patch-wise training
71 | Normalize(),
72 | ToTensor()
73 | ])
74 |
75 | train_dataset = Denoising_dataset(img_dir='your dataset path',
76 | train_val='train',
77 | transform=train_transform)
78 |
79 | train_loader = DataLoader(train_dataset,
80 | batch_size=BATCH_SIZE,
81 | shuffle=True,
82 | num_workers=0)
83 |
84 | val_transform = transforms.Compose([
85 | Normalize(),
86 | ToTensor()
87 | ])
88 |
89 | val_dataset = Denoising_dataset(img_dir='your dataset path',
90 | train_val='val',
91 | transform=val_transform)
92 |
93 | val_loader = DataLoader(val_dataset,
94 | batch_size=BATCH_SIZE,
95 | shuffle=False,
96 | num_workers=0)
97 |
98 |
99 | model = RIDNet(in_channels=1, out_channels=1, num_feautres=128)
100 | model.to(device)
101 | summary(model,(1, 512, 512), batch_size=BATCH_SIZE)
102 |
103 |
104 | criterion = L1_Loss().to(device)
105 | optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
106 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-5, verbose=1)
107 |
108 | # tensorboard
109 | writer = SummaryWriter('runs/')
110 |
111 | best_val_loss = inf
112 |
113 | for epoch in range(1, EPOCHS+1):
114 | train_loss = 0.
115 | val_loss = 0.
116 |
117 | loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=False)
118 |
119 | model.train()
120 | for i, data in loop:
121 | noisy = data['noisy'].to(device)
122 | clean = data['clean'].to(device)
123 |
124 | optimizer.zero_grad()
125 | pred = model(noisy)
126 | loss = criterion(pred, clean) # pred, gt
127 | loss.backward()
128 | optimizer.step()
129 |
130 | train_loss += loss.item()
131 | loop.set_description(f'Epoch [{epoch}/{EPOCHS}')
132 |
133 | current_lr = scheduler.optimizer.param_groups[0]['lr']
134 | writer.add_scalar('lr', current_lr, epoch)
135 | scheduler.step()
136 |
137 | model.eval()
138 | with torch.no_grad():
139 | loop = tqdm(enumerate(val_loader), total=len(val_loader), leave=False)
140 |
141 | for j, data in loop:
142 | noisy = data['noisy'].to(device)
143 | clean = data['clean'].to(device)
144 |
145 | pred = model(noisy)
146 |
147 | loss = criterion(pred, clean)
148 | val_loss += loss.item()
149 | loop.set_description(f'valid')
150 |
151 | train_loss = train_loss / len(train_loader)
152 | val_loss = val_loss / len(val_loader)
153 |
154 | writer.add_scalar('Loss/train', train_loss, epoch)
155 | writer.add_scalar('Loss/val', val_loss, epoch)
156 |
157 | print(f'Epoch: {epoch}\t train_loss: {train_loss}\t val_loss: {val_loss}')
158 |
159 | if best_val_loss > val_loss:
160 | # print('=' * 100)
161 | print('=' * 100)
162 | print(f'val_loss is improved from {best_val_loss:.8f} to {val_loss:.8f}\t saved current weight')
163 | print('=' * 100)
164 | best_val_loss = val_loss
165 |
166 | # torch.save(model, 'model.pth')
167 | torch.save({'epoch': epoch,
168 | 'model_state_dict': model.state_dict(),
169 | 'optimizer_state_dict': optimizer.state_dict(),
170 | 'loss': criterion},
171 | f'weight/{str(criterion).split("()")[0]}_model_{epoch:05d}_valloss_{best_val_loss:.4f}.pth')
172 |
173 | writer.close()
174 |
175 | if __name__ == '__main__':
176 | seed_everything(42)
177 | train()
178 |
--------------------------------------------------------------------------------