├── README.md ├── _config.yml ├── dataset.py ├── finalloss.png ├── flowlib.py ├── infer.py ├── models └── style │ ├── autoportrait.jpg │ ├── candy.jpg │ ├── composition.jpg │ ├── edtaonisl.jpg │ └── udnie.jpg ├── network.png ├── network.py ├── requirements.txt ├── styles ├── autoportrait.jpg ├── candy.jpg ├── composition.jpg ├── edtaonisl.jpg ├── init └── udnie.jpg ├── test.py ├── testwarp.py ├── totaldata.py ├── train.py └── utilities.py /README.md: -------------------------------------------------------------------------------- 1 | # ReCoNet-PyTorch 2 | This repository contains a PyTorch implementation of the [ReCoNet paper](https://arxiv.org/pdf/1807.01197.pdf). This is the course project for CS763, IITB. This model has been trained and results have been uploaded in [this repo](https://github.com/liulai/reconet-torch) by @[liulai](https://github.com/liulai/reconet-torch). 3 | 4 | ### Contributors: 5 | - [Mohd Safwan](https://github.com/safwankdb) 6 | - [Kushagra Juneja](https://github.com/kushagra1729) 7 | - [Saksham Khandelwal](https://github.com/skq024) 8 | 9 | ### Abstract 10 | We aim to build a generalisable neural style transfer network for videos with temporal consistency and efficient real time style transfer using any modern GPU. All the past techniques haven't been able to accomplish real-time efficient style transfer either lacking in temporal consistency, nice perceptual style quality or fast processing. Here we have used ReCoNet, which tried to mitigate all these problems. 11 | 12 | ### ReCoNet 13 | ReCoNet is a feed forward neural network, which stylises videos frame by frame through an encoder and subsequently a decoder, and a VGG loss network to capture the perceptual style of the transfer target. The temporal loss is guided by occlusion masks and optical flow. Only the encoder and decoder run during inference which makes ReCoNet very efficient, running above real-time (~200fps) on modern GPUs.
14 | The network is illustrated in the figure below:
15 | 16 | ![ReCoNet Structure](https://github.com/skq024/Real-time-Coherent-Style-Transfer-For-Videos/blob/master/network.png) 17 | 18 | ### Dataset 19 | We have used MPI Sintel dataset which contains around 1000 frames and FlyingChairs dataset which contains about 22000 frames as the training datasets and a video clipping of some animated movie for testing.
20 | We have tried style transfer over the following styles:
21 |
22 | autoportrait 23 | candy 24 | composition 25 | edtaonisl 26 |
27 | 28 | ### Loss functions and optimisation 29 | The network consists of a multi level temporal loss which focuses on temporal coherence at both high level feature maps and the final stylised output. The high level features do not involve the effect of luminance and hence, whereas the finalised output has a luminance term included. The perceptual losses are calculated using the VGG 16 network, and involve the content loss, style loss and the total variation regularizer. They are calculated on each frame separately and then summed up with the temporal losses for the particular frame.
30 | The final loss function is:
31 |
32 | edtaonisl 33 |
34 | 35 | ### Requirements 36 | ```bash 37 | numpy 38 | Pillow 39 | scikit-image 40 | opencv-python 41 | torch 42 | torchvision 43 | ``` 44 | ### References 45 | [1] Real Time Coherent Video Style Transfer Network : https://arxiv.org/pdf/1807.01197.pdf
46 | [2] Gram matrix : https://pytorch.org/tutorials/advanced/neural_style_tutorial.html
47 | [3] ReCoNet Model : https://github.com/irsisyphus/reconet/blob/master/network.py
48 | [4] Optical flow warping : https://github.com/NVlabs/PWC-Net/blob/master/PyTorch/models/PWCNet.py
49 | [5] Optical flow I/O : https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html 50 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman 2 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | from torchvision import transforms, utils 5 | import numpy as np 6 | import PIL 7 | from PIL import Image 8 | from flowlib import read 9 | from skimage import io, transform 10 | 11 | device='cuda' 12 | 13 | def toString(num): 14 | string = str(num) 15 | while(len(string) < 4): 16 | string = "0"+string 17 | return string 18 | 19 | 20 | class MPIDataset(Dataset): 21 | 22 | def __init__(self, path, transform=None): 23 | """ 24 | looking at the "clean" subfolder for images, might change to "final" later 25 | root_dir -> path to the location where the "training" folder is kept inside the MPI folder 26 | """ 27 | self.path = path+"training/" 28 | self.transform = transform 29 | self.dirlist = os.listdir(self.path+"clean/") 30 | self.dirlist.sort() 31 | # print(self.dirlist) 32 | self.numlist = [] 33 | for folder in self.dirlist: 34 | self.numlist.append(len(os.listdir(self.path+"clean/"+folder+"/"))) 35 | 36 | def __len__(self): 37 | 38 | return sum(self.numlist)-len(self.numlist) 39 | 40 | def __getitem__(self, idx): 41 | """ 42 | idx must be between 0 to len-1 43 | assuming flow[0] contains flow in x direction and flow[1] contains flow in y 44 | """ 45 | for i in range(0, len(self.numlist)): 46 | folder = self.dirlist[i] 47 | path = self.path+"clean/"+folder+"/" 48 | occpath = self.path+"occlusions/"+folder+"/" 49 | flowpath = self.path+"flow/"+folder+"/" 50 | if(idx < (self.numlist[i]-1)): 51 | num1 = toString(idx+1) 52 | num2 = toString(idx+2) 53 | img1 = io.imread(path+"frame_"+num1+".png") 54 | img2 = io.imread(path+"frame_"+num2+".png") 55 | mask = io.imread(occpath+"frame_"+num1+".png") 56 | img1 = torch.from_numpy(transform.resize(img1, (360, 640))).to(device).permute(2, 0, 1).float() 57 | img2 = torch.from_numpy(transform.resize(img2, (360, 640))).to(device).permute(2, 0, 1).float() 58 | mask = torch.from_numpy(transform.resize(mask, (360, 640))).to(device).float() 59 | flow = read(flowpath+"frame_"+num1+".flo") 60 | # bilinear interpolation is default 61 | originalflow=torch.from_numpy(flow) 62 | flow = torch.from_numpy(transform.resize(flow, (360, 640))).to(device).permute(2,0,1).float() 63 | flow[0, :, :] *= float(flow.shape[1])/originalflow.shape[1] 64 | flow[1, :, :] *= float(flow.shape[2])/originalflow.shape[2] 65 | # print(flow.shape) #y,x,2 66 | # print(img1.shape) 67 | break 68 | 69 | idx -= self.numlist[i]-1 70 | 71 | if self.transform: 72 | # complete later 73 | pass 74 | #IMG2 should be at t in IMG1 is at T-1 75 | return (img1, img2, mask, flow) 76 | -------------------------------------------------------------------------------- /finalloss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/safwankdb/ReCoNet-PyTorch/8d6d98a32968d87c5ae26aed0aae0adb3dfa94f6/finalloss.png -------------------------------------------------------------------------------- /flowlib.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.4 2 | #Source: https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html 3 | 4 | import os 5 | import re 6 | import numpy as np 7 | import uuid 8 | from scipy import misc 9 | import numpy as np 10 | from PIL import Image 11 | import sys 12 | 13 | def read(file): 14 | if file.endswith('.float3'): return readFloat(file) 15 | elif file.endswith('.flo'): return readFlow(file) 16 | elif file.endswith('.ppm'): return readImage(file) 17 | elif file.endswith('.pgm'): return readImage(file) 18 | elif file.endswith('.png'): return readImage(file) 19 | elif file.endswith('.jpg'): return readImage(file) 20 | elif file.endswith('.pfm'): return readPFM(file)[0] 21 | else: raise Exception('don\'t know how to read %s' % file) 22 | 23 | def write(file, data): 24 | if file.endswith('.float3'): return writeFloat(file, data) 25 | elif file.endswith('.flo'): return writeFlow(file, data) 26 | elif file.endswith('.ppm'): return writeImage(file, data) 27 | elif file.endswith('.pgm'): return writeImage(file, data) 28 | elif file.endswith('.png'): return writeImage(file, data) 29 | elif file.endswith('.jpg'): return writeImage(file, data) 30 | elif file.endswith('.pfm'): return writePFM(file, data) 31 | else: raise Exception('don\'t know how to write %s' % file) 32 | 33 | def readFlow(name): 34 | if name.endswith('.pfm') or name.endswith('.PFM'): 35 | return readPFM(name)[0][:,:,0:2] 36 | 37 | f = open(name, 'rb') 38 | 39 | header = f.read(4) 40 | if header.decode("utf-8") != 'PIEH': 41 | raise Exception('Flow file header does not contain PIEH') 42 | 43 | width = np.fromfile(f, np.int32, 1).squeeze() 44 | height = np.fromfile(f, np.int32, 1).squeeze() 45 | 46 | flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2)) 47 | 48 | return flow.astype(np.float32) 49 | 50 | def writeFlow(name, flow): 51 | f = open(name, 'wb') 52 | f.write('PIEH'.encode('utf-8')) 53 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) 54 | flow = flow.astype(np.float32) 55 | flow.tofile(f) 56 | 57 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from torchvision import transforms, utils 5 | from skimage import io, transform 6 | from PIL import Image 7 | import cv2 8 | import numpy as np 9 | from network import ReCoNet 10 | from utilities import * 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--source', required=True, help='Video file to process') 14 | # parser.add_argument('--target', required=True, help='Output file') 15 | parser.add_argument('--model', required=True, help='Model state_dict file') 16 | args = parser.parse_args() 17 | device = 'cuda' 18 | video_capture = cv2.VideoCapture(args.source) 19 | model = ReCoNet().to(device) 20 | model.load_state_dict(torch.load(args.model)) 21 | 22 | images = os.listdir('alley_2') 23 | images.sort() 24 | # for i in images: 25 | # frame = cv2.imread('alley_2/'+i) 26 | while(True): 27 | ret, frame = video_capture.read() 28 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 29 | frame = torch.from_numpy(transform.resize(frame, (360, 640))).to(device).permute(2, 0, 1).float() 30 | # frame = normalize(frame) 31 | features, styled_frame = model(frame.unsqueeze(0)) 32 | # styled_frame -= 127.5 33 | # styled_frame = styled_frame.cpu().clamp(0, 255).data.squeeze(0).numpy().transpose(1, 2, 0).astype('uint8') 34 | styled_frame = transforms.ToPILImage()(styled_frame[0].detach().cpu()) 35 | styled_frame = np.array(styled_frame) 36 | styled_frame = styled_frame[:, :,::-1] 37 | cv2.imshow('frame', styled_frame) 38 | if cv2.waitKey(1) & 0xFF == ord('q'): 39 | break 40 | -------------------------------------------------------------------------------- /models/style/autoportrait.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/safwankdb/ReCoNet-PyTorch/8d6d98a32968d87c5ae26aed0aae0adb3dfa94f6/models/style/autoportrait.jpg -------------------------------------------------------------------------------- /models/style/candy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/safwankdb/ReCoNet-PyTorch/8d6d98a32968d87c5ae26aed0aae0adb3dfa94f6/models/style/candy.jpg -------------------------------------------------------------------------------- /models/style/composition.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/safwankdb/ReCoNet-PyTorch/8d6d98a32968d87c5ae26aed0aae0adb3dfa94f6/models/style/composition.jpg -------------------------------------------------------------------------------- /models/style/edtaonisl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/safwankdb/ReCoNet-PyTorch/8d6d98a32968d87c5ae26aed0aae0adb3dfa94f6/models/style/edtaonisl.jpg -------------------------------------------------------------------------------- /models/style/udnie.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/safwankdb/ReCoNet-PyTorch/8d6d98a32968d87c5ae26aed0aae0adb3dfa94f6/models/style/udnie.jpg -------------------------------------------------------------------------------- /network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/safwankdb/ReCoNet-PyTorch/8d6d98a32968d87c5ae26aed0aae0adb3dfa94f6/network.png -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torchvision.models import vgg16 5 | from collections import namedtuple 6 | 7 | # From https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/vgg.py 8 | class Vgg16(torch.nn.Module): 9 | def __init__(self, device='cpu'): 10 | super(Vgg16, self).__init__() 11 | vgg_pretrained_features = vgg16(pretrained=True).features 12 | self.slice1 = torch.nn.Sequential() 13 | self.slice2 = torch.nn.Sequential() 14 | self.slice3 = torch.nn.Sequential() 15 | self.slice4 = torch.nn.Sequential() 16 | for x in range(4): 17 | self.slice1.add_module(str(x), vgg_pretrained_features[x].to(device)) 18 | for x in range(4, 9): 19 | self.slice2.add_module(str(x), vgg_pretrained_features[x].to(device)) 20 | for x in range(9, 16): 21 | self.slice3.add_module(str(x), vgg_pretrained_features[x].to(device)) 22 | for x in range(16, 23): 23 | self.slice4.add_module(str(x), vgg_pretrained_features[x].to(device)) 24 | 25 | for param in self.parameters(): 26 | param.requires_grad = False 27 | 28 | def forward(self, X): 29 | h = self.slice1(X) 30 | h_relu1_2 = h 31 | h = self.slice2(h) 32 | h_relu2_2 = h 33 | h = self.slice3(h) 34 | h_relu3_3 = h 35 | h = self.slice4(h) 36 | h_relu4_3 = h 37 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3']) 38 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3) 39 | return out 40 | 41 | 42 | # Rest of the file based on https://github.com/irsisyphus/reconet 43 | 44 | class SelectiveLoadModule(torch.nn.Module): 45 | """Only load layers in trained models with the same name.""" 46 | def __init__(self): 47 | super(SelectiveLoadModule, self).__init__() 48 | 49 | def forward(self, x): 50 | return x 51 | 52 | def load_state_dict(self, state_dict): 53 | """Override the function to ignore redundant weights.""" 54 | own_state = self.state_dict() 55 | for name, param in state_dict.items(): 56 | if name in own_state: 57 | own_state[name].copy_(param) 58 | 59 | 60 | class ConvLayer(torch.nn.Module): 61 | """Reflection padded convolution layer.""" 62 | def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True): 63 | super(ConvLayer, self).__init__() 64 | reflection_padding = int(np.floor(kernel_size / 2)) 65 | self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding) 66 | self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, bias=bias) 67 | 68 | def forward(self, x): 69 | out = self.reflection_pad(x) 70 | out = self.conv2d(out) 71 | return out 72 | 73 | 74 | class ConvTanh(ConvLayer): 75 | def __init__(self, in_channels, out_channels, kernel_size, stride): 76 | super(ConvTanh, self).__init__(in_channels, out_channels, kernel_size, stride) 77 | self.tanh = torch.nn.Tanh() 78 | 79 | def forward(self, x): 80 | out = super(ConvTanh, self).forward(x) 81 | return self.tanh(out/255) * 150 + 255/2 82 | 83 | 84 | class ConvInstRelu(ConvLayer): 85 | def __init__(self, in_channels, out_channels, kernel_size, stride): 86 | super(ConvInstRelu, self).__init__(in_channels, out_channels, kernel_size, stride) 87 | self.instance = torch.nn.InstanceNorm2d(out_channels, affine=True) 88 | self.relu = torch.nn.ReLU() 89 | 90 | def forward(self, x): 91 | out = super(ConvInstRelu, self).forward(x) 92 | out = self.instance(out) 93 | out = self.relu(out) 94 | return out 95 | 96 | 97 | class UpsampleConvLayer(torch.nn.Module): 98 | """Upsamples the input and then does a convolution. 99 | This method gives better results compared to ConvTranspose2d. 100 | ref: http://distill.pub/2016/deconv-checkerboard/ 101 | """ 102 | def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None): 103 | super(UpsampleConvLayer, self).__init__() 104 | self.upsample = upsample 105 | reflection_padding = int(np.floor(kernel_size / 2)) 106 | self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding) 107 | self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride) 108 | 109 | def forward(self, x): 110 | x_in = x 111 | if self.upsample: 112 | x_in = torch.nn.functional.interpolate(x_in, scale_factor=self.upsample) 113 | out = self.reflection_pad(x_in) 114 | out = self.conv2d(out) 115 | return out 116 | 117 | 118 | class UpsampleConvInstRelu(UpsampleConvLayer): 119 | def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None): 120 | super(UpsampleConvInstRelu, self).__init__(in_channels, out_channels, kernel_size, stride, upsample) 121 | self.instance = torch.nn.InstanceNorm2d(out_channels, affine=True) 122 | self.relu = torch.nn.ReLU() 123 | 124 | def forward(self, x): 125 | out = super(UpsampleConvInstRelu, self).forward(x) 126 | out = self.instance(out) 127 | out = self.relu(out) 128 | return out 129 | 130 | 131 | class ResidualBlock(torch.nn.Module): 132 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1): 133 | super(ResidualBlock, self).__init__() 134 | self.conv1 = ConvLayer(in_channels, out_channels, kernel_size, stride) 135 | self.in1 = torch.nn.InstanceNorm2d(out_channels, affine=True) 136 | self.conv2 = ConvLayer(out_channels, out_channels, kernel_size, stride) 137 | self.in2 = torch.nn.InstanceNorm2d(out_channels, affine=True) 138 | self.relu = torch.nn.ReLU() 139 | 140 | def forward(self, x): 141 | residual = x 142 | out = self.relu(self.in1(self.conv1(x))) 143 | out = self.in2(self.conv2(out)) 144 | out = out + residual 145 | return out 146 | 147 | 148 | class ReCoNet(SelectiveLoadModule): 149 | def __init__(self): 150 | super(ReCoNet, self).__init__() 151 | 152 | self.conv1 = ConvInstRelu(3, 32, kernel_size=9, stride=1) 153 | self.conv2 = ConvInstRelu(32, 64, kernel_size=3, stride=2) 154 | self.conv3 = ConvInstRelu(64, 128, kernel_size=3, stride=2) 155 | 156 | self.res1 = ResidualBlock(128, 128) 157 | self.res2 = ResidualBlock(128, 128) 158 | self.res3 = ResidualBlock(128, 128) 159 | self.res4 = ResidualBlock(128, 128) 160 | self.res5 = ResidualBlock(128, 128) 161 | 162 | self.deconv1 = UpsampleConvInstRelu(128, 64, kernel_size=3, stride=1, upsample=2) 163 | self.deconv2 = UpsampleConvInstRelu(64, 32, kernel_size=3, stride=1, upsample=2) 164 | self.deconv3 = ConvTanh(32, 3, kernel_size=9, stride=1) 165 | 166 | def forward(self, x): 167 | x = self.conv1(x) 168 | x = self.conv2(x) 169 | x = self.conv3(x) 170 | 171 | x = self.res1(x) 172 | x = self.res2(x) 173 | x = self.res3(x) 174 | x = self.res4(x) 175 | x = self.res5(x) 176 | 177 | features = x 178 | 179 | x = self.deconv1(x) 180 | x = self.deconv2(x) 181 | x = self.deconv3(x) 182 | 183 | return (features, x) 184 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | Pillow 3 | scikit-image 4 | opencv-python 5 | torch 6 | torchvision 7 | -------------------------------------------------------------------------------- /styles/autoportrait.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/safwankdb/ReCoNet-PyTorch/8d6d98a32968d87c5ae26aed0aae0adb3dfa94f6/styles/autoportrait.jpg -------------------------------------------------------------------------------- /styles/candy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/safwankdb/ReCoNet-PyTorch/8d6d98a32968d87c5ae26aed0aae0adb3dfa94f6/styles/candy.jpg -------------------------------------------------------------------------------- /styles/composition.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/safwankdb/ReCoNet-PyTorch/8d6d98a32968d87c5ae26aed0aae0adb3dfa94f6/styles/composition.jpg -------------------------------------------------------------------------------- /styles/edtaonisl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/safwankdb/ReCoNet-PyTorch/8d6d98a32968d87c5ae26aed0aae0adb3dfa94f6/styles/edtaonisl.jpg -------------------------------------------------------------------------------- /styles/init: -------------------------------------------------------------------------------- 1 | . 2 | -------------------------------------------------------------------------------- /styles/udnie.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/safwankdb/ReCoNet-PyTorch/8d6d98a32968d87c5ae26aed0aae0adb3dfa94f6/styles/udnie.jpg -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | import numpy as np 5 | from totaldata import * 6 | from skimage import io, transform 7 | 8 | data=ConsolidatedDataset(MPI_path="../MPI_Data/" ,FC_path="../FlyingChairs2/") 9 | print(len(data)) 10 | img1, img2, mask, flow= data[230] 11 | print(img1.size(), img2.size(), mask.size(), flow.size()) 12 | io.imsave("img1.png", img1.squeeze().permute(1,2,0).cpu().numpy()) 13 | io.imsave("img2.png", img2.squeeze().permute(1,2,0).cpu().numpy()) 14 | # io.imsave("mask.png", mask.squeeze().permute(1,2,0).cpu().numpy()) 15 | -------------------------------------------------------------------------------- /testwarp.py: -------------------------------------------------------------------------------- 1 | from utilities import * 2 | from dataset import MPIDataset 3 | from skimage import io, transform 4 | from torch.utils.data import DataLoader 5 | import torch 6 | 7 | path="../MPI_Data/" 8 | dataloader = DataLoader(MPIDataset("../MPI_Data/"), batch_size=1) 9 | 10 | for itr, (img1, img2, mask, flow) in enumerate(dataloader): 11 | if(itr==1): 12 | break 13 | warped=warp(img1,flow) 14 | print(img1.squeeze().permute(1,2,0).size()) 15 | io.imsave("warped.png", warped.squeeze().permute(1,2,0).cpu().numpy()) 16 | io.imsave("img1.png", img1.squeeze().permute(1,2,0).cpu().numpy()) 17 | io.imsave("img2.png", img2.squeeze().permute(1,2,0).cpu().numpy()) 18 | -------------------------------------------------------------------------------- /totaldata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | from torchvision import transforms, utils 5 | import numpy as np 6 | import PIL 7 | from PIL import Image 8 | from flowlib import read 9 | from skimage import io, transform 10 | 11 | device='cuda' 12 | 13 | def toString4(num): 14 | string = str(num) 15 | while(len(string) < 4): 16 | string = "0"+string 17 | return string 18 | 19 | def toString7(num): 20 | string = str(num) 21 | while(len(string) < 7): 22 | string = "0"+string 23 | return string 24 | 25 | class MPIDataset(Dataset): 26 | 27 | def __init__(self, path, transform=None): 28 | """ 29 | looking at the "clean" subfolder for images, might change to "final" later 30 | root_dir -> path to the location where the "training" folder is kept inside the MPI folder 31 | """ 32 | self.path = path+"training/" 33 | self.transform = transform 34 | self.dirlist = os.listdir(self.path+"clean/") 35 | self.dirlist.sort() 36 | # print(self.dirlist) 37 | self.numlist = [] 38 | for folder in self.dirlist: 39 | self.numlist.append(len(os.listdir(self.path+"clean/"+folder+"/"))) 40 | self.length=sum(self.numlist)-len(self.numlist) 41 | 42 | def __len__(self): 43 | 44 | return self.length 45 | 46 | def __getitem__(self, idx): 47 | """ 48 | idx must be between 0 to len-1 49 | assuming flow[0] contains flow in x direction and flow[1] contains flow in y 50 | """ 51 | for i in range(0, len(self.numlist)): 52 | folder = self.dirlist[i] 53 | path = self.path+"clean/"+folder+"/" 54 | occpath = self.path+"occlusions/"+folder+"/" 55 | flowpath = self.path+"flow/"+folder+"/" 56 | if(idx < (self.numlist[i]-1)): 57 | num1 = toString4(idx+1) 58 | num2 = toString4(idx+2) 59 | img1 = io.imread(path+"frame_"+num1+".png") 60 | img2 = io.imread(path+"frame_"+num2+".png") 61 | mask = io.imread(occpath+"frame_"+num1+".png") 62 | img1 = torch.from_numpy(transform.resize(img1, (360, 640))).to(device).permute(2, 0, 1).float() 63 | img2 = torch.from_numpy(transform.resize(img2, (360, 640))).to(device).permute(2, 0, 1).float() 64 | mask = torch.from_numpy(transform.resize(mask, (360, 640))).to(device).float() 65 | flow = read(flowpath+"frame_"+num1+".flo") 66 | # bilinear interpolation is default 67 | originalflow=torch.from_numpy(flow) 68 | flow = torch.from_numpy(transform.resize(flow, (360, 640))).to(device).permute(2,0,1).float() 69 | flow[0, :, :] *= float(flow.shape[1])/originalflow.shape[1] 70 | flow[1, :, :] *= float(flow.shape[2])/originalflow.shape[2] 71 | # print(flow.shape) #y,x,2 72 | # print(img1.shape) 73 | break 74 | 75 | idx -= self.numlist[i]-1 76 | 77 | if self.transform: 78 | # complete later 79 | pass 80 | #IMG2 should be at t in IMG1 is at T-1 81 | return (img1, img2, mask, flow) 82 | 83 | class FlyingChairsDataset(Dataset): 84 | 85 | def __init__(self, path, transform=None): 86 | """ 87 | looking at the "clean" subfolder for images, might change to "final" later 88 | root_dir -> path to the location where the "training" folder is kept inside the MPI folder 89 | """ 90 | self.path = path+"train/" 91 | self.transform = transform 92 | self.length=22230 #14 files corresponding to each image pair 93 | 94 | def __len__(self): 95 | 96 | return self.length 97 | 98 | def __getitem__(self, idx): 99 | """ 100 | idx must be between 0 to len-1 101 | assuming flow[0] contains flow in x direction and flow[1] contains flow in y 102 | """ 103 | num = toString7(idx) 104 | img1 = io.imread(self.path+num+"-img_0.png") 105 | img2 = io.imread(self.path+num+"-img_1.png") 106 | mask = io.imread(self.path+num+"-mb_01.png") 107 | img1 = torch.from_numpy(transform.resize(img1, (360, 640))).to(device).permute(2, 0, 1).float() 108 | img2 = torch.from_numpy(transform.resize(img2, (360, 640))).to(device).permute(2, 0, 1).float() 109 | mask = torch.from_numpy(transform.resize(mask, (360, 640))).to(device).float() 110 | flow = read(self.path+num+"-flow_01.flo") 111 | # bilinear interpolation is default 112 | originalflow=torch.from_numpy(flow) 113 | flow = torch.from_numpy(transform.resize(flow, (360, 640))).to(device).permute(2,0,1).float() 114 | flow[0, :, :] *= float(flow.shape[1])/originalflow.shape[1] 115 | flow[1, :, :] *= float(flow.shape[2])/originalflow.shape[2] 116 | 117 | if self.transform: 118 | # complete later 119 | pass 120 | #IMG2 should be at t in IMG1 is at T-1 121 | return (img1, img2, mask, flow) 122 | 123 | class ConsolidatedDataset(Dataset): 124 | 125 | def __init__(self, MPI_path, FC_path, transform=None): 126 | self.mpi=MPIDataset(MPI_path) 127 | self.fc=FlyingChairsDataset(FC_path) 128 | 129 | def __len__(self): 130 | 131 | return len(self.mpi)+len(self.fc) 132 | 133 | def __getitem__(self, idx): 134 | """ 135 | idx must be between 0 to len-1 136 | """ 137 | if(idx0] = 1 90 | 91 | return output*mask 92 | 93 | 94 | 95 | --------------------------------------------------------------------------------