├── README.md ├── data_loader.py ├── network.py ├── test.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # motion_magnification_pytorch 2 | Reproducing Learning-based Video Motion Magnification in pytorch 3 | 4 | We write the code with reference to the original [tensorflow implementation](https://github.com/12dmodel/deep_motion_mag) from the authors 5 | 6 | This code is tested on python 3.5 and pytorch 0.4.1. 7 | 8 | # Data 9 | We use the dataset opened by the authors. 10 | Please refer to the [authors repository](https://github.com/12dmodel/deep_motion_mag). 11 | 12 | # Train the network 13 | python train.py [--additional option] 14 | 15 | ex) to run training on gpu number 0 16 | 17 | python train.py --gpu 0 18 | 19 | # reference 20 | 1) https://github.com/12dmodel/deep_motion_mag 21 | 2) Oh, Tae-Hyun, et al. "Learning-based Video Motion Magnification." arXiv preprint arXiv:1804.02684 (2018). 22 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets.folder import * 2 | import numpy as np 3 | import torch 4 | 5 | class ImageFromFolder(ImageFolder): 6 | 7 | def __init__(self, root, num_data=100000, preprocessing=False, transform=None, target_transform=None, 8 | loader=default_loader): 9 | 10 | mag = np.loadtxt(os.path.join(root, 'train_mf.txt')) 11 | #print(mag[:10], mag.shape) 12 | 13 | imgs = [(os.path.join(root,'amplified','%06d.png'%(i)), 14 | os.path.join(root,'frameA','%06d.png'%(i)), 15 | os.path.join(root,'frameB','%06d.png'%(i)), 16 | os.path.join(root,'frameC','%06d.png'%(i)), 17 | mag[i]) for i in range(num_data)] 18 | 19 | 20 | self.root = root 21 | self.imgs = imgs 22 | self.samples = self.imgs 23 | self.transform = transform 24 | self.target_transform = target_transform 25 | self.loader = loader 26 | self.preproc = preprocessing 27 | 28 | def __getitem__(self, index): 29 | """ 30 | Args: 31 | index (int): Index 32 | Returns: 33 | tuple: (sample, target) where target is class_index of the target class. 34 | """ 35 | pathAmp, pathA, pathB, pathC, target = self.samples[index] 36 | sampleAmp, sampleA, sampleB, sampleC = np.array(self.loader(pathAmp)), np.array(self.loader(pathA)), np.array(self.loader(pathB)), np.array(self.loader(pathC)) 37 | 38 | # normalize 39 | sampleAmp = sampleAmp/127.5 - 1.0 40 | sampleA = sampleA/127.5 - 1.0 41 | sampleB = sampleB/127.5 - 1.0 42 | sampleC = sampleC/127.5 - 1.0 43 | 44 | # preprocessing 45 | if self.preproc: 46 | sampleAmp = preproc_poisson_noise(sampleAmp) 47 | sampleA = preproc_poisson_noise(sampleA) 48 | sampleB = preproc_poisson_noise(sampleB) 49 | sampleC = preproc_poisson_noise(sampleC) 50 | """ 51 | if self.transform is not None: 52 | sample = self.transform(sample) 53 | if self.target_transform is not None: 54 | target = self.target_transform(target) 55 | """ 56 | 57 | # to torch tensor 58 | sampleAmp, sampleA, sampleB, sampleC = torch.from_numpy(sampleAmp), torch.from_numpy(sampleA), torch.from_numpy(sampleB), torch.from_numpy(sampleC) 59 | sampleAmp = sampleAmp.float() 60 | sampleA = sampleA.float() 61 | sampleB = sampleB.float() 62 | sampleC = sampleC.float() 63 | 64 | target = torch.from_numpy(np.array(target)).float() 65 | 66 | # permute from HWC to CHW 67 | sampleAmp = sampleAmp.permute(2,0,1) 68 | sampleA = sampleA.permute(2,0,1) 69 | sampleB = sampleB.permute(2,0,1) 70 | sampleC = sampleC.permute(2,0,1) 71 | 72 | return sampleAmp, sampleA, sampleB, sampleC, target 73 | 74 | def preproc_poisson_noise(image): 75 | nn = np.random.uniform(0, 0.3) # 0.3 76 | n = np.random.normal(0.0, 1.0, image.shape) 77 | n_str = np.sqrt(image + 1.0) / np.sqrt(127.5) 78 | return image + nn * n * n_str 79 | 80 | class ImageFromFolderTest(ImageFolder): 81 | 82 | def __init__(self, root, mag=10.0, mode='static', num_data=300, preprocessing=False, transform=None, target_transform=None, loader=default_loader): 83 | 84 | if mode=='static' or mode=='temporal': 85 | imgs = [(root+'_%06d.png'%(1), 86 | root+'_%06d.png'%(i+2), 87 | mag) for i in range(num_data)] 88 | elif mode=='dynamic': 89 | imgs = [(root+'_%06d.png'%(i+1), 90 | root+'_%06d.png'%(i+2), 91 | mag) for i in range(num_data)] 92 | else: 93 | raise ValueError("Unsupported modes %s"%(mode)) 94 | 95 | 96 | self.root = root 97 | self.imgs = imgs 98 | self.samples = self.imgs 99 | self.transform = transform 100 | self.target_transform = target_transform 101 | self.loader = loader 102 | self.preproc = preprocessing 103 | 104 | def __getitem__(self, index): 105 | """ 106 | Args: 107 | index (int): Index 108 | Returns: 109 | tuple: (sample, target) where target is class_index of the target class. 110 | """ 111 | pathA, pathB, target = self.samples[index] 112 | sampleA, sampleB = np.array(self.loader(pathA)), np.array(self.loader(pathB)) 113 | 114 | # normalize 115 | sampleA = sampleA/127.5 - 1.0 116 | sampleB = sampleB/127.5 - 1.0 117 | 118 | # preprocessing 119 | if self.preproc: 120 | sampleA = preproc_poisson_noise(sampleA) 121 | sampleB = preproc_poisson_noise(sampleB) 122 | """ 123 | if self.transform is not None: 124 | sample = self.transform(sample) 125 | if self.target_transform is not None: 126 | target = self.target_transform(target) 127 | """ 128 | 129 | # to torch tensor 130 | sampleA, sampleB = torch.from_numpy(sampleA), torch.from_numpy(sampleB) 131 | sampleA = sampleA.float() 132 | sampleB = sampleB.float() 133 | 134 | target = torch.from_numpy(np.array(target)).float() 135 | 136 | # permute from HWC to CHW 137 | sampleA = sampleA.permute(2,0,1) 138 | sampleB = sampleB.permute(2,0,1) 139 | 140 | return sampleA, sampleB, target 141 | # Test 142 | if __name__ == '__main__': 143 | 144 | dataset = ImageFromFolder('./../data/train', num_data=100, preprocessing=True) 145 | 146 | imageAmp, imageA, imageB, imageC, mag = dataset.__getitem__(0) 147 | 148 | import matplotlib.pyplot as plt 149 | plt.imshow(((imageA+1.0)*127.5).astype(np.uint8)) 150 | plt.show() 151 | #print(imgAmp.shape, imgA.shape, imgB.shape, imgC.shape, mag) 152 | #print(img, target) 153 | 154 | 155 | 156 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | """ 2 | need to check 3 | 1) conv initializer (std) 4 | 2) bias 5 | """ 6 | 7 | import math 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | def _make_layer(block, in_planes, out_planes, num_layers, kernel_size=3, stride=1): 13 | layers = [] 14 | for i in range(num_layers): 15 | layers.append(block(in_planes, out_planes, kernel_size, stride)) 16 | return nn.Sequential(*layers) 17 | 18 | class ResBlock(nn.Module): 19 | def __init__(self, in_planes, output_planes, kernel_size=3, stride=1): 20 | super(ResBlock, self).__init__() 21 | p = (kernel_size-1)//2 22 | self.pad1 = nn.ReflectionPad2d(p) 23 | self.conv1 = nn.Conv2d(in_planes, output_planes, kernel_size=kernel_size, 24 | stride=stride, bias=False) 25 | self.pad2 = nn.ReflectionPad2d(p) 26 | self.conv2 = nn.Conv2d(in_planes, output_planes, kernel_size=kernel_size, 27 | stride=stride, bias=False) 28 | self.relu = nn.ReLU(inplace=True) 29 | 30 | def forward(self, x): 31 | y = self.relu(self.conv1(self.pad1(x))) 32 | y = self.conv2(self.pad2(y)) 33 | return y + x 34 | 35 | class ConvBlock(nn.Module): 36 | def __init__(self, in_planes, output_planes, kernel_size=7, stride=1): 37 | super(ConvBlock, self).__init__() 38 | p=3 39 | self.pad1 = nn.ReflectionPad2d(p) 40 | self.conv1 = nn.Conv2d(in_planes, output_planes, kernel_size=kernel_size, 41 | stride=stride, bias=False) 42 | self.relu = nn.ReLU(inplace=True) 43 | 44 | def forward(self, x): 45 | return self.relu(self.conv1(self.pad1(x))) 46 | 47 | class ConvBlockAfter(nn.Module): 48 | def __init__(self, in_planes, output_planes, kernel_size=3, stride=1): 49 | super(ConvBlockAfter, self).__init__() 50 | p=1 51 | self.pad1 = nn.ReflectionPad2d(p) 52 | self.conv1 = nn.Conv2d(in_planes, output_planes, kernel_size=kernel_size, 53 | stride=stride, bias=False) 54 | 55 | def forward(self, x): 56 | return self.conv1(self.pad1(x)) 57 | 58 | class Encoder(nn.Module): 59 | def __init__(self, num_resblk): 60 | super(Encoder, self).__init__() 61 | # common representation 62 | self.pad1 = nn.ReflectionPad2d(3) 63 | self.conv1 = nn.Conv2d(3, 16, kernel_size=7, stride=1, bias=False) 64 | self.pad2 = nn.ReflectionPad2d(1) 65 | self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, bias=False) 66 | self.resblks = _make_layer(ResBlock, 32, 32, num_resblk) 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | # texture representation 70 | self.pad1_text = nn.ReflectionPad2d(1) 71 | self.conv1_text = nn.Conv2d(32, 32, kernel_size=3, stride=2, bias=False) 72 | self.resblks_text = _make_layer(ResBlock, 32, 32, 2) 73 | 74 | # shape representation 75 | self.pad1_shape = nn.ReflectionPad2d(1) 76 | self.conv1_shape = nn.Conv2d(32, 32, kernel_size=3, stride=1, bias=False) 77 | self.resblks_shape = _make_layer(ResBlock, 32, 32, 2) 78 | 79 | def forward(self, x): 80 | c = self.relu(self.conv1(self.pad1(x))) 81 | c = self.relu(self.conv2(self.pad2(c))) 82 | c = self.resblks(c) 83 | 84 | v = self.relu(self.conv1_text(self.pad1_text(c))) 85 | v = self.resblks_text(v) 86 | 87 | m = self.relu(self.conv1_shape(self.pad1_shape(c))) 88 | m = self.resblks_shape(m) 89 | 90 | return v, m # v: texture, m: shape 91 | 92 | class Manipulator(nn.Module): 93 | def __init__(self, num_resblk): 94 | super(Manipulator, self).__init__() 95 | self.convblks = _make_layer(ConvBlock, 32, 32, 1, kernel_size=7, stride=1) 96 | self.convblks_after = _make_layer(ConvBlockAfter, 32, 32, 1, kernel_size=3, stride=1) 97 | self.resblks = _make_layer(ResBlock, 32, 32, num_resblk, kernel_size=3, stride=1) 98 | 99 | def forward(self, x_a, x_b, amp): 100 | diff = x_b - x_a 101 | diff = self.convblks(diff) 102 | diff = (amp - 1.0) * diff 103 | diff = self.convblks_after(diff) 104 | diff = self.resblks(diff) 105 | 106 | return x_b + diff 107 | 108 | class Decoder(nn.Module): 109 | def __init__(self, num_resblk): 110 | super(Decoder, self).__init__() 111 | # texture 112 | self.upsample_text = nn.UpsamplingNearest2d(scale_factor=2) 113 | self.pad_text = nn.ReflectionPad2d(1) 114 | self.conv1_text = nn.Conv2d(32, 32, kernel_size=3, stride=1, bias=False) 115 | self.relu = nn.ReLU(inplace=True) 116 | # common blocks 117 | self.resblks = _make_layer(ResBlock, 64, 64, num_resblk) 118 | self.upsample = nn.UpsamplingNearest2d(scale_factor=2) 119 | self.pad1 = nn.ReflectionPad2d(1) 120 | self.conv1 = nn.Conv2d(64, 32, kernel_size=3, stride=1, bias=False) 121 | self.pad2 = nn.ReflectionPad2d(3) 122 | self.conv2 = nn.Conv2d(32, 3, kernel_size=7, stride=1, bias=False) 123 | 124 | 125 | def forward(self, v, m): # v: texture, m: shape 126 | v = self.relu(self.conv1_text(self.pad_text(self.upsample_text(v)))) 127 | 128 | c = torch.cat([v,m],1) 129 | c = self.resblks(c) 130 | c = self.upsample(c) 131 | c = self.relu(self.conv1(self.pad1(c))) 132 | c = self.conv2(self.pad2(c)) 133 | 134 | return c 135 | 136 | class MagNet(nn.Module): 137 | def __init__(self, num_resblk_enc=3, num_resblk_man=1, num_resblk_dec=9): 138 | super(MagNet, self).__init__() 139 | self.encoder = Encoder(num_resblk=num_resblk_enc) 140 | self.manipulator = Manipulator(num_resblk=num_resblk_man) 141 | self.decoder = Decoder(num_resblk=num_resblk_dec) 142 | 143 | # initialize conv weights(xavier) 144 | #for m in self.modules(): 145 | # if isinstance(m, nn.Conv2d): 146 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 147 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 148 | 149 | def forward(self, x_a, x_b, x_c, amp): # v: texture, m: shape 150 | v_a, m_a = self.encoder(x_a) 151 | v_b, m_b = self.encoder(x_b) 152 | v_c, m_c = self.encoder(x_c) 153 | 154 | m_enc = self.manipulator(m_a, m_b, amp) 155 | 156 | y_hat = self.decoder(v_b, m_enc) 157 | 158 | return y_hat, (v_a, m_a), (v_b, m_b), (v_c, m_c) 159 | 160 | if __name__ == '__main__': 161 | model = MagNet() 162 | print(model) 163 | 164 | #model = torch.nn.DataParallel(model).cuda() 165 | #print(model) 166 | 167 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | import torch.optim 11 | import torch.utils.data as data 12 | #import torchvision.transforms as transforms 13 | import torchvision.datasets as datasets 14 | 15 | from network import MagNet 16 | from data_loader import ImageFromFolderTest 17 | from utils import AverageMeter 18 | import numpy as np 19 | from PIL import Image 20 | from collections import OrderedDict 21 | 22 | parser = argparse.ArgumentParser(description='PyTorch Deep Video Magnification') 23 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 24 | help='number of data loading workers (default: 0)') 25 | parser.add_argument('-b', '--batch-size', default=4, type=int, 26 | metavar='N', help='mini-batch size (default: 4)') 27 | parser.add_argument('--print-freq', '-p', default=100, type=int, 28 | metavar='N', help='print frequency (default: 100)') 29 | parser.add_argument('--load_ckpt', type=str, metavar='PATH', 30 | help='path to load checkpoint') 31 | parser.add_argument('--save_dir', default='demo', type=str, metavar='PATH', 32 | help='path to save generated frames (default: demo)') 33 | parser.add_argument('--gpu',default=0, type=str, help='cuda_visible_devices') 34 | 35 | parser.add_argument('-m', '--amp', default=20.0, type=float, 36 | help='amplification factor (default: 10.0)') 37 | parser.add_argument('--mode', default='static', type=str, choices=['static', 'dynamic','temporal'], 38 | help='amplification mode (static, dynamic, temporal)') 39 | parser.add_argument('--video_path', default='./../demo_video/baby', type=str, 40 | help='path to video frames') 41 | parser.add_argument('--num_data', default=300, type=int, 42 | help='number of frames') 43 | #for temporal filter 44 | parser.add_argument('--fh', default=0.4, type=float) 45 | parser.add_argument('--fl', default=0.04, type=float) 46 | #parser.add_argument('--fs', default=30, type=int) 47 | #parser.add_argument('--ntab', default=2, type=int) 48 | 49 | args = parser.parse_args() 50 | 51 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 52 | os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu) 53 | 54 | def main(): 55 | global args 56 | args = parser.parse_args() 57 | print(args) 58 | 59 | # create model 60 | model = MagNet().cuda() 61 | #model = torch.nn.DataParallel(model).cuda() 62 | print(model) 63 | 64 | # load checkpoint 65 | if os.path.isfile(args.load_ckpt): 66 | print("=> loading checkpoint '{}'".format(args.load_ckpt)) 67 | checkpoint = torch.load(args.load_ckpt) 68 | args.start_epoch = checkpoint['epoch'] 69 | 70 | # to load state_dict trained with DataParallel to model without DataParallel 71 | new_state_dict = OrderedDict() 72 | state_dict = checkpoint['state_dict'] 73 | for k, v in state_dict.items(): 74 | name = k[7:] 75 | new_state_dict[name]=v 76 | model.load_state_dict(new_state_dict) 77 | print("=> loaded checkpoint '{}' (epoch {})" 78 | .format(args.load_ckpt, checkpoint['epoch'])) 79 | else: 80 | print("=> no checkpoint found at '{}'".format(args.load_ckpt)) 81 | assert(False) 82 | 83 | 84 | # check saving directory 85 | save_dir = args.save_dir 86 | if not os.path.exists(save_dir): 87 | os.makedirs(save_dir) 88 | print(save_dir) 89 | 90 | # cudnn enable 91 | cudnn.benchmark = True 92 | 93 | # data loader 94 | dataset_mag = ImageFromFolderTest(args.video_path, mag=args.amp, mode=args.mode, num_data=args.num_data, preprocessing=False) 95 | data_loader = data.DataLoader(dataset_mag, 96 | batch_size=args.batch_size, 97 | shuffle=False, 98 | num_workers=args.workers, 99 | pin_memory=True) 100 | 101 | 102 | # generate frames 103 | mag_frames=[] 104 | model.eval() 105 | 106 | # static mode or dynamic mode 107 | if args.mode=='static' or args.mode=='dynamic': 108 | for i, (xa, xb, amp_factor) in enumerate(data_loader): 109 | if i%10==0: print('processing sample %d'%i) 110 | amp_factor = amp_factor.unsqueeze(1).unsqueeze(1).unsqueeze(1) 111 | 112 | xa=xa.cuda() 113 | xb=xb.cuda() 114 | amp_factor=amp_factor.cuda() 115 | 116 | y_hat, _, _, _ = model(xa, xb, xb, amp_factor) 117 | 118 | if i==0: 119 | # back to image scale (0-255) 120 | tmp = xa.permute(0,2,3,1).cpu().detach().numpy() 121 | tmp = np.clip(tmp, -1.0, 1.0) 122 | tmp = ((tmp + 1.0) * 127.5).astype(np.uint8) 123 | mag_frames.append(tmp) 124 | 125 | # back to image scale (0-255) 126 | y_hat = y_hat.permute(0,2,3,1).cpu().detach().numpy() 127 | y_hat = np.clip(y_hat, -1.0, 1.0) 128 | y_hat = ((y_hat + 1.0) * 127.5).astype(np.uint8) 129 | mag_frames.append(y_hat) 130 | 131 | else: 132 | # temporal mode (difference of IIR) 133 | # copy filter coefficients and follow codes from https://github.com/12dmodel/deep_motion_mag 134 | filter_b = [args.fh-args.fl, args.fl-args.fh] 135 | filter_a = [-1.0*(2.0 - args.fh - args.fl), (1.0 - args.fl) * (1.0 - args.fh)] 136 | 137 | x_state = [] 138 | y_state = [] 139 | for i, (xa, xb, amp_factor) in enumerate(data_loader): 140 | if i%10==0: print('processing sample %d'%i) 141 | amp_factor = amp_factor.unsqueeze(1).unsqueeze(1).unsqueeze(1) 142 | 143 | xa=xa.cuda() 144 | xb=xb.cuda() 145 | amp_factor=amp_factor.cuda() 146 | 147 | vb, mb = model.encoder(xb) 148 | x_state.insert(0,mb.detach()) 149 | while len(x_state)len(filter_b): 152 | x_state = x_state[:len(filter_b)] 153 | y = torch.zeros_like(mb) 154 | for i in range(len(x_state)): 155 | y += x_state[i] * filter_b[i] 156 | for i in range(len(y_state)): 157 | y -= y_state[i] * filter_a[i] 158 | 159 | y_state.insert(0,y.detach()) 160 | if len(y_state) > len(filter_a): 161 | y_state = y_state[:len(filter_a)] 162 | 163 | mb_m = model.manipulator(0.0, y, amp_factor) 164 | mb_m += mb - y 165 | 166 | y_hat = model.decoder(vb, mb_m) 167 | 168 | if i==0: 169 | # back to image scale (0-255) 170 | tmp = xa.permute(0,2,3,1).cpu().detach().numpy() 171 | tmp = np.clip(tmp, -1.0, 1.0) 172 | tmp = ((tmp + 1.0) * 127.5).astype(np.uint8) 173 | mag_frames.append(tmp) 174 | 175 | # back to image scale (0-255) 176 | y_hat = y_hat.permute(0,2,3,1).cpu().detach().numpy() 177 | y_hat = np.clip(y_hat, -1.0, 1.0) 178 | y_hat = ((y_hat + 1.0) * 127.5).astype(np.uint8) 179 | mag_frames.append(y_hat) 180 | 181 | 182 | 183 | # save frames 184 | mag_frames = np.concatenate(mag_frames, 0) 185 | for i, frame in enumerate(mag_frames): 186 | fn = os.path.join(save_dir, 'demo_%s_%06d.png'%(args.mode,i)) 187 | im = Image.fromarray(frame) 188 | im.save(fn) 189 | 190 | 191 | 192 | if __name__ == '__main__': 193 | main() 194 | 195 | 196 | 197 | 198 | 199 | 200 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # need to check 2 | 3 | import argparse 4 | import os 5 | import shutil 6 | import time 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim 13 | import torch.utils.data as data 14 | import torchvision.transforms as transforms 15 | import torchvision.datasets as datasets 16 | 17 | from network import MagNet 18 | from data_loader import ImageFromFolder 19 | from utils import AverageMeter 20 | 21 | parser = argparse.ArgumentParser(description='PyTorch Deep Video Magnification') 22 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 23 | help='number of data loading workers (default: 0)') 24 | parser.add_argument('--epochs', default=12, type=int, metavar='N', 25 | help='number of total epochs to run') 26 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 27 | help='manual epoch number (useful on restarts)') 28 | parser.add_argument('-b', '--batch-size', default=4, type=int, 29 | metavar='N', help='mini-batch size (default: 4)') 30 | parser.add_argument('--lr', '--learning-rate', default=0.0001, type=float, 31 | metavar='LR', help='initial learning rate') 32 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 33 | help='momentum') 34 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, 35 | metavar='W', help='weight decay (default: 0.0)') 36 | parser.add_argument('--num_data', default=100000, type=int, 37 | help='number of total data sample used for training (default: 100000)') 38 | parser.add_argument('--print-freq', '-p', default=100, type=int, 39 | metavar='N', help='print frequency (default: 100)') 40 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 41 | help='path to latest checkpoint (default: none)') 42 | parser.add_argument('--ckpt', default='ckpt', type=str, metavar='PATH', 43 | help='path to save checkpoint (default: ckpt)') 44 | parser.add_argument('--gpu',default=0, type=str, help='cuda_visible_devices') 45 | parser.add_argument('--weight_reg1', default=1.0, type=float, 46 | help='weight texture regularization loss (default: 1.0)') 47 | parser.add_argument('--weight_reg2', default=1.0, type=float, 48 | help='weight shpae regularization loss (default: 1.0)') 49 | args = parser.parse_args() 50 | 51 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 52 | os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu) 53 | 54 | losses_recon, losses_reg1, losses_reg2, losses_reg3 = [],[],[],[] 55 | 56 | def main(): 57 | global args 58 | global losses_recon, losses_reg1, losses_reg2, losses_reg3 59 | args = parser.parse_args() 60 | print(args) 61 | 62 | # create model 63 | model = MagNet() 64 | model = torch.nn.DataParallel(model).cuda() 65 | print(model) 66 | 67 | # optionally resume from a checkpoint 68 | if args.resume: 69 | if os.path.isfile(args.resume): 70 | print("=> loading checkpoint '{}'".format(args.resume)) 71 | checkpoint = torch.load(args.resume) 72 | args.start_epoch = checkpoint['epoch'] 73 | 74 | model.load_state_dict(checkpoint['state_dict']) 75 | losses_recon = checkpoint['losses_recon'] 76 | losses_reg1 = checkpoint['losses_reg1'] 77 | losses_reg2 = checkpoint['losses_reg2'] 78 | losses_reg3 = checkpoint['losses_reg3'] 79 | print("=> loaded checkpoint '{}' (epoch {})" 80 | .format(args.resume, checkpoint['epoch'])) 81 | else: 82 | print("=> no checkpoint found at '{}'".format(args.resume)) 83 | 84 | # check saving directory 85 | ckpt_dir = args.ckpt 86 | if not os.path.exists(ckpt_dir): 87 | os.makedirs(ckpt_dir) 88 | print(ckpt_dir) 89 | 90 | # cudnn enable 91 | cudnn.benchmark = True 92 | 93 | # dataloader 94 | dataset_mag = ImageFromFolder('./../data/train', num_data=args.num_data, preprocessing=True) 95 | data_loader = data.DataLoader(dataset_mag, 96 | batch_size=args.batch_size, 97 | shuffle=True, 98 | num_workers=args.workers, 99 | pin_memory=True) 100 | 101 | # loss criterion 102 | criterion = nn.L1Loss(size_average=True).cuda() 103 | 104 | # optimizer 105 | optimizer = torch.optim.Adam(model.parameters(), args.lr, 106 | betas=(0.9,0.999), 107 | weight_decay=args.weight_decay) 108 | 109 | 110 | # train model 111 | for epoch in range(args.start_epoch, args.epochs): 112 | loss_recon, loss_reg1, loss_reg2, loss_reg3 = train(data_loader, model, criterion, optimizer, epoch, args) 113 | 114 | # stack losses 115 | losses_recon.append(loss_recon) 116 | losses_reg1.append(loss_reg1) 117 | losses_reg2.append(loss_reg2) 118 | losses_reg3.append(loss_reg3) 119 | 120 | dict_checkpoint = { 121 | 'epoch': epoch + 1, 122 | 'state_dict': model.state_dict(), 123 | 'losses_recon': losses_recon, 124 | 'losses_reg1': losses_reg1, 125 | 'losses_reg2': losses_reg2, 126 | 'losses_reg3': losses_reg3, 127 | } 128 | 129 | # save checkpoints 130 | fpath = os.path.join(ckpt_dir, 'ckpt_e%02d.pth.tar'%(epoch)) 131 | torch.save(dict_checkpoint, fpath) 132 | 133 | def train(loader, model, criterion, optimizer, epoch, args): 134 | batch_time = AverageMeter() 135 | data_time = AverageMeter() 136 | losses_recon = AverageMeter() 137 | losses_reg1 = AverageMeter() # texture loss 138 | losses_reg2 = AverageMeter() # shape loss 139 | losses_reg3 = AverageMeter() 140 | 141 | model.train() 142 | 143 | end = time.time() 144 | for i, (y, xa, xb, xc, amp_factor) in enumerate(loader): 145 | y = y.cuda(async=True) 146 | data_time.update(time.time() - end) 147 | 148 | # compute output 149 | amp_factor = amp_factor.unsqueeze(1).unsqueeze(1).unsqueeze(1) 150 | y_hat, rep_a, rep_b, rep_c = model(xa, xb, xc, amp_factor) 151 | #v_c, m_c = model.encoder(xc) 152 | v_a, m_a = rep_a # v: texture, m: shape 153 | v_b, m_b = rep_b 154 | v_c, m_c = rep_c 155 | 156 | # compute losses 157 | loss_recon = criterion(y_hat, y) 158 | loss_reg1 = args.weight_reg1 * L1_loss(v_c, v_a) 159 | loss_reg2 = args.weight_reg2 * L1_loss(m_c, m_b) 160 | loss_reg3 = 0.0 161 | loss = loss_recon + loss_reg1 + loss_reg2 + loss_reg3 162 | 163 | losses_recon.update(loss_recon.item()) 164 | losses_reg1.update(loss_reg1.item()) 165 | losses_reg2.update(loss_reg2.item()) 166 | losses_reg3.update(loss_reg3) 167 | 168 | # update model 169 | optimizer.zero_grad() 170 | loss.backward() 171 | optimizer.step() 172 | 173 | # measure elapsed time 174 | batch_time.update(time.time() - end) 175 | end = time.time() 176 | 177 | if i % args.print_freq == 0: 178 | print('Epoch: [{0}][{1}/{2}]\t' 179 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 180 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 181 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 182 | 'LossR1 {loss_reg1.val:.4f} ({loss_reg1.avg:.4f})\t' 183 | 'LossR2 {loss_reg2.val:.4f} ({loss_reg2.avg:.4f})\t' 184 | 'LossR3 {loss_reg3.val:.4f} ({loss_reg3.avg:.4f})'.format( 185 | epoch, i, len(loader), batch_time=batch_time, data_time=data_time, 186 | loss=losses_recon, loss_reg1=losses_reg1, 187 | loss_reg2=losses_reg2, loss_reg3=losses_reg3)) 188 | 189 | 190 | return losses_recon.avg, losses_reg1.avg, losses_reg2.avg, losses_reg3.avg 191 | 192 | 193 | def L1_loss(input, target): 194 | return torch.abs(input - target).mean() 195 | 196 | 197 | 198 | if __name__ == '__main__': 199 | main() 200 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | class AverageMeter(object): 3 | """Computes and stores the average and current value""" 4 | def __init__(self): 5 | self.reset() 6 | 7 | def reset(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def update(self, val, n=1): 14 | self.val = val 15 | self.sum += val * n 16 | self.count += n 17 | self.avg = self.sum / self.count 18 | 19 | 20 | --------------------------------------------------------------------------------