├── README.md
├── _config.yml
├── dataset.py
├── finalloss.png
├── flowlib.py
├── infer.py
├── models
└── style
│ ├── autoportrait.jpg
│ ├── candy.jpg
│ ├── composition.jpg
│ ├── edtaonisl.jpg
│ └── udnie.jpg
├── network.png
├── network.py
├── requirements.txt
├── styles
├── autoportrait.jpg
├── candy.jpg
├── composition.jpg
├── edtaonisl.jpg
├── init
└── udnie.jpg
├── test.py
├── testwarp.py
├── totaldata.py
├── train.py
└── utilities.py
/README.md:
--------------------------------------------------------------------------------
1 | # ReCoNet-PyTorch
2 | This repository contains a PyTorch implementation of the [ReCoNet paper](https://arxiv.org/pdf/1807.01197.pdf). This is the course project for CS763, IITB. This model has been trained and results have been uploaded in [this repo](https://github.com/liulai/reconet-torch) by @[liulai](https://github.com/liulai/reconet-torch).
3 |
4 | ### Contributors:
5 | - [Mohd Safwan](https://github.com/safwankdb)
6 | - [Kushagra Juneja](https://github.com/kushagra1729)
7 | - [Saksham Khandelwal](https://github.com/skq024)
8 |
9 | ### Abstract
10 | We aim to build a generalisable neural style transfer network for videos with temporal consistency and efficient real time style transfer using any modern GPU. All the past techniques haven't been able to accomplish real-time efficient style transfer either lacking in temporal consistency, nice perceptual style quality or fast processing. Here we have used ReCoNet, which tried to mitigate all these problems.
11 |
12 | ### ReCoNet
13 | ReCoNet is a feed forward neural network, which stylises videos frame by frame through an encoder and subsequently a decoder, and a VGG loss network to capture the perceptual style of the transfer target. The temporal loss is guided by occlusion masks and optical flow. Only the encoder and decoder run during inference which makes ReCoNet very efficient, running above real-time (~200fps) on modern GPUs.
14 | The network is illustrated in the figure below:
15 |
16 | 
17 |
18 | ### Dataset
19 | We have used MPI Sintel dataset which contains around 1000 frames and FlyingChairs dataset which contains about 22000 frames as the training datasets and a video clipping of some animated movie for testing.
20 | We have tried style transfer over the following styles:
21 |
27 |
28 | ### Loss functions and optimisation
29 | The network consists of a multi level temporal loss which focuses on temporal coherence at both high level feature maps and the final stylised output. The high level features do not involve the effect of luminance and hence, whereas the finalised output has a luminance term included. The perceptual losses are calculated using the VGG 16 network, and involve the content loss, style loss and the total variation regularizer. They are calculated on each frame separately and then summed up with the temporal losses for the particular frame.
30 | The final loss function is:
31 |
32 |

33 |
34 |
35 | ### Requirements
36 | ```bash
37 | numpy
38 | Pillow
39 | scikit-image
40 | opencv-python
41 | torch
42 | torchvision
43 | ```
44 | ### References
45 | [1] Real Time Coherent Video Style Transfer Network : https://arxiv.org/pdf/1807.01197.pdf
46 | [2] Gram matrix : https://pytorch.org/tutorials/advanced/neural_style_tutorial.html
47 | [3] ReCoNet Model : https://github.com/irsisyphus/reconet/blob/master/network.py
48 | [4] Optical flow warping : https://github.com/NVlabs/PWC-Net/blob/master/PyTorch/models/PWCNet.py
49 | [5] Optical flow I/O : https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html
50 |
--------------------------------------------------------------------------------
/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-cayman
2 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.utils.data import Dataset, DataLoader
4 | from torchvision import transforms, utils
5 | import numpy as np
6 | import PIL
7 | from PIL import Image
8 | from flowlib import read
9 | from skimage import io, transform
10 |
11 | device='cuda'
12 |
13 | def toString(num):
14 | string = str(num)
15 | while(len(string) < 4):
16 | string = "0"+string
17 | return string
18 |
19 |
20 | class MPIDataset(Dataset):
21 |
22 | def __init__(self, path, transform=None):
23 | """
24 | looking at the "clean" subfolder for images, might change to "final" later
25 | root_dir -> path to the location where the "training" folder is kept inside the MPI folder
26 | """
27 | self.path = path+"training/"
28 | self.transform = transform
29 | self.dirlist = os.listdir(self.path+"clean/")
30 | self.dirlist.sort()
31 | # print(self.dirlist)
32 | self.numlist = []
33 | for folder in self.dirlist:
34 | self.numlist.append(len(os.listdir(self.path+"clean/"+folder+"/")))
35 |
36 | def __len__(self):
37 |
38 | return sum(self.numlist)-len(self.numlist)
39 |
40 | def __getitem__(self, idx):
41 | """
42 | idx must be between 0 to len-1
43 | assuming flow[0] contains flow in x direction and flow[1] contains flow in y
44 | """
45 | for i in range(0, len(self.numlist)):
46 | folder = self.dirlist[i]
47 | path = self.path+"clean/"+folder+"/"
48 | occpath = self.path+"occlusions/"+folder+"/"
49 | flowpath = self.path+"flow/"+folder+"/"
50 | if(idx < (self.numlist[i]-1)):
51 | num1 = toString(idx+1)
52 | num2 = toString(idx+2)
53 | img1 = io.imread(path+"frame_"+num1+".png")
54 | img2 = io.imread(path+"frame_"+num2+".png")
55 | mask = io.imread(occpath+"frame_"+num1+".png")
56 | img1 = torch.from_numpy(transform.resize(img1, (360, 640))).to(device).permute(2, 0, 1).float()
57 | img2 = torch.from_numpy(transform.resize(img2, (360, 640))).to(device).permute(2, 0, 1).float()
58 | mask = torch.from_numpy(transform.resize(mask, (360, 640))).to(device).float()
59 | flow = read(flowpath+"frame_"+num1+".flo")
60 | # bilinear interpolation is default
61 | originalflow=torch.from_numpy(flow)
62 | flow = torch.from_numpy(transform.resize(flow, (360, 640))).to(device).permute(2,0,1).float()
63 | flow[0, :, :] *= float(flow.shape[1])/originalflow.shape[1]
64 | flow[1, :, :] *= float(flow.shape[2])/originalflow.shape[2]
65 | # print(flow.shape) #y,x,2
66 | # print(img1.shape)
67 | break
68 |
69 | idx -= self.numlist[i]-1
70 |
71 | if self.transform:
72 | # complete later
73 | pass
74 | #IMG2 should be at t in IMG1 is at T-1
75 | return (img1, img2, mask, flow)
76 |
--------------------------------------------------------------------------------
/finalloss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/safwankdb/ReCoNet-PyTorch/8d6d98a32968d87c5ae26aed0aae0adb3dfa94f6/finalloss.png
--------------------------------------------------------------------------------
/flowlib.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3.4
2 | #Source: https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html
3 |
4 | import os
5 | import re
6 | import numpy as np
7 | import uuid
8 | from scipy import misc
9 | import numpy as np
10 | from PIL import Image
11 | import sys
12 |
13 | def read(file):
14 | if file.endswith('.float3'): return readFloat(file)
15 | elif file.endswith('.flo'): return readFlow(file)
16 | elif file.endswith('.ppm'): return readImage(file)
17 | elif file.endswith('.pgm'): return readImage(file)
18 | elif file.endswith('.png'): return readImage(file)
19 | elif file.endswith('.jpg'): return readImage(file)
20 | elif file.endswith('.pfm'): return readPFM(file)[0]
21 | else: raise Exception('don\'t know how to read %s' % file)
22 |
23 | def write(file, data):
24 | if file.endswith('.float3'): return writeFloat(file, data)
25 | elif file.endswith('.flo'): return writeFlow(file, data)
26 | elif file.endswith('.ppm'): return writeImage(file, data)
27 | elif file.endswith('.pgm'): return writeImage(file, data)
28 | elif file.endswith('.png'): return writeImage(file, data)
29 | elif file.endswith('.jpg'): return writeImage(file, data)
30 | elif file.endswith('.pfm'): return writePFM(file, data)
31 | else: raise Exception('don\'t know how to write %s' % file)
32 |
33 | def readFlow(name):
34 | if name.endswith('.pfm') or name.endswith('.PFM'):
35 | return readPFM(name)[0][:,:,0:2]
36 |
37 | f = open(name, 'rb')
38 |
39 | header = f.read(4)
40 | if header.decode("utf-8") != 'PIEH':
41 | raise Exception('Flow file header does not contain PIEH')
42 |
43 | width = np.fromfile(f, np.int32, 1).squeeze()
44 | height = np.fromfile(f, np.int32, 1).squeeze()
45 |
46 | flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2))
47 |
48 | return flow.astype(np.float32)
49 |
50 | def writeFlow(name, flow):
51 | f = open(name, 'wb')
52 | f.write('PIEH'.encode('utf-8'))
53 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
54 | flow = flow.astype(np.float32)
55 | flow.tofile(f)
56 |
57 |
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | from torchvision import transforms, utils
5 | from skimage import io, transform
6 | from PIL import Image
7 | import cv2
8 | import numpy as np
9 | from network import ReCoNet
10 | from utilities import *
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--source', required=True, help='Video file to process')
14 | # parser.add_argument('--target', required=True, help='Output file')
15 | parser.add_argument('--model', required=True, help='Model state_dict file')
16 | args = parser.parse_args()
17 | device = 'cuda'
18 | video_capture = cv2.VideoCapture(args.source)
19 | model = ReCoNet().to(device)
20 | model.load_state_dict(torch.load(args.model))
21 |
22 | images = os.listdir('alley_2')
23 | images.sort()
24 | # for i in images:
25 | # frame = cv2.imread('alley_2/'+i)
26 | while(True):
27 | ret, frame = video_capture.read()
28 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
29 | frame = torch.from_numpy(transform.resize(frame, (360, 640))).to(device).permute(2, 0, 1).float()
30 | # frame = normalize(frame)
31 | features, styled_frame = model(frame.unsqueeze(0))
32 | # styled_frame -= 127.5
33 | # styled_frame = styled_frame.cpu().clamp(0, 255).data.squeeze(0).numpy().transpose(1, 2, 0).astype('uint8')
34 | styled_frame = transforms.ToPILImage()(styled_frame[0].detach().cpu())
35 | styled_frame = np.array(styled_frame)
36 | styled_frame = styled_frame[:, :,::-1]
37 | cv2.imshow('frame', styled_frame)
38 | if cv2.waitKey(1) & 0xFF == ord('q'):
39 | break
40 |
--------------------------------------------------------------------------------
/models/style/autoportrait.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/safwankdb/ReCoNet-PyTorch/8d6d98a32968d87c5ae26aed0aae0adb3dfa94f6/models/style/autoportrait.jpg
--------------------------------------------------------------------------------
/models/style/candy.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/safwankdb/ReCoNet-PyTorch/8d6d98a32968d87c5ae26aed0aae0adb3dfa94f6/models/style/candy.jpg
--------------------------------------------------------------------------------
/models/style/composition.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/safwankdb/ReCoNet-PyTorch/8d6d98a32968d87c5ae26aed0aae0adb3dfa94f6/models/style/composition.jpg
--------------------------------------------------------------------------------
/models/style/edtaonisl.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/safwankdb/ReCoNet-PyTorch/8d6d98a32968d87c5ae26aed0aae0adb3dfa94f6/models/style/edtaonisl.jpg
--------------------------------------------------------------------------------
/models/style/udnie.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/safwankdb/ReCoNet-PyTorch/8d6d98a32968d87c5ae26aed0aae0adb3dfa94f6/models/style/udnie.jpg
--------------------------------------------------------------------------------
/network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/safwankdb/ReCoNet-PyTorch/8d6d98a32968d87c5ae26aed0aae0adb3dfa94f6/network.png
--------------------------------------------------------------------------------
/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from torchvision.models import vgg16
5 | from collections import namedtuple
6 |
7 | # From https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/vgg.py
8 | class Vgg16(torch.nn.Module):
9 | def __init__(self, device='cpu'):
10 | super(Vgg16, self).__init__()
11 | vgg_pretrained_features = vgg16(pretrained=True).features
12 | self.slice1 = torch.nn.Sequential()
13 | self.slice2 = torch.nn.Sequential()
14 | self.slice3 = torch.nn.Sequential()
15 | self.slice4 = torch.nn.Sequential()
16 | for x in range(4):
17 | self.slice1.add_module(str(x), vgg_pretrained_features[x].to(device))
18 | for x in range(4, 9):
19 | self.slice2.add_module(str(x), vgg_pretrained_features[x].to(device))
20 | for x in range(9, 16):
21 | self.slice3.add_module(str(x), vgg_pretrained_features[x].to(device))
22 | for x in range(16, 23):
23 | self.slice4.add_module(str(x), vgg_pretrained_features[x].to(device))
24 |
25 | for param in self.parameters():
26 | param.requires_grad = False
27 |
28 | def forward(self, X):
29 | h = self.slice1(X)
30 | h_relu1_2 = h
31 | h = self.slice2(h)
32 | h_relu2_2 = h
33 | h = self.slice3(h)
34 | h_relu3_3 = h
35 | h = self.slice4(h)
36 | h_relu4_3 = h
37 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
38 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3)
39 | return out
40 |
41 |
42 | # Rest of the file based on https://github.com/irsisyphus/reconet
43 |
44 | class SelectiveLoadModule(torch.nn.Module):
45 | """Only load layers in trained models with the same name."""
46 | def __init__(self):
47 | super(SelectiveLoadModule, self).__init__()
48 |
49 | def forward(self, x):
50 | return x
51 |
52 | def load_state_dict(self, state_dict):
53 | """Override the function to ignore redundant weights."""
54 | own_state = self.state_dict()
55 | for name, param in state_dict.items():
56 | if name in own_state:
57 | own_state[name].copy_(param)
58 |
59 |
60 | class ConvLayer(torch.nn.Module):
61 | """Reflection padded convolution layer."""
62 | def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True):
63 | super(ConvLayer, self).__init__()
64 | reflection_padding = int(np.floor(kernel_size / 2))
65 | self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
66 | self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, bias=bias)
67 |
68 | def forward(self, x):
69 | out = self.reflection_pad(x)
70 | out = self.conv2d(out)
71 | return out
72 |
73 |
74 | class ConvTanh(ConvLayer):
75 | def __init__(self, in_channels, out_channels, kernel_size, stride):
76 | super(ConvTanh, self).__init__(in_channels, out_channels, kernel_size, stride)
77 | self.tanh = torch.nn.Tanh()
78 |
79 | def forward(self, x):
80 | out = super(ConvTanh, self).forward(x)
81 | return self.tanh(out/255) * 150 + 255/2
82 |
83 |
84 | class ConvInstRelu(ConvLayer):
85 | def __init__(self, in_channels, out_channels, kernel_size, stride):
86 | super(ConvInstRelu, self).__init__(in_channels, out_channels, kernel_size, stride)
87 | self.instance = torch.nn.InstanceNorm2d(out_channels, affine=True)
88 | self.relu = torch.nn.ReLU()
89 |
90 | def forward(self, x):
91 | out = super(ConvInstRelu, self).forward(x)
92 | out = self.instance(out)
93 | out = self.relu(out)
94 | return out
95 |
96 |
97 | class UpsampleConvLayer(torch.nn.Module):
98 | """Upsamples the input and then does a convolution.
99 | This method gives better results compared to ConvTranspose2d.
100 | ref: http://distill.pub/2016/deconv-checkerboard/
101 | """
102 | def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
103 | super(UpsampleConvLayer, self).__init__()
104 | self.upsample = upsample
105 | reflection_padding = int(np.floor(kernel_size / 2))
106 | self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
107 | self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
108 |
109 | def forward(self, x):
110 | x_in = x
111 | if self.upsample:
112 | x_in = torch.nn.functional.interpolate(x_in, scale_factor=self.upsample)
113 | out = self.reflection_pad(x_in)
114 | out = self.conv2d(out)
115 | return out
116 |
117 |
118 | class UpsampleConvInstRelu(UpsampleConvLayer):
119 | def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
120 | super(UpsampleConvInstRelu, self).__init__(in_channels, out_channels, kernel_size, stride, upsample)
121 | self.instance = torch.nn.InstanceNorm2d(out_channels, affine=True)
122 | self.relu = torch.nn.ReLU()
123 |
124 | def forward(self, x):
125 | out = super(UpsampleConvInstRelu, self).forward(x)
126 | out = self.instance(out)
127 | out = self.relu(out)
128 | return out
129 |
130 |
131 | class ResidualBlock(torch.nn.Module):
132 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
133 | super(ResidualBlock, self).__init__()
134 | self.conv1 = ConvLayer(in_channels, out_channels, kernel_size, stride)
135 | self.in1 = torch.nn.InstanceNorm2d(out_channels, affine=True)
136 | self.conv2 = ConvLayer(out_channels, out_channels, kernel_size, stride)
137 | self.in2 = torch.nn.InstanceNorm2d(out_channels, affine=True)
138 | self.relu = torch.nn.ReLU()
139 |
140 | def forward(self, x):
141 | residual = x
142 | out = self.relu(self.in1(self.conv1(x)))
143 | out = self.in2(self.conv2(out))
144 | out = out + residual
145 | return out
146 |
147 |
148 | class ReCoNet(SelectiveLoadModule):
149 | def __init__(self):
150 | super(ReCoNet, self).__init__()
151 |
152 | self.conv1 = ConvInstRelu(3, 32, kernel_size=9, stride=1)
153 | self.conv2 = ConvInstRelu(32, 64, kernel_size=3, stride=2)
154 | self.conv3 = ConvInstRelu(64, 128, kernel_size=3, stride=2)
155 |
156 | self.res1 = ResidualBlock(128, 128)
157 | self.res2 = ResidualBlock(128, 128)
158 | self.res3 = ResidualBlock(128, 128)
159 | self.res4 = ResidualBlock(128, 128)
160 | self.res5 = ResidualBlock(128, 128)
161 |
162 | self.deconv1 = UpsampleConvInstRelu(128, 64, kernel_size=3, stride=1, upsample=2)
163 | self.deconv2 = UpsampleConvInstRelu(64, 32, kernel_size=3, stride=1, upsample=2)
164 | self.deconv3 = ConvTanh(32, 3, kernel_size=9, stride=1)
165 |
166 | def forward(self, x):
167 | x = self.conv1(x)
168 | x = self.conv2(x)
169 | x = self.conv3(x)
170 |
171 | x = self.res1(x)
172 | x = self.res2(x)
173 | x = self.res3(x)
174 | x = self.res4(x)
175 | x = self.res5(x)
176 |
177 | features = x
178 |
179 | x = self.deconv1(x)
180 | x = self.deconv2(x)
181 | x = self.deconv3(x)
182 |
183 | return (features, x)
184 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | Pillow
3 | scikit-image
4 | opencv-python
5 | torch
6 | torchvision
7 |
--------------------------------------------------------------------------------
/styles/autoportrait.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/safwankdb/ReCoNet-PyTorch/8d6d98a32968d87c5ae26aed0aae0adb3dfa94f6/styles/autoportrait.jpg
--------------------------------------------------------------------------------
/styles/candy.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/safwankdb/ReCoNet-PyTorch/8d6d98a32968d87c5ae26aed0aae0adb3dfa94f6/styles/candy.jpg
--------------------------------------------------------------------------------
/styles/composition.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/safwankdb/ReCoNet-PyTorch/8d6d98a32968d87c5ae26aed0aae0adb3dfa94f6/styles/composition.jpg
--------------------------------------------------------------------------------
/styles/edtaonisl.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/safwankdb/ReCoNet-PyTorch/8d6d98a32968d87c5ae26aed0aae0adb3dfa94f6/styles/edtaonisl.jpg
--------------------------------------------------------------------------------
/styles/init:
--------------------------------------------------------------------------------
1 | .
2 |
--------------------------------------------------------------------------------
/styles/udnie.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/safwankdb/ReCoNet-PyTorch/8d6d98a32968d87c5ae26aed0aae0adb3dfa94f6/styles/udnie.jpg
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.utils.data import Dataset, DataLoader
4 | import numpy as np
5 | from totaldata import *
6 | from skimage import io, transform
7 |
8 | data=ConsolidatedDataset(MPI_path="../MPI_Data/" ,FC_path="../FlyingChairs2/")
9 | print(len(data))
10 | img1, img2, mask, flow= data[230]
11 | print(img1.size(), img2.size(), mask.size(), flow.size())
12 | io.imsave("img1.png", img1.squeeze().permute(1,2,0).cpu().numpy())
13 | io.imsave("img2.png", img2.squeeze().permute(1,2,0).cpu().numpy())
14 | # io.imsave("mask.png", mask.squeeze().permute(1,2,0).cpu().numpy())
15 |
--------------------------------------------------------------------------------
/testwarp.py:
--------------------------------------------------------------------------------
1 | from utilities import *
2 | from dataset import MPIDataset
3 | from skimage import io, transform
4 | from torch.utils.data import DataLoader
5 | import torch
6 |
7 | path="../MPI_Data/"
8 | dataloader = DataLoader(MPIDataset("../MPI_Data/"), batch_size=1)
9 |
10 | for itr, (img1, img2, mask, flow) in enumerate(dataloader):
11 | if(itr==1):
12 | break
13 | warped=warp(img1,flow)
14 | print(img1.squeeze().permute(1,2,0).size())
15 | io.imsave("warped.png", warped.squeeze().permute(1,2,0).cpu().numpy())
16 | io.imsave("img1.png", img1.squeeze().permute(1,2,0).cpu().numpy())
17 | io.imsave("img2.png", img2.squeeze().permute(1,2,0).cpu().numpy())
18 |
--------------------------------------------------------------------------------
/totaldata.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.utils.data import Dataset, DataLoader
4 | from torchvision import transforms, utils
5 | import numpy as np
6 | import PIL
7 | from PIL import Image
8 | from flowlib import read
9 | from skimage import io, transform
10 |
11 | device='cuda'
12 |
13 | def toString4(num):
14 | string = str(num)
15 | while(len(string) < 4):
16 | string = "0"+string
17 | return string
18 |
19 | def toString7(num):
20 | string = str(num)
21 | while(len(string) < 7):
22 | string = "0"+string
23 | return string
24 |
25 | class MPIDataset(Dataset):
26 |
27 | def __init__(self, path, transform=None):
28 | """
29 | looking at the "clean" subfolder for images, might change to "final" later
30 | root_dir -> path to the location where the "training" folder is kept inside the MPI folder
31 | """
32 | self.path = path+"training/"
33 | self.transform = transform
34 | self.dirlist = os.listdir(self.path+"clean/")
35 | self.dirlist.sort()
36 | # print(self.dirlist)
37 | self.numlist = []
38 | for folder in self.dirlist:
39 | self.numlist.append(len(os.listdir(self.path+"clean/"+folder+"/")))
40 | self.length=sum(self.numlist)-len(self.numlist)
41 |
42 | def __len__(self):
43 |
44 | return self.length
45 |
46 | def __getitem__(self, idx):
47 | """
48 | idx must be between 0 to len-1
49 | assuming flow[0] contains flow in x direction and flow[1] contains flow in y
50 | """
51 | for i in range(0, len(self.numlist)):
52 | folder = self.dirlist[i]
53 | path = self.path+"clean/"+folder+"/"
54 | occpath = self.path+"occlusions/"+folder+"/"
55 | flowpath = self.path+"flow/"+folder+"/"
56 | if(idx < (self.numlist[i]-1)):
57 | num1 = toString4(idx+1)
58 | num2 = toString4(idx+2)
59 | img1 = io.imread(path+"frame_"+num1+".png")
60 | img2 = io.imread(path+"frame_"+num2+".png")
61 | mask = io.imread(occpath+"frame_"+num1+".png")
62 | img1 = torch.from_numpy(transform.resize(img1, (360, 640))).to(device).permute(2, 0, 1).float()
63 | img2 = torch.from_numpy(transform.resize(img2, (360, 640))).to(device).permute(2, 0, 1).float()
64 | mask = torch.from_numpy(transform.resize(mask, (360, 640))).to(device).float()
65 | flow = read(flowpath+"frame_"+num1+".flo")
66 | # bilinear interpolation is default
67 | originalflow=torch.from_numpy(flow)
68 | flow = torch.from_numpy(transform.resize(flow, (360, 640))).to(device).permute(2,0,1).float()
69 | flow[0, :, :] *= float(flow.shape[1])/originalflow.shape[1]
70 | flow[1, :, :] *= float(flow.shape[2])/originalflow.shape[2]
71 | # print(flow.shape) #y,x,2
72 | # print(img1.shape)
73 | break
74 |
75 | idx -= self.numlist[i]-1
76 |
77 | if self.transform:
78 | # complete later
79 | pass
80 | #IMG2 should be at t in IMG1 is at T-1
81 | return (img1, img2, mask, flow)
82 |
83 | class FlyingChairsDataset(Dataset):
84 |
85 | def __init__(self, path, transform=None):
86 | """
87 | looking at the "clean" subfolder for images, might change to "final" later
88 | root_dir -> path to the location where the "training" folder is kept inside the MPI folder
89 | """
90 | self.path = path+"train/"
91 | self.transform = transform
92 | self.length=22230 #14 files corresponding to each image pair
93 |
94 | def __len__(self):
95 |
96 | return self.length
97 |
98 | def __getitem__(self, idx):
99 | """
100 | idx must be between 0 to len-1
101 | assuming flow[0] contains flow in x direction and flow[1] contains flow in y
102 | """
103 | num = toString7(idx)
104 | img1 = io.imread(self.path+num+"-img_0.png")
105 | img2 = io.imread(self.path+num+"-img_1.png")
106 | mask = io.imread(self.path+num+"-mb_01.png")
107 | img1 = torch.from_numpy(transform.resize(img1, (360, 640))).to(device).permute(2, 0, 1).float()
108 | img2 = torch.from_numpy(transform.resize(img2, (360, 640))).to(device).permute(2, 0, 1).float()
109 | mask = torch.from_numpy(transform.resize(mask, (360, 640))).to(device).float()
110 | flow = read(self.path+num+"-flow_01.flo")
111 | # bilinear interpolation is default
112 | originalflow=torch.from_numpy(flow)
113 | flow = torch.from_numpy(transform.resize(flow, (360, 640))).to(device).permute(2,0,1).float()
114 | flow[0, :, :] *= float(flow.shape[1])/originalflow.shape[1]
115 | flow[1, :, :] *= float(flow.shape[2])/originalflow.shape[2]
116 |
117 | if self.transform:
118 | # complete later
119 | pass
120 | #IMG2 should be at t in IMG1 is at T-1
121 | return (img1, img2, mask, flow)
122 |
123 | class ConsolidatedDataset(Dataset):
124 |
125 | def __init__(self, MPI_path, FC_path, transform=None):
126 | self.mpi=MPIDataset(MPI_path)
127 | self.fc=FlyingChairsDataset(FC_path)
128 |
129 | def __len__(self):
130 |
131 | return len(self.mpi)+len(self.fc)
132 |
133 | def __getitem__(self, idx):
134 | """
135 | idx must be between 0 to len-1
136 | """
137 | if(idx0] = 1
90 |
91 | return output*mask
92 |
93 |
94 |
95 |
--------------------------------------------------------------------------------