├── .gitignore ├── o1.png ├── o2.png ├── model2.pyc ├── dataloader.pyc ├── __pycache__ ├── model2.cpython-36.pyc ├── dataloader.cpython-36.pyc └── model_orig.cpython-36.pyc ├── generateVideoFromImages.py ├── data ├── README.md └── make_train_test_data.py ├── README.md ├── dataloader.py ├── eval.py ├── eval_videos.py ├── stackImageFrames.py ├── generate_depth_comparison.py ├── getdepthmap.py ├── train.py ├── model2.py └── model_orig.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/*/ 2 | logs/*/ 3 | outputs/*/ 4 | *.pth 5 | results/*/ 6 | -------------------------------------------------------------------------------- /o1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anishmadan23/deep3d-pytorch/HEAD/o1.png -------------------------------------------------------------------------------- /o2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anishmadan23/deep3d-pytorch/HEAD/o2.png -------------------------------------------------------------------------------- /model2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anishmadan23/deep3d-pytorch/HEAD/model2.pyc -------------------------------------------------------------------------------- /dataloader.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anishmadan23/deep3d-pytorch/HEAD/dataloader.pyc -------------------------------------------------------------------------------- /__pycache__/model2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anishmadan23/deep3d-pytorch/HEAD/__pycache__/model2.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/dataloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anishmadan23/deep3d-pytorch/HEAD/__pycache__/dataloader.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/model_orig.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anishmadan23/deep3d-pytorch/HEAD/__pycache__/model_orig.cpython-36.pyc -------------------------------------------------------------------------------- /generateVideoFromImages.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | 5 | my_dir = '../india_gate_left_genR_dir_orig_size/' 6 | 7 | all_imgs = os.listdir(my_dir) 8 | all_imgs.sort(key=lambda x:int(x[:x.find('.')])) 9 | 10 | imgs_for_video = [] 11 | for img in all_imgs: 12 | img = cv2.imread(os.path.join(my_dir,img)) 13 | height, width, layers = img.shape 14 | size = (width,height) 15 | imgs_for_video.append(img) 16 | 17 | 18 | fourcc = cv2.VideoWriter_fourcc(*'MP4V') # Be sure to use lower case 19 | out = cv2.VideoWriter('india_gate_long_orig_3d_vid.avi', fourcc, 10.0, (width, height)) 20 | # out = cv2.VideoWriter('long_orig_3d_vid.mp4',cv2.VideoWriter_fourcc(*'DIVX'), 15, size) 21 | 22 | for i in range(len(imgs_for_video)): 23 | print(i) 24 | out.write(imgs_for_video[i]) 25 | out.release() 26 | 27 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Handling KITTI RAW Data 2 | The training/test data consists of stereo pairs from [KITTI RAW Data](http://www.cvlibs.net/datasets/kitti/raw_data.php) 3 | 4 | **Train+Val Data** comprised of stereo pairs from following files: 5 | - 2011_09_26_drive_0011_sync.zip 6 | - 2011_09_26_drive_0022_sync.zip 7 | - 2011_09_26_drive_0059_sync.zip 8 | - 2011_09_26_drive_0084_sync.zip 9 | - 2011_09_26_drive_0093_sync.zip 10 | - 2011_09_26_drive_0095_sync.zip 11 | - 2011_09_26_drive_0096_sync.zip 12 | 13 | **Test Data** comprised of stereo pairs from following files: 14 | - 2011_09_26_drive_0019_sync.zip 15 | - 2011_09_26_drive_0091_sync.zip 16 | 17 | ## Creating train, val and test data 18 | - Place all these files in this (data/) folder. 19 | **NOTE**: Remove the outermost folder after extracting each of the following zip files, i.e if 2011_09_26_drive_0096_sync.zip 20 | is extracted, then outermost folder is 2011_09_26/. Remove this folder so that the outermost folder is 21 | now 2011_09_26_drive_0096_sync/. 22 | 23 | - Run make_train_test_data.py to get train,val and test data. 24 | ``` 25 | python3 make_train_test_data.py 26 | ``` 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep3D-pytorch 2 | ## Team Members: Anish Madan, Apoorv Khattar, Yash Tomar 3 | 4 | ## About the porject 5 | Estimating a right-view from a monocular image(to make a stereo pair of images) which respects the geometry of the scene is 6 | an important problem in computer vision. This repository aims to achieve this by implementing 7 | [Deep3D](https://arxiv.org/abs/1604.03650) ([see original repo](https://github.com/piiswrong/deep3d)) using PyTorch 8 | to generate right view of images. This generated pair of images can then be used to estimate depth in images, 9 | convert 2D video to 3D, etc. 10 | 11 | ## Dataset 12 | We used the KITTI Stereo 2015 Dataset. The dataset consists of 200 training scenes and 200 test scenes, which include 4 color images per scene, in a lossless png format. This means that we have 400 left and right image pairs for training. 13 | 14 | ## Model Weights 15 | The pretrained model can be downloaded from [here](https://drive.google.com/drive/folders/1txjqUjCcEvEkVS8QvNn1icrP34eh-crJ?usp=sharing) 16 | 17 | ## Results 18 | The following are some results of our approach (from left to right- input left , ground truth right stereo image pair, generated stereo image pair, depth map generated using OpenCV for image pairs): 19 | 20 | ![s1](https://github.com/anishmadan23/deep3d-pytorch/blob/master/o1.png) 21 | 22 | ![s2](https://github.com/anishmadan23/deep3d-pytorch/blob/master/o2.png) 23 | 24 | ###### This work was done as part of our project for CSE344: Computer Vision course at IIIT Delhi. 25 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | 5 | import numpy as np 6 | import skimage.io as io 7 | import torch.utils.data as data 8 | 9 | from skimage.transform import resize 10 | 11 | class MyDataset(data.Dataset): 12 | def __init__(self, root, in_transforms = None, orig_size = (384, 1280),small_size=(96,320)): 13 | self.leftpath = os.path.join(root, 'left') 14 | self.leftimg = os.listdir(self.leftpath) 15 | 16 | self.rightpath = os.path.join(root, 'right') 17 | self.rightimg = os.listdir(self.rightpath) 18 | 19 | self.leftimg.sort() 20 | self.rightimg.sort() 21 | 22 | self.orig_size = orig_size 23 | self.small_size= small_size 24 | 25 | def __len__(self): 26 | return len(self.leftimg) 27 | 28 | def __getitem__(self, index): 29 | leftImage = io.imread(os.path.join(self.leftpath, self.leftimg[index])) 30 | # print(leftImage.shape) 31 | leftImage_orig = resize(leftImage, self.orig_size) / 255.0 32 | 33 | leftImage_small = resize(leftImage, self.small_size) / 255.0 34 | 35 | rightImage_orig = io.imread(os.path.join(self.rightpath, self.rightimg[index])) 36 | rightImage_orig = resize(rightImage_orig, self.orig_size) /255.0 37 | 38 | left_orig = torch.from_numpy(leftImage_orig) 39 | left_orig = left_orig.permute([-1,0,1]) 40 | 41 | left_small = torch.from_numpy(leftImage_small) 42 | left_small = left_small.permute([-1,0,1]) 43 | 44 | right_orig = torch.from_numpy(rightImage_orig) 45 | right_orig = right_orig.permute([-1,0,1]) 46 | 47 | return left_orig, left_small, right_orig 48 | 49 | 50 | # data_obj = MyDataset('./data/train/') 51 | # train_dataloader = data.DataLoader(data_obj, batch_size = 4, shuffle = True) 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | import time 4 | import datetime 5 | import os 6 | import copy 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import torchvision 13 | import torch.nn.functional as F 14 | 15 | from torchvision.utils import make_grid, save_image 16 | from torch.optim import lr_scheduler 17 | from torchvision import datasets, models, transforms 18 | from tensorboardX import SummaryWriter 19 | 20 | from model_orig import * 21 | from dataloader import * 22 | 23 | RES_DIR = './results/' 24 | now = str(datetime.datetime.now()) 25 | 26 | if not os.path.exists(RES_DIR): 27 | os.makedirs(RES_DIR) 28 | if not os.path.exists(RES_DIR+now): 29 | os.makedirs(RES_DIR+now) 30 | RES_DIR = RES_DIR + now + '/' 31 | 32 | device = torch.device('cuda') 33 | print(device) 34 | dataroot = './data/test/' 35 | weight_file = './data/99_20_view_syn_weights_l1with_scheduler.pth' 36 | batch = 1 37 | img_size = (96, 320) 38 | 39 | model = Deep3d(device=device).to(device) 40 | model.load_state_dict(torch.load(weight_file)) 41 | 42 | test_dataset = MyDataset(dataroot, in_transforms = None) 43 | test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size = batch, shuffle = False) 44 | 45 | print(len(test_dataloader)) 46 | 47 | model.eval() 48 | for i, data in enumerate(test_dataloader): 49 | with torch.no_grad(): 50 | left_orig = data[0].to(device).float() 51 | left = data[1].to(device).float() 52 | right = data[2].to(device).float() 53 | 54 | output = model(left_orig,left) 55 | 56 | save_image(left_orig, RES_DIR + '{}_left.png'.format(i)) 57 | save_image(right, RES_DIR + '{}_right.png'.format(i)) 58 | save_image(output, RES_DIR + '{}_.genR.png'.format(i)) 59 | print(output.shape, left.shape, right.shape) 60 | # sys.exit() 61 | -------------------------------------------------------------------------------- /eval_videos.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | import time 4 | import datetime 5 | import os 6 | import copy 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import torchvision 13 | import torch.nn.functional as F 14 | 15 | from torchvision.utils import make_grid, save_image 16 | from torch.optim import lr_scheduler 17 | from torchvision import datasets, models, transforms 18 | from tensorboardX import SummaryWriter 19 | 20 | from model2 import * 21 | from dataloader_raw import * 22 | 23 | # RES_DIR = './results/' 24 | RES_DIR ='../2011_09_26_long/2011_09_26_drive_0014_sync/' 25 | RES_DIR_LEFT= RES_DIR+'left/' 26 | RES_DIR_RIGHT= RES_DIR+'right/' 27 | RES_DIR_GENR = RES_DIR+'genR/' 28 | 29 | # now = str(datetime.datetime.now()) 30 | 31 | if not os.path.exists(RES_DIR_GENR): 32 | os.makedirs(RES_DIR_GENR) 33 | 34 | if not os.path.exists(RES_DIR_LEFT): 35 | os.makedirs(RES_DIR_LEFT) 36 | 37 | 38 | if not os.path.exists(RES_DIR_RIGHT): 39 | os.makedirs(RES_DIR_RIGHT) 40 | 41 | 42 | # if not os.path.exists(RES_DIR+now): 43 | # os.makedirs(RES_DIR+now) 44 | # RES_DIR = RES_DIR + now + '/' 45 | 46 | device = torch.device('cuda') 47 | print(device) 48 | # dataroot = '/home/apoorv/Documents/Practice/CV/Project/data_scene_flow/training/' 49 | dataroot = '../2011_09_26_long/2011_09_26_drive_0014_sync/' 50 | # dataroot = '../india_gate_frames/' 51 | 52 | weight_file = '../99_20_view_syn_weights_l1with_scheduler.pth' 53 | batch = 1 54 | img_size = (96, 320) 55 | 56 | model = Deep3d(device=device).to(device) 57 | model.load_state_dict(torch.load(weight_file)) 58 | 59 | test_dataset = MyDataset(dataroot, in_transforms = None, size = img_size) 60 | test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size = batch, shuffle = False) 61 | 62 | print(len(test_dataloader)) 63 | 64 | model.eval() 65 | for i, data in enumerate(test_dataloader): 66 | with torch.no_grad(): 67 | left = data[0].to(device).float() 68 | right = data[1].to(device).float() 69 | 70 | output = model(left) 71 | 72 | save_image(left, RES_DIR_LEFT + '{}.png'.format(i)) 73 | save_image(right, RES_DIR_RIGHT + '{}.png'.format(i)) 74 | save_image(output, RES_DIR_GENR + '{}.png'.format(i)) 75 | 76 | # print(output.shape, left.shape, right.shape) 77 | print(output.shape, left.shape) 78 | 79 | # sys.exit() 80 | 81 | -------------------------------------------------------------------------------- /stackImageFrames.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | 5 | 6 | root_dir = '../2011_09_26_long/2011_09_26_drive_0014_sync/' 7 | left_dir = root_dir+'left/' 8 | right_dir = root_dir+'right/' 9 | genR_dir = root_dir+'genR/' 10 | 11 | 12 | left_imgs = os.listdir(left_dir) 13 | right_imgs = os.listdir(right_dir) 14 | genR_imgs = os.listdir(genR_dir) 15 | 16 | left_imgs.sort(key=lambda x:int(x[:x.find('.')])) 17 | right_imgs.sort(key=lambda x:int(x[:x.find('.')])) 18 | genR_imgs.sort(key=lambda x:int(x[:x.find('.')])) 19 | 20 | # left_imgs.sort(key=lambda x:int(x.split('.')[0])) 21 | # # right_imgs.sort(key=lambda x:int(x[:x.find('.')])) 22 | # genR_imgs.sort(key=lambda x:int(x.split('.')[0])) 23 | 24 | 25 | 26 | all_img_types = [left_imgs,right_imgs,genR_imgs] 27 | all_img_dirs = [left_dir,right_dir,genR_dir] 28 | 29 | # all_img_types = [left_imgs,genR_imgs] 30 | # all_img_dirs = [left_dir,genR_dir] 31 | for idx,img_type in enumerate(all_img_types): 32 | for k,img in enumerate(img_type): 33 | img = cv2.imread(os.path.join(all_img_dirs[idx],img)) 34 | # img = cv2.resize(img,(img.shape[1]*2,img.shape[0]*2), interpolation = cv2.INTER_CUBIC) 35 | all_img_types[idx][k] = img 36 | 37 | 38 | new_left_right_dir = '../final_left_right_dir_orig_size/' 39 | if not os.path.exists(new_left_right_dir): 40 | os.makedirs(new_left_right_dir) 41 | 42 | 43 | for i in range(len(left_imgs)): 44 | new_img = np.zeros((left_imgs[i].shape[0]*2,left_imgs[i].shape[1],left_imgs[i].shape[2])) 45 | new_img[:left_imgs[i].shape[0],:,:] = left_imgs[i] 46 | new_img[left_imgs[i].shape[0]:,:,:] = right_imgs[i] 47 | save_name = str(i)+str('.png') 48 | cv2.imwrite(os.path.join(new_left_right_dir,save_name),new_img) 49 | 50 | new_left_genR_dir = '../final_left_genR_dir_orig_size/' 51 | if not os.path.exists(new_left_genR_dir): 52 | os.makedirs(new_left_genR_dir) 53 | 54 | 55 | for i in range(len(left_imgs)): 56 | new_img = np.zeros((left_imgs[i].shape[0]*2,left_imgs[i].shape[1],left_imgs[i].shape[2])) 57 | new_img[:left_imgs[i].shape[0],:,:] = left_imgs[i] 58 | new_img[left_imgs[i].shape[0]:,:,:] = genR_imgs[i] 59 | save_name = str(i)+str('.png') 60 | cv2.imwrite(os.path.join(new_left_genR_dir,save_name),new_img) 61 | 62 | # print(left_imgs[0].shape) 63 | # cv2.imwrite('resized.png',left_imgs[0]) 64 | # resized_left_imgs = [cv2.resize(img,(img.shape[1]*2,img.shape[0]*2), interpolation = cv2.INTER_CUBIC) for img in ] 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /generate_depth_comparison.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import sys 4 | import os 5 | import math 6 | 7 | from sklearn.preprocessing import normalize 8 | from cv2 import ximgproc, StereoSGBM_create 9 | 10 | def generate_depth(imgL, imgR): 11 | 12 | window_size = 7 13 | 14 | lm = StereoSGBM_create(minDisparity=0,numDisparities=16,blockSize=5,P1=8 * 3 * window_size ** 2,P2=32 * 3 * window_size ** 2,disp12MaxDiff=1,) 15 | 16 | rm = ximgproc.createRightMatcher(lm) 17 | 18 | wlf = ximgproc.createDisparityWLSFilter(matcher_left=lm) 19 | wlf.setLambda(80000) 20 | wlf.setSigmaColor(1.2) 21 | displ = lm.compute(imgL, imgR) 22 | dispr = rm.compute(imgR, imgL) 23 | final = wlf.filter(displ, imgL, None, dispr) 24 | final = final - 2*np.min(final) 25 | final = final*255/np.max(final) 26 | final = np.uint8(final) 27 | return final 28 | 29 | INPUT_DIR = './results/2019-04-29 12:29:15.506219/' 30 | OUTPUT_DIR = './depth/' 31 | list_img = os.listdir(INPUT_DIR) 32 | list_img.sort() 33 | 34 | mae = [] 35 | rmse = [] 36 | for i in range(0,len(list_img),3): 37 | print(list_img[i], list_img[i+1], list_img[i+2]) 38 | imgL = cv2.imread(INPUT_DIR+list_img[i+2]) 39 | imgR = cv2.imread(INPUT_DIR+list_img[i+1]) 40 | depth_map_out = generate_depth(imgL, imgR) 41 | # depth_map_out = depth_map_out / np.max(depth_map_out) 42 | # depth_map_out_s = depth_map_out * 255 43 | # depth_map_out[0:30,:] = 0 44 | 45 | imgRG = cv2.imread(INPUT_DIR+list_img[i]) 46 | # imgRG = cv2.cvtColor(imgRG, cv2.COLOR_BGR2GRAY) 47 | # imgRG = imgRG / np.max(imgRG) 48 | # imgRGS = imgRG * 255 49 | depth_map_ground = generate_depth(imgL, imgRG) 50 | 51 | if(np.max(depth_map_ground)>1): 52 | depth_map_ground = depth_map_ground / 255.0 * 65.0 + 1.0 53 | if(np.max(depth_map_out)>1): 54 | depth_map_out = depth_map_out / 255.0 * 65.0 + 1.0 55 | 56 | cv2.imwrite(OUTPUT_DIR+'{}_g.png'.format(i), depth_map_ground.astype(np.uint8)) 57 | cv2.imwrite(OUTPUT_DIR+'{}_o.png'.format(i), depth_map_out.astype(np.uint8)) 58 | 59 | diff = np.abs(depth_map_out - depth_map_ground) 60 | mae.append(np.sum(diff)/(depth_map_out.shape[0]*depth_map_out.shape[1])/depth_map_ground) 61 | rmse.append(np.sum(diff**2)/(depth_map_out.shape[0]*depth_map_out.shape[1])) 62 | 63 | #sys.exit() 64 | rmse = np.array(rmse) 65 | rmse = np.mean(rmse**(0.5)) 66 | mae = np.array(mae) 67 | mae = np.mean(mae) 68 | 69 | print('RMSE = {}, MAE = {}'.format(rmse, mae)) -------------------------------------------------------------------------------- /getdepthmap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 as cv 3 | import cv2 4 | ply_header = '''ply 5 | format ascii 1.0 6 | element vertex 7 | ''' 8 | str2=''' 9 | property float x 10 | property float y 11 | property float z 12 | property uchar red 13 | property uchar green 14 | property uchar blue 15 | end_header 16 | ''' 17 | 18 | def write(verts, colors): 19 | verts = verts.reshape(-1, 3) 20 | vertex=np.zeros((colors.shape[0],6)) 21 | colors = colors.reshape(-1, 3) 22 | 23 | for i in range(colors.shape[0]): 24 | vertex[i][0]=verts[i][0] 25 | vertex[i][1]=verts[i][1] 26 | vertex[i][2]=verts[i][2] 27 | vertex[i][3]=colors[i][0] 28 | vertex[i][4]=colors[i][1] 29 | vertex[i][5]=colors[i][2] 30 | with open('depth.ply', 'wb') as f: 31 | header=ply_header+str(colors.shape[0])+str2 32 | f.write((header).encode('utf-8')) 33 | for i in range(colors.shape[0]): 34 | string=str(vertex[i][0])+" "+str(vertex[i][1])+" "+str(vertex[i][2])+" "+str(vertex[i][3])+" "+str(vertex[i][4])+" "+str(vertex[i][5])+" " 35 | f.write((string).encode('utf-8')) 36 | 37 | 38 | 39 | def main(filteredImg,imgL,mini): 40 | disp=filteredImg 41 | min_disp = 16 42 | num_disp = 112-min_disp 43 | (h, w,_) = imgL.shape 44 | mask=disp>mini 45 | Q=np.zeros((4,4)) 46 | Q[0][0]=1 47 | Q[1][1]=-1 48 | Q[3][2]=1 49 | Q[0][3]=-0.5*w 50 | Q[1][3]=0.5*h 51 | Q[2][3]=0.8*w 52 | Q=np.float32(Q) 53 | points = cv.reprojectImageTo3D(disp, Q) 54 | colors = cv.cvtColor(imgL, cv.COLOR_BGR2RGB) 55 | 56 | 57 | 58 | out_p = points[mask] 59 | out_c = colors[mask] 60 | 61 | mask =np.zeros((disp.shape[0],disp.shape[1]),dtype=int) 62 | 63 | for i in range(disp.shape[0]): 64 | for j in range(disp.shape[1]): 65 | if(disp[i][j]>mini): 66 | mask[i][j]=1 67 | # print(mask[i][j]) 68 | else: 69 | mask[i][j]=0 70 | 71 | write(out_p, out_c) 72 | 73 | 74 | print('loading images...') 75 | imgL = cv2.imread('l.jpg') 76 | imgR = cv2.imread('R.jpg') 77 | 78 | 79 | 80 | window_size = 3 81 | 82 | lm = cv2.StereoSGBM_create(minDisparity=0,numDisparities=16) 83 | 84 | rm = cv2.ximgproc.createRightMatcher(lm) 85 | 86 | wlf = cv2.ximgproc.createDisparityWLSFilter(matcher_left=lm) 87 | wlf.setLambda(80000) 88 | wlf.setSigmaColor(1.2) 89 | displ = lm.compute(imgL, imgR) 90 | dispr = rm.compute(imgR, imgL) 91 | displ = np.int16(displ) 92 | dispr = np.int16(dispr) 93 | final = wlf.filter(displ, imgL, None, dispr) 94 | final = np.uint8(final) 95 | main(final,imgL,final.min()) 96 | cv.destroyAllWindows() 97 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | import time 4 | import datetime 5 | import os 6 | import copy 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import torchvision 13 | import torch.nn.functional as F 14 | 15 | from torchvision.utils import make_grid, save_image 16 | from torchvision import datasets, models, transforms 17 | from tensorboardX import SummaryWriter 18 | 19 | from model2 import * 20 | from dataloader import * 21 | 22 | ########### TensorboardX ########### 23 | LOG_DIR = './logs/' 24 | 25 | now = str(datetime.datetime.now()) 26 | OUTPUTS_DIR = './outputs/' 27 | 28 | if not os.path.exists(LOG_DIR): 29 | os.makedirs(LOG_DIR) 30 | 31 | if not os.path.exists(OUTPUTS_DIR): 32 | os.makedirs(OUTPUTS_DIR) 33 | OUTPUTS_DIR = OUTPUTS_DIR + now + '/' 34 | if not os.path.exists(OUTPUTS_DIR): 35 | os.makedirs(OUTPUTS_DIR) 36 | if not os.path.exists(LOG_DIR+now): 37 | os.makedirs(LOG_DIR+now) 38 | 39 | writer = SummaryWriter(LOG_DIR + now) 40 | 41 | ########### Arguments ########### 42 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 43 | # device= torch.device('cpu') 44 | print(device) 45 | max_epoch = 100 46 | 47 | train_dataroot = './data/train/' 48 | val_dataroot = './data/val/' 49 | 50 | batch = 2 51 | save_after = 2 52 | lr = 0.0004 53 | save_file = 'view_syn_weights_l1with_scheduler.pth' 54 | img_size = (96,320) 55 | 56 | 57 | momentum = 0.95 58 | weight_decay = 1.0e-4 59 | resume = False 60 | log = "error_log.txt" 61 | 62 | ########### Model ########### 63 | model = Deep3d(device=device).to(device) 64 | if(resume): 65 | model.load_state_dict(torch.load(save_file)) 66 | print(model) 67 | 68 | ########### Dataloader ########### 69 | train_dataset = MyDataset(train_dataroot, in_transforms = None) 70 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = batch, shuffle = True) 71 | 72 | val_dataset = MyDataset(val_dataroot, in_transforms=None) 73 | val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size = batch, shuffle=False) 74 | 75 | dataloaders = {'train': train_dataloader, 'val': val_dataloader} 76 | dataset_sizes = {'train': len(train_dataset),'val':len(val_dataset)} 77 | ########### Criterion ########### 78 | optimizer = optim.Adam([ 79 | {'params': [param for name, param in model.named_parameters() if name[-4:] == 'bias'], 80 | 'lr': 2 * lr}, 81 | {'params': [param for name, param in model.named_parameters() if name[-4:] != 'bias'], 82 | 'lr': lr, 'weight_decay': weight_decay} 83 | ], betas=(momentum, 0.999)) 84 | 85 | lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) 86 | 87 | ########### Begin Training ########### 88 | epoch = 0 89 | best_loss = 100000 90 | while(epoch < max_epoch): 91 | since = time.time() 92 | criterion = nn.L1Loss().cuda() 93 | 94 | print('Epoch {}/{}'.format(epoch, max_epoch - 1)) 95 | print('-' * 10) 96 | 97 | for phase in ['train','val']: 98 | if phase=='train': 99 | lr_scheduler.step() 100 | model.train() 101 | else: 102 | model.eval() 103 | 104 | running_loss = 0.0 105 | 106 | for iteration,data in enumerate(dataloaders[phase]): 107 | left_orig = data[0].to(device).float() 108 | left = data[1].to(device).float() 109 | right = data[2].to(device).float() 110 | 111 | optimizer.zero_grad() 112 | 113 | with torch.set_grad_enabled(phase == 'train'): 114 | output = model(left_orig,left) 115 | loss = criterion(output, right) 116 | 117 | if phase=='train': 118 | loss.backward() 119 | optimizer.step() 120 | 121 | writer.add_scalar('loss',loss.item()) 122 | 123 | print('Epoch {}, Iteration: {}, Loss: {}'.format(epoch,iteration,loss.item())) 124 | running_loss += loss.item() * left.size(0) 125 | 126 | if iteration % 200 == 0 and phase=='val': 127 | # print(left.shape) 128 | save_image(left_orig, OUTPUTS_DIR + '{}_{}_scan.png'.format(epoch, iteration)) 129 | save_image(right, OUTPUTS_DIR + '{}_{}_out.png'.format(epoch, iteration)) 130 | save_image(output, OUTPUTS_DIR + '{}_{}_rgb.png'.format(epoch, iteration)) 131 | 132 | 133 | epoch_loss = running_loss / dataset_sizes[phase] 134 | 135 | print('{} Loss: {:.4f}'.format(phase, epoch_loss)) 136 | 137 | if phase == 'val' and running_loss