├── .gitignore ├── 00091.png ├── README.md ├── data.py ├── losses.py ├── model.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *png 2 | *jpg 3 | *pth 4 | *pt 5 | *pyc 6 | -------------------------------------------------------------------------------- /00091.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AakashKT/pytorch-recurrent-ae-siggraph17/5fd5aa445fc6782aab62365fb65c387391c90e39/00091.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Interactive Reconstruction of Monte Carlo Image Sequences using a Recurrent Denoising Autoencoder , PyTorch implementation 2 | Link to original paper (SIGGRAPH '17) : https://research.nvidia.com/publication/interactive-reconstruction-monte-carlo-image-sequences-using-recurrent-denoising 3 | 4 | This is the unofficial PyTorch implementation of the above paper. 5 | 6 | # Dataset Preparation 7 | The input to this network is : 8 | * Albedo demodulated 1spp RGB image 9 | * Render a scene using any standart path tracer, and divide it by the albedo image of the same scene. Albedo image is usually obtained by adding the 'DiffuseColor' and 'GlossyColor' passes from the renderer. 10 | * Depth (Z-buffer) 11 | * Normal Map 12 | * Material roughness Map 13 | * Add the 'GlossyDirect' and 'GlossyIndirect' components from the renderer. Take the inverse of this image. 14 | 15 | The output of this network is : 16 | * Albedo demodulated 250spp RGB image 17 | 18 | Construct the input as one image, as follows :
19 | 20 |
21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 |
Column 1 Column 2
1 spp input250 spp output
Albedo ImageNormal Map
Depth MapRoughness Map
39 |
40 |

Directory structure

41 | Make batches of batch size 7 of continuous frames, and put them in a directory (seq_0, seq_1 ..). Do this for all frames. 42 | Split the resulting directories into test and train. 43 | 44 | Note : The data directory must contain 'train' and 'test' directories, and these directories much contain directories where the sequence is stored. 45 | 46 | * [DATA_DIRECTORY]/train 47 | * [DATA_DIRECTORY]/train/seq_0/ 48 | * [DATA_DIRECTORY]/train/seq_1/ 49 | * .... 50 | * [DATA_DIRECTORY]/test 51 | * [DATA_DIRECTORY]/test/seq_0/ 52 | * [DATA_DIRECTORY]/test/seq_1/ 53 | * .... 54 | 55 | # Training the network 56 | To train the network, run the following command : 57 | 58 | ``` 59 | python train.py --data_dir [PATH_TO_DATA_DIRECTORY] --name [EXP_NAME] --save_dir [PATH_TO_SAVE_CHECKPOINTS] --epochs 500 60 | ``` 61 | 62 | # Running a trained network 63 | To test the network, run the following command : 64 | 65 | ``` 66 | python test.py --data_dir [PATH_TO_DATA_DIRECTORY] --output_dir [PATH_TO_SAVE_RESULTS] --checkpoint [PATH_TO_CHECKPOINT].pt 67 | ``` 68 | 69 | # TODO / Possible faults 70 | * Can we get the demodulated albedo directly from the renderer? 71 | * The network denoises the demodulated albedo perfectly. But while reconstructing the textured image, by multiplication with albedo, it looses a lot of detail. FIX THIS. 72 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch, os, sys, cv2 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | import torch.optim as optim 6 | 7 | from torch.utils.data import Dataset, DataLoader 8 | from torch.nn import functional as func 9 | from PIL import Image 10 | 11 | import torchvision.transforms as transforms 12 | import numpy as np 13 | import torch 14 | 15 | 16 | class RAEData(Dataset): 17 | 18 | def __init__(self, input_dir, size): 19 | super(RAEData, self).__init__() 20 | 21 | self.input_dir = input_dir 22 | self.images = sorted(os.listdir(self.input_dir)) 23 | 24 | self.width = size[0] 25 | self.height = size[1] 26 | 27 | def __getitem__(self, index): 28 | # 10spp_shading ray_shading 29 | # 10spp_albedo normal 30 | # depth roughness 31 | #----------------------------- 32 | 33 | A = np.zeros((7, self.height, self.width, 8), dtype=np.float) 34 | B = np.zeros((7, self.height, self.width, 3), dtype=np.float) 35 | ALBEDO = np.zeros((7, self.height, self.width, 3), dtype=np.float) 36 | 37 | seq_images = sorted(os.listdir('%s/%s' % (self.input_dir, self.images[index]))) 38 | for i, item in enumerate(seq_images): 39 | img = cv2.imread('%s/%s/%s' % (self.input_dir, self.images[index], item)) 40 | img = cv2.resize(img, (self.width * 2, self.height * 3)) 41 | 42 | shading = img[:self.height, :self.width, :] 43 | ray_shading = img[:self.height, self.width:, :] 44 | albedo = img[self.height:self.height * 2, :self.width, :] 45 | normal = img[self.height:self.height * 2, self.width:, :] 46 | depth = (img[self.height * 2:, :self.width, 0] + img[self.height * 2:, :self.width, 1] \ 47 | + img[self.height * 2:, :self.width, 2]) / 3 48 | roughness = (img[self.height * 2:, self.width:, 0] + img[self.height * 2:, self.width:, 1] \ 49 | + img[self.height * 2:, self.width:, 2]) / 3 50 | depth = np.expand_dims(depth, axis=2) 51 | roughness = np.expand_dims(roughness, axis=2) 52 | 53 | ray_shading = ray_shading.astype(np.float) / 255.0 54 | shading = shading.astype(np.float) / 255.0 55 | normal = normal.astype(np.float) / 255.0 56 | albedo = albedo.astype(np.float) / 255.0 57 | depth = depth.astype(np.float) / 255.0 58 | roughness = roughness.astype(np.float) / 255.0 59 | 60 | A[i, :, :, :3] = shading 61 | A[i, :, :, 3:6] = normal 62 | A[i, :, :, 6:7] = depth 63 | A[i, :, :, 7:8] = roughness 64 | 65 | B[i, :, :, :] = ray_shading 66 | ALBEDO[i, :, :, :] = albedo 67 | 68 | A = torch.from_numpy(A) 69 | B = torch.from_numpy(B) 70 | ALBEDO = torch.from_numpy(ALBEDO) 71 | 72 | A = A.permute((0, 3, 1, 2)) 73 | B = B.permute((0, 3, 1, 2)) 74 | ALBEDO = ALBEDO.permute((0, 3, 1, 2)) 75 | 76 | return { 77 | 'A': A.type(torch.float).to('cuda:0'), 78 | 'B': B.type(torch.float).to('cuda:0'), 79 | 'ALBEDO': ALBEDO.type(torch.float).to('cuda:0') 80 | } 81 | 82 | 83 | def __len__(self): 84 | return len(self.images) 85 | 86 | def np_normalize(self, img): 87 | return (img - img.min()) / (img.max() - img.min()) 88 | 89 | def save_image(self, img, img_name): 90 | img = torch.squeeze(img.detach(), dim=0) * 255.0 91 | img = img.permute((1, 2, 0)) 92 | img = img.cpu().numpy().astype(np.uint8) 93 | 94 | cv2.imwrite(img_name, img) -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch, os, sys, cv2 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | import torch.optim as optim 6 | 7 | from torch.utils.data import Dataset, DataLoader 8 | from torch.nn import functional as func 9 | from PIL import Image 10 | 11 | import torchvision.transforms as transforms 12 | import numpy as np 13 | import torch 14 | 15 | 16 | def LoG(img): 17 | weight = [ 18 | [0, 0, 1, 0, 0], 19 | [0, 1, 2, 1, 0], 20 | [1, 2, -16, 2, 1], 21 | [0, 1, 2, 1, 0], 22 | [0, 0, 1, 0, 0] 23 | ] 24 | weight = np.array(weight) 25 | 26 | weight_np = np.zeros((1, 1, 5, 5)) 27 | weight_np[0, 0, :, :] = weight 28 | weight_np = np.repeat(weight_np, img.shape[1], axis=1) 29 | weight_np = np.repeat(weight_np, img.shape[0], axis=0) 30 | 31 | weight = torch.from_numpy(weight_np).type(torch.FloatTensor).to('cuda:0') 32 | 33 | return func.conv2d(img, weight, padding=1) 34 | 35 | def HFEN(output, target): 36 | return torch.sum(torch.pow(LoG(output) - LoG(target), 2)) / torch.sum(torch.pow(LoG(target), 2)) 37 | 38 | 39 | def l1_norm(output, target): 40 | return torch.sum(torch.abs(output - target)) / torch.numel(output) 41 | 42 | def get_temporal_data(output, target): 43 | final_output = output.clone() 44 | final_target = target.clone() 45 | final_output.fill_(0) 46 | final_target.fill_(0) 47 | 48 | for i in range(1, 7): 49 | final_output[:, i, :, :, :] = output[:, i, :, :] - output[:, i-1, :, :] 50 | final_target[:, i, :, :, :] = target[:, i, :, :] - target[:, i-1, :, :] 51 | 52 | return final_output, final_target 53 | 54 | def temporal_norm(output, target): 55 | return torch.sum(torch.abs(output - target)) / torch.numel(output) 56 | 57 | def loss_func(output, temporal_output, target, temporal_target): 58 | ls = l1_norm(output, target) 59 | lg = HFEN(output, target) 60 | lt = temporal_norm(temporal_output, temporal_target) 61 | 62 | return 0.8 * ls + 0.1 * lg + 0.1 * lt, ls, lg, lt -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch, os, sys, cv2 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | import torch.optim as optim 6 | 7 | from torch.utils.data import Dataset, DataLoader 8 | from torch.nn import functional as func 9 | from PIL import Image 10 | 11 | import torchvision.transforms as transforms 12 | import numpy as np 13 | import torch 14 | 15 | 16 | class RecurrentBlock(nn.Module): 17 | 18 | def __init__(self, input_nc, output_nc, downsampling=False, bottleneck=False, upsampling=False): 19 | super(RecurrentBlock, self).__init__() 20 | 21 | self.input_nc = input_nc 22 | self.output_nc = output_nc 23 | 24 | self.downsampling = downsampling 25 | self.upsampling = upsampling 26 | self.bottleneck = bottleneck 27 | 28 | self.hidden = None 29 | 30 | if self.downsampling: 31 | self.l1 = nn.Sequential( 32 | nn.Conv2d(input_nc, output_nc, 3, padding=1), 33 | nn.LeakyReLU(negative_slope=0.1) 34 | ) 35 | self.l2 = nn.Sequential( 36 | nn.Conv2d(2 * output_nc, output_nc, 3, padding=1), 37 | nn.LeakyReLU(negative_slope=0.1), 38 | nn.Conv2d(output_nc, output_nc, 3, padding=1), 39 | nn.LeakyReLU(negative_slope=0.1), 40 | ) 41 | elif self.upsampling: 42 | self.l1 = nn.Sequential( 43 | nn.Upsample(scale_factor=2, mode='nearest'), 44 | nn.Conv2d(2 * input_nc, output_nc, 3, padding=1), 45 | nn.LeakyReLU(negative_slope=0.1), 46 | nn.Conv2d(output_nc, output_nc, 3, padding=1), 47 | nn.LeakyReLU(negative_slope=0.1), 48 | ) 49 | elif self.bottleneck: 50 | self.l1 = nn.Sequential( 51 | nn.Conv2d(input_nc, output_nc, 3, padding=1), 52 | nn.LeakyReLU(negative_slope=0.1) 53 | ) 54 | self.l2 = nn.Sequential( 55 | nn.Conv2d(2 * output_nc, output_nc, 3, padding=1), 56 | nn.LeakyReLU(negative_slope=0.1), 57 | nn.Conv2d(output_nc, output_nc, 3, padding=1), 58 | nn.LeakyReLU(negative_slope=0.1), 59 | ) 60 | 61 | def forward(self, inp): 62 | 63 | if self.downsampling: 64 | op1 = self.l1(inp) 65 | op2 = self.l2(torch.cat((op1, self.hidden), dim=1)) 66 | 67 | self.hidden = op2 68 | 69 | return op2 70 | elif self.upsampling: 71 | op1 = self.l1(inp) 72 | 73 | return op1 74 | elif self.bottleneck: 75 | op1 = self.l1(inp) 76 | op2 = self.l2(torch.cat((op1, self.hidden), dim=1)) 77 | 78 | self.hidden = op2 79 | 80 | return op2 81 | 82 | def reset_hidden(self, inp, dfac): 83 | size = list(inp.size()) 84 | size[1] = self.output_nc 85 | size[2] /= dfac 86 | size[3] /= dfac 87 | 88 | self.hidden_size = size 89 | self.hidden = torch.zeros(*(size)).to('cuda:0') 90 | 91 | 92 | 93 | class RecurrentAE(nn.Module): 94 | 95 | def __init__(self, input_nc): 96 | super(RecurrentAE, self).__init__() 97 | 98 | self.d1 = RecurrentBlock(input_nc=input_nc, output_nc=32, downsampling=True) 99 | self.d2 = RecurrentBlock(input_nc=32, output_nc=43, downsampling=True) 100 | self.d3 = RecurrentBlock(input_nc=43, output_nc=57, downsampling=True) 101 | self.d4 = RecurrentBlock(input_nc=57, output_nc=76, downsampling=True) 102 | self.d5 = RecurrentBlock(input_nc=76, output_nc=101, downsampling=True) 103 | 104 | self.bottleneck = RecurrentBlock(input_nc=101, output_nc=101, bottleneck=True) 105 | 106 | self.u5 = RecurrentBlock(input_nc=101, output_nc=76, upsampling=True) 107 | self.u4 = RecurrentBlock(input_nc=76, output_nc=57, upsampling=True) 108 | self.u3 = RecurrentBlock(input_nc=57, output_nc=43, upsampling=True) 109 | self.u2 = RecurrentBlock(input_nc=43, output_nc=32, upsampling=True) 110 | self.u1 = RecurrentBlock(input_nc=32, output_nc=3, upsampling=True) 111 | 112 | def set_input(self, inp): 113 | self.inp = inp['A'] 114 | 115 | def forward(self): 116 | d1 = func.max_pool2d(input=self.d1(self.inp), kernel_size=2) 117 | d2 = func.max_pool2d(input=self.d2(d1), kernel_size=2) 118 | d3 = func.max_pool2d(input=self.d3(d2), kernel_size=2) 119 | d4 = func.max_pool2d(input=self.d4(d3), kernel_size=2) 120 | d5 = func.max_pool2d(input=self.d5(d4), kernel_size=2) 121 | 122 | b = self.bottleneck(d5) 123 | 124 | u5 = self.u5(torch.cat((b, d5), dim=1)) 125 | u4 = self.u4(torch.cat((u5, d4), dim=1)) 126 | u3 = self.u3(torch.cat((u4, d3), dim=1)) 127 | u2 = self.u2(torch.cat((u3, d2), dim=1)) 128 | u1 = self.u1(torch.cat((u2, d1), dim=1)) 129 | 130 | return u1 131 | 132 | def reset_hidden(self): 133 | self.d1.reset_hidden(self.inp, dfac=1) 134 | self.d2.reset_hidden(self.inp, dfac=2) 135 | self.d3.reset_hidden(self.inp, dfac=4) 136 | self.d4.reset_hidden(self.inp, dfac=8) 137 | self.d5.reset_hidden(self.inp, dfac=16) 138 | 139 | self.bottleneck.reset_hidden(self.inp, dfac=32) 140 | 141 | self.u4.reset_hidden(self.inp, dfac=16) 142 | self.u3.reset_hidden(self.inp, dfac=8) 143 | self.u5.reset_hidden(self.inp, dfac=4) 144 | self.u2.reset_hidden(self.inp, dfac=2) 145 | self.u1.reset_hidden(self.inp, dfac=1) 146 | 147 | 148 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch, os, sys, cv2 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | import torch.optim as optim 6 | 7 | from torch.utils.data import Dataset, DataLoader 8 | from torch.nn import functional as func 9 | from PIL import Image 10 | 11 | import torchvision.transforms as transforms 12 | import numpy as np 13 | import torch, argparse, pdb 14 | 15 | from model import * 16 | from data import * 17 | from losses import * 18 | 19 | 20 | def load_checkpoint(filename): 21 | chkpoint = torch.load(filename); 22 | model = RecurrentAE(8); 23 | model.to('cuda:0'); 24 | optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.99)) 25 | 26 | epoch = chkpoint['epoch']; 27 | model.load_state_dict(chkpoint['state_dict']); 28 | optimizer.load_state_dict(chkpoint['optimizer']); 29 | 30 | return model, optimizer, int(epoch); 31 | 32 | 33 | if __name__ == '__main__': 34 | 35 | parser = argparse.ArgumentParser(description='RecurentAE, SIGGRAPH \'17') 36 | parser.add_argument('--data_dir', type=str, help='Data directory') 37 | parser.add_argument('--output_dir', type=str, help='Directory to save output') 38 | parser.add_argument('--checkpoint', type=str, help='Checkpoint to load') 39 | 40 | args = parser.parse_args() 41 | 42 | model, optimizer, epoch = load_checkpoint(args.checkpoint) 43 | 44 | size = (256, 256) 45 | width = size[0] 46 | height = size[1] 47 | 48 | data_loader = RAEData('%s/test' % args.data_dir, size) 49 | dataset = DataLoader(data_loader, batch_size=1, num_workers=0, shuffle=False) 50 | 51 | for i, item in enumerate(dataset): 52 | 53 | os.system('mkdir -p %s/seq_%s' % (args.output_dir, i)) 54 | for j in range(0, 7): 55 | inp = item['A'] 56 | gt = item['B'] 57 | 58 | inp = inp[:, j, :, :, :] 59 | gt = gt[:, j, :, :, :] 60 | 61 | final_inp = { 62 | 'A': inp, 63 | 'B': gt 64 | } 65 | 66 | model.set_input(final_inp) 67 | if j == 0: 68 | model.reset_hidden() 69 | 70 | output = model() 71 | 72 | albedo = item['ALBEDO'].clone() 73 | albedo = albedo[:, j, :, :, :] 74 | albedo = torch.squeeze(albedo.detach(), dim=0) * 255.0 75 | albedo = albedo.permute((1, 2, 0)) 76 | albedo = albedo.cpu().numpy() 77 | 78 | ray = final_inp['B'].clone() 79 | ray = torch.squeeze(ray, dim=0) 80 | ray = ray[:3, :, :] 81 | ray = ray.permute((1, 2, 0)) 82 | ray = ray.cpu().numpy() 83 | ray *= 255.0 84 | 85 | output = torch.squeeze(output.detach(), dim=0) 86 | output = output.permute((1, 2, 0)) 87 | output = output.cpu().numpy() 88 | output *= 255.0 89 | 90 | og = final_inp['A'] 91 | og = torch.squeeze(og.detach(), dim=0) * 255.0 92 | og = og.permute((1, 2, 0)) 93 | og = og.cpu().numpy() 94 | 95 | final = np.zeros((height, width * 4, 3), dtype=np.float) 96 | final[:, :width, :] = og[:, :, :3] 97 | final[:, width:width * 2, :] = albedo 98 | final[:, width * 2:width * 3, :] = output 99 | final[:, width * 3:width * 4, :] = ray 100 | 101 | cv2.imwrite('%s/seq_%s/%s.png' % (args.output_dir, i, j), final) 102 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch, os, sys, cv2 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | import torch.optim as optim 6 | 7 | from torch.utils.data import Dataset, DataLoader 8 | from torch.nn import functional as func 9 | from PIL import Image 10 | 11 | import torchvision.transforms as transforms 12 | import numpy as np 13 | import torch, argparse, pdb 14 | 15 | from model import * 16 | from data import * 17 | from losses import * 18 | 19 | 20 | def save_checkpoint(state, filename): 21 | torch.save(state, filename); 22 | 23 | def train_sequence(model, sequence): 24 | output_final = sequence['B'].clone() 25 | output_final.fill_(0) 26 | target_final = sequence['B'].clone() 27 | target_final.fill_(0) 28 | 29 | inp = sequence['A'] 30 | target = sequence['B'] 31 | 32 | loss_final = 0 33 | ls_final = 0 34 | lg_final = 0 35 | lt_final = 0 36 | 37 | for j in range(0, 7): 38 | inpi = inp[:, j, :, :, :] 39 | gti = target[:, j, :, :, :] 40 | 41 | final_inp = { 42 | 'A': inpi, 43 | 'B': gti 44 | } 45 | 46 | model.set_input(final_inp) 47 | if j == 0: 48 | model.reset_hidden() 49 | 50 | output = model() 51 | output_final[:, j, :, :, :] = output 52 | target_final[:, j, :, :, :] = gti 53 | 54 | temporal_output, temporal_target = get_temporal_data(output_final, target_final) 55 | 56 | for j in range(0, 7): 57 | output = output_final[:, j, :, :, :] 58 | target = target_final[:, j, :, :, :] 59 | t_output = temporal_output[:, j, :, :, :] 60 | t_target = temporal_target[:, j, :, :, :] 61 | 62 | l, ls, lg, lt = loss_func(output, t_output, target, t_target) 63 | loss_final += l 64 | ls_final += ls 65 | lg_final += lg 66 | lt_final += lt 67 | 68 | return loss_final, ls_final, lg_final, lt_final 69 | 70 | 71 | def train(model, dataset, optimizer, epoch): 72 | 73 | total_loss = 0 74 | total_loss_num = 0 75 | 76 | for i, item in enumerate(dataset): 77 | optimizer.zero_grad() 78 | loss_final, ls_final, lg_final, lt_final = train_sequence(model, item) 79 | 80 | loss_final.backward(retain_graph=False) 81 | optimizer.step() 82 | 83 | total_loss += loss_final.item() 84 | total_loss_num += 1 85 | 86 | if i % 50 == 0: 87 | print('[Epoch : %s] [%s/%s] Loss => %s , L1 => %s , HFEN => %s , TEMPORAL => %s' % \ 88 | (epoch+1, (i+1), len(data_loader), loss_final.item(), ls_final.item(), \ 89 | lg_final.item(), lt_final.item())) 90 | sys.stdout.flush() 91 | 92 | total_loss /= total_loss_num 93 | 94 | return total_loss 95 | 96 | 97 | 98 | if __name__ == '__main__': 99 | 100 | parser = argparse.ArgumentParser(description='RecurentAE, SIGGRAPH \'17') 101 | parser.add_argument('--data_dir', type=str, help='Data directory') 102 | parser.add_argument('--save_dir', type=str, help='Model chekpoint saving directory') 103 | parser.add_argument('--name', type=str, help='Experiment Name') 104 | parser.add_argument('--epochs', type=int, help='Number of epochs to train') 105 | 106 | args = parser.parse_args() 107 | 108 | data_loader = RAEData('%s/train' % args.data_dir, (256, 256)) 109 | dataset = DataLoader(data_loader, batch_size=1, num_workers=0, shuffle=True) 110 | 111 | model = RecurrentAE(8) 112 | model.to('cuda:0') 113 | print(model) 114 | 115 | optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.99)) 116 | 117 | for epoch in range(args.epochs): 118 | print('\nEpoch %s' % (epoch+1)) 119 | 120 | total_loss = train(model, dataset, optimizer, epoch) 121 | 122 | print('Epoch %s loss => %s' % (epoch+1, total_loss)) 123 | sys.stdout.flush() 124 | 125 | if epoch % 100 == 0: 126 | print('SAVING MODEL AT EPOCH %s' % (epoch+1)) 127 | save_checkpoint({ 128 | 'epoch': epoch+1, 129 | 'state_dict':model.state_dict(), 130 | 'optimizer':optimizer.state_dict(), 131 | }, '%s/%s_%s.pt' % (args.save_dir, args.name, epoch+1)) 132 | 133 | 134 | save_checkpoint({ 135 | 'epoch': args.epochs, 136 | 'state_dict':model.state_dict(), 137 | 'optimizer':optimizer.state_dict(), 138 | }, '%s/%s_%s.pt' % (args.save_dir, args.name, args.epochs)) 139 | --------------------------------------------------------------------------------