├── .gitignore ├── README.md ├── demo.py ├── depth_refiner.py ├── egl_depthRenderer ├── Makefile ├── README.md ├── shaders │ ├── BerycentricGeometryShader.geometryshader │ ├── DepthRTT.fragmentshader │ └── DepthRTT.vertexshader └── src │ ├── common │ ├── shader.cpp │ └── shader.hpp │ ├── mLibSource.cpp │ └── main.cpp ├── model_epoch_99.pth ├── raw_depth0195.png ├── requirements.txt └── ulapsrn.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.log 2 | *.bak 3 | 4 | __pycache__/ 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # depth_refine_reconstruct 2 | 3 | Currently we only provide the code for inference, but the necessary codes for training (network architecture, discontinuity loss) are available now (in ulapsrn.py). 4 | 5 | ### (NEW) clean depth image rendering [code](egl_depthRenderer/) updated (2019.05.09) 6 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import io 2 | import cv2 3 | from PIL import Image 4 | from depth_refiner import DepthPredictor 5 | 6 | pil_img = Image.open('raw_depth0195.png') 7 | d_predictor = DepthPredictor(gpu_id=0) 8 | 9 | e1 = cv2.getTickCount() 10 | res = d_predictor(pil_img) 11 | e2 = cv2.getTickCount() 12 | 13 | t = (e2 - e1)/cv2.getTickFrequency() 14 | print('time elapsed:', t) 15 | 16 | cv2.imwrite('result.png', res) 17 | -------------------------------------------------------------------------------- /depth_refiner.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import io 3 | import torch 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import math 7 | import cv2 8 | import scipy 9 | import scipy.misc 10 | from PIL import Image 11 | from ulapsrn import Net_Quarter_Half_Original 12 | 13 | 14 | def generate_grid(h, w, fov): 15 | x = (torch.arange(1, w + 1) - (w + 1) / 2) / (w / 2) * math.tan(fov / 2 / 180 * math.pi) 16 | y = -(torch.arange(1, h + 1) - (h + 1) / 2) / (h / 2) * math.tan(fov / 2 / 180 * math.pi) * (h / w) 17 | grid = torch.stack([x.repeat(h, 1), y.repeat(w, 1).t(), torch.ones(h, w, dtype=torch.int64)], 0) 18 | return grid.type(torch.FloatTensor) 19 | 20 | 21 | def get_normal(x): 22 | [b, c, h, w] = x.size() 23 | grid = generate_grid(482, 642, 60) 24 | ph = (482 - h) // 2 25 | pw = (642 - w) // 2 26 | grid = grid.narrow(1, ph + 1, h).narrow(2, pw + 1, w) 27 | padding = torch.nn.ReflectionPad2d((1, 1, 1, 1)) 28 | v = x.repeat(1, 3, 1, 1) 29 | pv = padding(v * grid) 30 | gx = pv.narrow(3, 0, w).narrow(2, 0, h) / 2 - pv.narrow(3, 2, w).narrow(2, 0, h) / 2 31 | gy = pv.narrow(2, 2, h).narrow(3, 0, w) / 2 - pv.narrow(2, 0, h).narrow(3, 0, w) / 2 32 | crs = gx.cross(gy, 1) 33 | norm = crs.norm(2, 1, keepdim=True).repeat(1, 3, 1, 1) 34 | n = -crs / (norm.clamp(min=1e-8)) 35 | return n 36 | 37 | 38 | class DepthPredictor: 39 | def __init__(self, gpu_id=0): 40 | self.gpu_id = gpu_id 41 | print("===> Building model") 42 | self.model = Net_Quarter_Half_Original() 43 | print("===> Setting model") 44 | self.model = self.model.cuda(self.gpu_id) 45 | # weights = torch.load('model_original_deep_mask_normal_ssim_tv_upconv/model_epoch_99.pth') 46 | weights = torch.load('./model_epoch_99.pth') 47 | self.model.load_state_dict(weights['model'].state_dict()) 48 | self.model.eval() 49 | print('=> warming up') 50 | data_bytes = open('raw_depth0195.png', 'rb').read() 51 | pil_img = Image.open(io.BytesIO(data_bytes)) 52 | self.__call__(pil_img) 53 | print('=> loading done') 54 | 55 | def __call__(self, pil_input): 56 | with torch.no_grad(): 57 | print("===> Loading Input Depth") 58 | # pil_input = pil_input.crop((40, 60, 600, 460)) 59 | img_numpy = np.array(pil_input).astype(np.uint16) 60 | h, w = img_numpy.shape 61 | img_numpy_quarter = np.array(pil_input.resize((w // 4, h // 4), Image.NEAREST)).astype(np.uint16) 62 | input = Variable(torch.from_numpy(img_numpy.astype(np.int32)).unsqueeze(0).unsqueeze(0)).float()/1000.0 63 | input_quarter = Variable(torch.from_numpy(img_numpy_quarter.astype(np.int32)).unsqueeze(0).unsqueeze(0)).float() / 1000.0 64 | input = input.cuda(self.gpu_id) 65 | input_quarter = input_quarter.cuda(self.gpu_id) 66 | print("===> Testing") 67 | pred_original, pred_half, pred_quarter = self.model((input, input_quarter)) 68 | res = pred_original 69 | 70 | depth = res.data.squeeze().cpu().numpy() 71 | res_med = torch.from_numpy(cv2.medianBlur(depth, 3)).unsqueeze(0).unsqueeze(0) 72 | res_med_img = (res_med[0][0]*1000).numpy().astype(np.uint16) 73 | 74 | # auxiliary process for surface normal estimation 75 | # normal_input = (get_normal(input.cpu()) + 1) / 2.0 76 | # normal_output_med = (get_normal(res_med.cpu()) + 1) / 2.0 77 | 78 | # normal_input = scipy.misc.toimage(normal_input.squeeze().data.cpu().numpy().transpose((1, 2, 0))) 79 | # normal_output_med = scipy.misc.toimage(normal_output_med.squeeze().data.cpu().numpy().transpose((1, 2, 0))) 80 | return res_med_img 81 | 82 | def get_input(self, pil_input): 83 | with torch.no_grad(): 84 | print("===> Loading Input Depth") 85 | # pil_input = pil_input.crop((40, 60, 600, 460)) 86 | img_numpy = np.array(pil_input).astype(np.uint16) 87 | h, w = img_numpy.shape 88 | input = Variable(torch.from_numpy(img_numpy.astype(np.int32)).unsqueeze(0).unsqueeze(0)).float() / 1000.0 89 | normal_input = (get_normal(input) + 1) / 2.0 90 | normal_input = scipy.misc.toimage(normal_input.squeeze().data.cpu().numpy().transpose((1, 2, 0))) 91 | return normal_input 92 | -------------------------------------------------------------------------------- /egl_depthRenderer/Makefile: -------------------------------------------------------------------------------- 1 | .SUFFIXES: 2 | 3 | 4 | CXX = g++ 5 | FLAGS = -g -std=c++11 6 | FLAGS += -I "src" 7 | FLAGS += -I "../mLib/include" 8 | FLAGS += -I "../mLib/src" 9 | 10 | OPENCV = `pkg-config opencv --cflags --libs` 11 | 12 | LFLAGS = -g 13 | LFLAGS += -lpthread -lglut -lGLU -lGL -lglfw -lGLEW -lEGL 14 | LFLAGS += $(OPENCV) 15 | 16 | SRC = main.cpp mLibSource.cpp shader.cpp 17 | OBJS = $(SRC:.cpp=.o) 18 | EXECUTABLE = depthRenderer 19 | 20 | .PHONY: all purge clean 21 | 22 | all: $(EXECUTABLE) 23 | 24 | build/%.o: src/%.cpp 25 | $(CXX) $(FLAGS) -MP -MD $(<,.o=.d) $< -c -o $@ 26 | 27 | build/%.o: src/common/%.cpp 28 | $(CXX) $(FLAGS) -MP -MD $(<,.o=.d) $< -c -o $@ 29 | 30 | $(EXECUTABLE): $(addprefix build/, $(OBJS)) 31 | $(CXX) $^ -o $@ $(LFLAGS) 32 | 33 | clean: 34 | rm -rf build/*.o build/*.d 35 | rm -rf $(EXECUTABLE) 36 | 37 | purge: clean 38 | rm -rf build/* 39 | 40 | # dependency rules 41 | include $(wildcard build/*.d) 42 | -------------------------------------------------------------------------------- /egl_depthRenderer/README.md: -------------------------------------------------------------------------------- 1 | # Depth renderer for ScanNet reconstructed models 2 | 3 | You can use this code for rendering the ScanNet 3D model, to get the clean depth images for corresponding camera poses. 4 | 5 | --- 6 | ## Requirements 7 | 8 | opencv, mLib, glut, glew, glfw, EGL 9 | 10 | You may already have GL and extensions on Ubuntu. 11 | 12 | By default, [mLib](https://github.com/niessner/mLib) should be located at same level of renderer directory. 13 | 14 | You can change it by modifying MakeFlie. 15 | 16 | ### Desired folder hierachy 17 | 18 | this_repo/ --- egl_depthRenderer/ --- src/ -------- common/ 19 | |- mLib/ |- Makefile |- main.cpp 20 | |- ... |- mLibSource.cpp 21 | |- README.md 22 | 23 | 24 | 25 | --- 26 | 27 | ## How to use it 28 | 29 | The code needs reconstructed .ply model and camera poses from [ScanNet dataset](http://www.scan-net.org). 30 | 31 | Camer pose files (frame-XXXXXX.pose.txt) contain 4x4 transformation matrix that project world to the camera coordinate system. 32 | 33 | $ cat frame-000000.pose.txt 34 | -0.955421 0.119616 -0.269932 2.65583 35 | 0.295248 0.388339 -0.872939 2.9816 36 | 0.000407581 -0.91372 -0.406343 1.36865 37 | 0 0 0 1 38 | $ _ 39 | 40 | You can extract camera pose files of each scan from \*.sens file in ScanNet dataset (see [here](https://github.com/ScanNet/ScanNet/tree/master/SensReader)). 41 | 42 | $ make 43 | $ ./depthRenderer 44 | 45 | For example: 46 | $ ./depthRenderer scene0000_00_vh_clean.ply poses depth_out 100 47 | Render 'scene0000_00_vh_clean.ply' to 'depth_out' with interval 100 48 | Loaded a mesh with 1990518 vertices 49 | Compiling shader : DepthRTT.vertexshader 50 | Compiling shader : DepthRTT.fragmentshader 51 | Compiling shader : BerycentricGeometryShader.geometryshader 52 | Linking program 53 | Processing 5500.. 54 | 0.333942s 55 | $ _ 56 | 57 | ## Quality of rendered depth images 58 | As you may know, rendered depth image can have different structure compared to originally captured raw depth image with same camera poses. It is caused by the reconstruction process may deform the 3D model during global error minimization, camera pose estimation errors, and so on. 59 | 60 | To utilize the rendered clean depth for GT of supervised learning such as single-view depth estimation or depth refinement, several pre/post processing techniques can be options. 61 | 62 | For example, before render the depth, you can locally re-adjust the camera pose using ICP or similar methods to achieve better alignment between rendered clean depth image and raw depth image (and color image). Or, you can just crop the local patches whose structure similarity to the raw depth patch is above the threshold to build a patch-wise dataset (like our ECCV paper). 63 | 64 | ## FAQ 65 | - Does it only work with ScanNet dataset? 66 | - With a little modification, you can use this code to render any 3D models (meshes) with arbitrary camera poses and parameters. 67 | - Why it took so long time to publish this code? 68 | - Preparing graduation.. and first renderer was written in Unity scripts, but Unity cannot handle well such a huge 3D model with millions of vertices. So I re-implemented the depth renderer. 69 | 70 | ## TODO 71 | - ~~Support headless rendering through SSH without X-window~~ (Done) 72 | - discard opencv dependency (may use stb_image) 73 | - support direct pose extraction from *.sens files. 74 | 75 | ## Contributors 76 | Junho Jeon (zwitterion27@gmail.com) 77 | 78 | Jinwoong Jung (jinwoong.jung@postech.ac.kr) 79 | -------------------------------------------------------------------------------- /egl_depthRenderer/shaders/BerycentricGeometryShader.geometryshader: -------------------------------------------------------------------------------- 1 | #version 330 core 2 | 3 | layout(triangles) in; 4 | layout(triangle_strip, max_vertices = 3) out; 5 | 6 | flat in vec4 vcolor[]; 7 | 8 | flat out vec4 colors[3]; 9 | out vec3 coord; 10 | 11 | void main(){ 12 | for (int i = 0; i < 3; ++i) 13 | colors[i] = vcolor[i]; 14 | for (int i = 0; i < 3; ++i) 15 | { 16 | coord = vec3(0.0); 17 | coord[i] = 1.0; 18 | gl_Position = gl_in[i].gl_Position; 19 | EmitVertex(); 20 | } 21 | EndPrimitive(); 22 | } -------------------------------------------------------------------------------- /egl_depthRenderer/shaders/DepthRTT.fragmentshader: -------------------------------------------------------------------------------- 1 | #version 330 core 2 | 3 | // Ouput data 4 | layout(location = 0) out float fragmentdepth; 5 | 6 | //uniform mat4 gl_ProjectionMatrix; 7 | uniform float zNear; 8 | uniform float zFar; 9 | 10 | void main(){ 11 | //fragmentdepth = (2.0 * zNear) / (zFar + zNear - gl_FragCoord.z * (zFar - zNear)); 12 | } 13 | -------------------------------------------------------------------------------- /egl_depthRenderer/shaders/DepthRTT.vertexshader: -------------------------------------------------------------------------------- 1 | #version 330 core 2 | 3 | // Input vertex data, different for all executions of this shader. 4 | layout(location = 0) in vec3 vertexPosition_modelspace; 5 | 6 | // Values that stay constant for the whole mesh. 7 | uniform mat4 depthMVP; 8 | 9 | void main(){ 10 | gl_Position = depthMVP * vec4(vertexPosition_modelspace,1); 11 | } 12 | 13 | -------------------------------------------------------------------------------- /egl_depthRenderer/src/common/shader.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | using namespace std; 8 | 9 | #include 10 | #include 11 | 12 | #include 13 | 14 | #include "shader.hpp" 15 | 16 | GLuint LoadShaders(const char* vertex_file_path, const char* fragment_file_path, const char* geometry_file_path) 17 | { 18 | // Create the shaders 19 | GLuint VertexShaderID = glCreateShader(GL_VERTEX_SHADER); 20 | GLuint FragmentShaderID = glCreateShader(GL_FRAGMENT_SHADER); 21 | GLuint GeometryShaderID = glCreateShader(GL_GEOMETRY_SHADER); 22 | 23 | // Read the Vertex Shader code from the file 24 | std::string VertexShaderCode; 25 | std::ifstream VertexShaderStream(vertex_file_path, std::ios::in); 26 | if(VertexShaderStream.is_open()){ 27 | std::string Line = ""; 28 | while(getline(VertexShaderStream, Line)) 29 | VertexShaderCode += "\n" + Line; 30 | VertexShaderStream.close(); 31 | }else{ 32 | printf("Impossible to open %s. Are you in the right directory ? Don't forget to read the FAQ !\n", vertex_file_path); 33 | getchar(); 34 | return 0; 35 | } 36 | 37 | // Read the Fragment Shader code from the file 38 | std::string FragmentShaderCode; 39 | std::ifstream FragmentShaderStream(fragment_file_path, std::ios::in); 40 | if(FragmentShaderStream.is_open()){ 41 | std::string Line = ""; 42 | while(getline(FragmentShaderStream, Line)) 43 | FragmentShaderCode += "\n" + Line; 44 | FragmentShaderStream.close(); 45 | } 46 | 47 | // Read the Fragment Shader code from the file 48 | std::string GeometryShaderCode; 49 | std::ifstream GeometryShaderStream(geometry_file_path, std::ios::in); 50 | if (GeometryShaderStream.is_open()){ 51 | std::string Line = ""; 52 | while (getline(GeometryShaderStream, Line)) 53 | GeometryShaderCode += "\n" + Line; 54 | GeometryShaderStream.close(); 55 | } 56 | 57 | GLint Result = GL_FALSE; 58 | int InfoLogLength; 59 | 60 | 61 | // Compile Vertex Shader 62 | printf("Compiling shader : %s\n", vertex_file_path); 63 | char const * VertexSourcePointer = VertexShaderCode.c_str(); 64 | glShaderSource(VertexShaderID, 1, &VertexSourcePointer , NULL); 65 | glCompileShader(VertexShaderID); 66 | 67 | // Check Vertex Shader 68 | glGetShaderiv(VertexShaderID, GL_COMPILE_STATUS, &Result); 69 | glGetShaderiv(VertexShaderID, GL_INFO_LOG_LENGTH, &InfoLogLength); 70 | if ( InfoLogLength > 0 ){ 71 | std::vector VertexShaderErrorMessage(InfoLogLength+1); 72 | glGetShaderInfoLog(VertexShaderID, InfoLogLength, NULL, &VertexShaderErrorMessage[0]); 73 | printf("%s\n", &VertexShaderErrorMessage[0]); 74 | } 75 | 76 | 77 | 78 | // Compile Fragment Shader 79 | printf("Compiling shader : %s\n", fragment_file_path); 80 | char const * FragmentSourcePointer = FragmentShaderCode.c_str(); 81 | glShaderSource(FragmentShaderID, 1, &FragmentSourcePointer , NULL); 82 | glCompileShader(FragmentShaderID); 83 | 84 | // Check Fragment Shader 85 | glGetShaderiv(FragmentShaderID, GL_COMPILE_STATUS, &Result); 86 | glGetShaderiv(FragmentShaderID, GL_INFO_LOG_LENGTH, &InfoLogLength); 87 | if ( InfoLogLength > 0 ){ 88 | std::vector FragmentShaderErrorMessage(InfoLogLength+1); 89 | glGetShaderInfoLog(FragmentShaderID, InfoLogLength, NULL, &FragmentShaderErrorMessage[0]); 90 | printf("%s\n", &FragmentShaderErrorMessage[0]); 91 | } 92 | 93 | // Compile Geometry Shader 94 | printf("Compiling shader : %s\n", geometry_file_path); 95 | char const * GeometrySourcePointer = GeometryShaderCode.c_str(); 96 | glShaderSource(GeometryShaderID, 1, &GeometrySourcePointer, NULL); 97 | glCompileShader(GeometryShaderID); 98 | 99 | // Check Fragment Shader 100 | glGetShaderiv(GeometryShaderID, GL_COMPILE_STATUS, &Result); 101 | glGetShaderiv(GeometryShaderID, GL_INFO_LOG_LENGTH, &InfoLogLength); 102 | if (InfoLogLength > 0){ 103 | std::vector GeometryShaderErrorMessage(InfoLogLength + 1); 104 | glGetShaderInfoLog(GeometryShaderID, InfoLogLength, NULL, &GeometryShaderErrorMessage[0]); 105 | printf("%s\n", &GeometryShaderErrorMessage[0]); 106 | } 107 | 108 | 109 | // Link the program 110 | printf("Linking program\n"); 111 | GLuint ProgramID = glCreateProgram(); 112 | glAttachShader(ProgramID, VertexShaderID); 113 | glAttachShader(ProgramID, FragmentShaderID); 114 | glAttachShader(ProgramID, GeometryShaderID); 115 | glLinkProgram(ProgramID); 116 | 117 | // Check the program 118 | glGetProgramiv(ProgramID, GL_LINK_STATUS, &Result); 119 | glGetProgramiv(ProgramID, GL_INFO_LOG_LENGTH, &InfoLogLength); 120 | if ( InfoLogLength > 0 ){ 121 | std::vector ProgramErrorMessage(InfoLogLength+1); 122 | glGetProgramInfoLog(ProgramID, InfoLogLength, NULL, &ProgramErrorMessage[0]); 123 | printf("%s\n", &ProgramErrorMessage[0]); 124 | } 125 | 126 | 127 | glDetachShader(ProgramID, VertexShaderID); 128 | glDetachShader(ProgramID, FragmentShaderID); 129 | glDetachShader(ProgramID, GeometryShaderID); 130 | 131 | glDeleteShader(VertexShaderID); 132 | glDeleteShader(FragmentShaderID); 133 | glDeleteShader(GeometryShaderID); 134 | 135 | return ProgramID; 136 | } 137 | 138 | 139 | -------------------------------------------------------------------------------- /egl_depthRenderer/src/common/shader.hpp: -------------------------------------------------------------------------------- 1 | #ifndef SHADER_HPP 2 | #define SHADER_HPP 3 | 4 | GLuint LoadShaders(const char* vertex_file_path, const char* fragment_file_path, const char* geometry_file_path); 5 | 6 | #endif 7 | -------------------------------------------------------------------------------- /egl_depthRenderer/src/mLibSource.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "mLibCore.h" 3 | #include "mLibLodePNG.h" 4 | 5 | #include "mLibCore.cpp" 6 | #include "mLibLodePNG.cpp" 7 | -------------------------------------------------------------------------------- /egl_depthRenderer/src/main.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Depth rendering codes for ScanNet 3d model with estimated poses. 4 | 5 | Created on: May. 8, 2019 6 | Author: Junho Jeon, was in POSTECH, now with NaverLabs Corp. 7 | 8 | You may need opengl extensions for compile it. 9 | sudo apt-get install libglfw3-dev libglfw3 libglew1.5 libglew1.5-dev 10 | 11 | */ 12 | 13 | #include "mLibCore.h" 14 | #include "mLibLodePNG.h" 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | #include 22 | #include 23 | #include 24 | 25 | #include 26 | using namespace ml; 27 | using namespace std; 28 | using namespace cv; 29 | 30 | static const EGLint configAttribs[] = { 31 | EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, 32 | EGL_BLUE_SIZE, 8, 33 | EGL_GREEN_SIZE, 8, 34 | EGL_RED_SIZE, 8, 35 | EGL_DEPTH_SIZE, 8, 36 | EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT, 37 | EGL_NONE 38 | }; 39 | 40 | static const int pbufferWidth = 9; 41 | static const int pbufferHeight = 9; 42 | 43 | static const EGLint pbufferAttribs[] = { 44 | EGL_WIDTH, pbufferWidth, 45 | EGL_HEIGHT, pbufferHeight, 46 | EGL_NONE, 47 | }; 48 | 49 | int RenderDepthFromMesh(MeshDataf &mesh, std::string poses_path, std::string out_path, int interval) 50 | { 51 | GLFWwindow* window; 52 | 53 | if (!glfwInit()) 54 | { 55 | fprintf(stderr, "Failed to initialize GLFW\n"); 56 | getchar(); 57 | return -1; 58 | } 59 | int windowWidth = 640; 60 | int windowHeight = 480; 61 | 62 | glfwWindowHint(GLFW_SAMPLES, 4); 63 | glfwWindowHint(GLFW_CONTEXT_VERSION_MAJOR, 3); 64 | glfwWindowHint(GLFW_CONTEXT_VERSION_MINOR, 3); 65 | glfwWindowHint(GLFW_OPENGL_FORWARD_COMPAT, GL_TRUE); // To make MacOS happy; should not be needed 66 | glfwWindowHint(GLFW_OPENGL_PROFILE, GLFW_OPENGL_CORE_PROFILE); 67 | glfwWindowHint(GLFW_VISIBLE, GLFW_FALSE); // Make invisible window since we render to FBO directly 68 | 69 | // Open a window and create its OpenGL context 70 | window = glfwCreateWindow(windowWidth, windowHeight, "Depth Renderer", NULL, NULL); 71 | if (window == NULL){ 72 | fprintf(stderr, "Failed to open GLFW window. If you have an Intel GPU, they are not 3.3 compatible. Try the 2.1 version of the tutorials.\n"); 73 | getchar(); 74 | glfwTerminate(); 75 | return -1; 76 | } 77 | glfwMakeContextCurrent(window); 78 | 79 | // Initialize GLEW 80 | glewExperimental = true; // Needed for core profile 81 | if (glewInit() != GLEW_OK) { 82 | fprintf(stderr, "Failed to initialize GLEW\n"); 83 | getchar(); 84 | glfwTerminate(); 85 | return -1; 86 | } 87 | glViewport(0, 0, windowWidth, windowHeight); 88 | 89 | // Dark blue background 90 | glClearColor(1.0f, 0.0f, 0.0f, 1.0f); 91 | glEnable(GL_DEPTH_TEST); 92 | glDepthFunc(GL_LESS); 93 | glDisable(GL_MULTISAMPLE); 94 | 95 | GLuint VertexArrayID; 96 | glGenVertexArrays(1, &VertexArrayID); 97 | glBindVertexArray(VertexArrayID); 98 | 99 | GLuint depthProgramID = LoadShaders("shaders/DepthRTT.vertexshader", 100 | "shaders/DepthRTT.fragmentshader", "shaders/BerycentricGeometryShader.geometryshader"); 101 | 102 | // Get a handle for our "MVP" uniform 103 | GLuint depthMatrixID = glGetUniformLocation(depthProgramID, "depthMVP"); 104 | 105 | // Load it into a VBO 106 | GLuint vertexbuffer, indexbuffer; 107 | glGenBuffers(1, &vertexbuffer); 108 | glBindBuffer(GL_ARRAY_BUFFER, vertexbuffer); 109 | glBufferData(GL_ARRAY_BUFFER, mesh.m_Vertices.size()*sizeof(mesh.m_Vertices[0]), &mesh.m_Vertices[0], GL_STATIC_DRAW); 110 | 111 | std::vector indexvec; 112 | for (int i = 0; i < mesh.m_FaceIndicesVertices.size(); i++) 113 | for (int j = 0; j < 3; j++) 114 | indexvec.push_back(mesh.m_FaceIndicesVertices[i][j]); 115 | glGenBuffers(1, &indexbuffer); 116 | glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, indexbuffer); 117 | glBufferData(GL_ELEMENT_ARRAY_BUFFER, indexvec.size()*sizeof(unsigned int), &indexvec[0], GL_STATIC_DRAW); 118 | 119 | // zNear: 0.01, zFar: 30.0 (in meters) 120 | float zNear = 0.01; 121 | float zFar = 30.0; 122 | ml::Matrix4x4 camera, proj; 123 | { 124 | float tmp[] = { 125 | -1.81066, 0.00000, 0.00000, 0.00000, 126 | 0.00000, -2.41421, 0.00000, 0.00000, 127 | 0.00000, 0.00000, -(zFar + zNear) / (zFar - zNear), -2 * zFar*zNear / (zFar-zNear), 128 | 0.00000, 0.00000, -1.00000, 0.00000 129 | }; 130 | proj = ml::Matrix4x4(tmp); 131 | } 132 | 133 | GLuint FramebufferName = 0; 134 | glGenFramebuffers(1, &FramebufferName); 135 | glBindFramebuffer(GL_FRAMEBUFFER, FramebufferName); 136 | 137 | GLuint depthTexture; 138 | glGenTextures(1, &depthTexture); 139 | glBindTexture(GL_TEXTURE_2D, depthTexture); 140 | glTexImage2D(GL_TEXTURE_2D, 0, GL_DEPTH_COMPONENT32, windowWidth, windowHeight, 0, GL_DEPTH_COMPONENT, GL_FLOAT, 0); 141 | glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST); 142 | glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST); 143 | glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE); 144 | glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); 145 | 146 | glFramebufferTexture(GL_FRAMEBUFFER, GL_DEPTH_ATTACHMENT, depthTexture, 0); 147 | // No color output in the bound framebuffer, only depth. 148 | glDrawBuffer(GL_NONE); 149 | 150 | // Always check that our framebuffer is ok 151 | if (glCheckFramebufferStatus(GL_FRAMEBUFFER) != GL_FRAMEBUFFER_COMPLETE) { 152 | printf("something wrong in FramebufferStatus check\n"); 153 | return false; 154 | } 155 | 156 | // Use our shader 157 | glUseProgram(depthProgramID); 158 | 159 | int frame = 0; 160 | do{ 161 | // Clear the screen 162 | glBindFramebuffer(GL_FRAMEBUFFER, FramebufferName); 163 | glViewport(0, 0, windowWidth, windowHeight); 164 | glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT); 165 | 166 | { 167 | std::stringstream tmpFrame; 168 | tmpFrame << std::setw(4) << std::setfill('0') << frame; 169 | // FILE *in = fopen((std::string("") + target_scene + std::string("\\pose\\frame-00") + tmpFrame.str() + ".pose.txt").c_str(), "r"); 170 | FILE *in = fopen((poses_path + std::string("/frame-00") + tmpFrame.str() + ".pose.txt").c_str(), "r"); 171 | if (in == NULL) 172 | break; 173 | float tmp[4][4]; 174 | for (int i = 0; i < 4; i++) 175 | for (int j = 0; j < 4; j++) 176 | fscanf(in, "%f", &tmp[i][j]); 177 | fclose(in); 178 | if (tmp[0][0] != -1.0) { 179 | camera = ml::Matrix4x4((float*)&tmp[0][0]); 180 | } 181 | else { 182 | frame += interval; 183 | continue; 184 | } 185 | camera._m00 *= -1.0; camera._m01 *= -1.0; camera._m02 *= -1.0; 186 | camera._m10 *= -1.0; camera._m11 *= -1.0; camera._m12 *= -1.0; 187 | camera._m20 *= -1.0; camera._m21 *= -1.0; camera._m22 *= -1.0; 188 | } 189 | ml::Matrix4x4 MVP = proj*camera.getInverse(); 190 | MVP.transpose(); 191 | glUniformMatrix4fv(depthMatrixID, 1, GL_FALSE, &MVP.matrix[0]); 192 | glShadeModel(GL_FLAT); 193 | 194 | // 1rst attribute buffer : vertices 195 | glEnableVertexAttribArray(0); 196 | glBindBuffer(GL_ARRAY_BUFFER, vertexbuffer); 197 | glVertexAttribPointer(0, 3, GL_FLOAT, GL_FALSE, 0, (void*)0); 198 | 199 | // Index buffer 200 | glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, indexbuffer); 201 | // Draw the triangle ! 202 | glDrawElements(GL_TRIANGLES, indexvec.size(), GL_UNSIGNED_INT, (void*)0); 203 | glDisableVertexAttribArray(0); 204 | 205 | // Swap buffers 206 | DenseMatrixf ones(windowHeight, windowWidth, 0.0f); 207 | glfwSwapBuffers(window); 208 | { 209 | printf("Processing %d..\r", frame); 210 | fflush(stdout); 211 | cv::Mat1f tmp(windowHeight, windowWidth); 212 | glReadPixels(0, 0, windowWidth, windowHeight, GL_DEPTH_COMPONENT, GL_FLOAT, tmp.data); 213 | tmp = tmp * 2.0f - 1.0f; 214 | tmp = (2.0 * zNear * zFar) / (zFar + zNear - tmp * (zFar - zNear)); 215 | cv::Mat1w out(windowHeight, windowWidth); 216 | cv::Mat1w maskedOut(windowHeight, windowWidth); 217 | maskedOut.setTo(0); 218 | tmp.convertTo(out, CV_16UC1, 1000.0); 219 | cv::Mat1b mask = out < ushort(zFar*1000-1); 220 | out.copyTo(maskedOut, mask); 221 | std::stringstream tmpFrame; 222 | tmpFrame << std::setw(4) << std::setfill('0') << frame; 223 | cv::imwrite((out_path + std::string("/frame-00") + tmpFrame.str() + ".depth.png").c_str(), maskedOut); 224 | frame += interval; 225 | } 226 | } // Check if the ESC key was pressed or the window was closed 227 | while (glfwGetKey(window, GLFW_KEY_ESCAPE) != GLFW_PRESS && glfwWindowShouldClose(window) == 0); 228 | 229 | printf("\n"); 230 | // Cleanup VBO and shader 231 | glDeleteBuffers(1, &vertexbuffer); 232 | glDeleteBuffers(1, &indexbuffer); 233 | glDeleteProgram(depthProgramID); 234 | glDeleteVertexArrays(1, &VertexArrayID); 235 | 236 | // Close OpenGL window and terminate GLFW 237 | glfwTerminate(); 238 | } 239 | 240 | 241 | int HeadlessRenderDepthFromMesh(MeshDataf &mesh, std::string poses_path, std::string out_path, int interval) 242 | { 243 | GLFWwindow* window; 244 | 245 | // if (!glfwInit()) 246 | // { 247 | // fprintf(stderr, "Failed to initialize GLFW\n"); 248 | // getchar(); 249 | // return -1; 250 | // } 251 | int windowWidth = 640; 252 | int windowHeight = 480; 253 | 254 | glfwWindowHint(GLFW_SAMPLES, 4); 255 | glfwWindowHint(GLFW_CONTEXT_VERSION_MAJOR, 3); 256 | glfwWindowHint(GLFW_CONTEXT_VERSION_MINOR, 3); 257 | glfwWindowHint(GLFW_OPENGL_FORWARD_COMPAT, GL_TRUE); // To make MacOS happy; should not be needed 258 | glfwWindowHint(GLFW_OPENGL_PROFILE, GLFW_OPENGL_CORE_PROFILE); 259 | glfwWindowHint(GLFW_VISIBLE, GLFW_FALSE); // Make invisible window since we render to FBO directly 260 | 261 | // // Open a window and create its OpenGL context 262 | // window = glfwCreateWindow(windowWidth, windowHeight, "Depth Renderer", NULL, NULL); 263 | // if (window == NULL){ 264 | // fprintf(stderr, "Failed to open GLFW window. If you have an Intel GPU, they are not 3.3 compatible. Try the 2.1 version of the tutorials.\n"); 265 | // getchar(); 266 | // glfwTerminate(); 267 | // return -1; 268 | // } 269 | // glfwMakeContextCurrent(window); 270 | 271 | // Initialize GLEW 272 | glewExperimental = true; // Needed for core profile 273 | if (glewInit() != GLEW_OK) { 274 | fprintf(stderr, "Failed to initialize GLEW\n"); 275 | getchar(); 276 | glfwTerminate(); 277 | return -1; 278 | } 279 | glViewport(0, 0, windowWidth, windowHeight); 280 | 281 | // Dark blue background 282 | glClearColor(1.0f, 0.0f, 0.0f, 1.0f); 283 | glEnable(GL_DEPTH_TEST); 284 | glDepthFunc(GL_LESS); 285 | glDisable(GL_MULTISAMPLE); 286 | 287 | GLuint VertexArrayID; 288 | glGenVertexArrays(1, &VertexArrayID); 289 | glBindVertexArray(VertexArrayID); 290 | 291 | GLuint depthProgramID = LoadShaders("shaders/DepthRTT.vertexshader", 292 | "shaders/DepthRTT.fragmentshader", "shaders/BerycentricGeometryShader.geometryshader"); 293 | 294 | // Get a handle for our "MVP" uniform 295 | GLuint depthMatrixID = glGetUniformLocation(depthProgramID, "depthMVP"); 296 | 297 | // Load it into a VBO 298 | GLuint vertexbuffer, indexbuffer; 299 | glGenBuffers(1, &vertexbuffer); 300 | glBindBuffer(GL_ARRAY_BUFFER, vertexbuffer); 301 | glBufferData(GL_ARRAY_BUFFER, mesh.m_Vertices.size()*sizeof(mesh.m_Vertices[0]), &mesh.m_Vertices[0], GL_STATIC_DRAW); 302 | 303 | std::vector indexvec; 304 | for (int i = 0; i < mesh.m_FaceIndicesVertices.size(); i++) 305 | for (int j = 0; j < 3; j++) 306 | indexvec.push_back(mesh.m_FaceIndicesVertices[i][j]); 307 | glGenBuffers(1, &indexbuffer); 308 | glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, indexbuffer); 309 | glBufferData(GL_ELEMENT_ARRAY_BUFFER, indexvec.size()*sizeof(unsigned int), &indexvec[0], GL_STATIC_DRAW); 310 | 311 | // zNear: 0.01, zFar: 30.0 (in meters) 312 | float zNear = 0.01; 313 | float zFar = 30.0; 314 | ml::Matrix4x4 camera, proj; 315 | { 316 | float tmp[] = { 317 | -1.81066, 0.00000, 0.00000, 0.00000, 318 | 0.00000, -2.41421, 0.00000, 0.00000, 319 | 0.00000, 0.00000, -(zFar + zNear) / (zFar - zNear), -2 * zFar*zNear / (zFar-zNear), 320 | 0.00000, 0.00000, -1.00000, 0.00000 321 | }; 322 | proj = ml::Matrix4x4(tmp); 323 | } 324 | 325 | GLuint FramebufferName = 0; 326 | glGenFramebuffers(1, &FramebufferName); 327 | glBindFramebuffer(GL_FRAMEBUFFER, FramebufferName); 328 | 329 | GLuint depthTexture; 330 | glGenTextures(1, &depthTexture); 331 | glBindTexture(GL_TEXTURE_2D, depthTexture); 332 | glTexImage2D(GL_TEXTURE_2D, 0, GL_DEPTH_COMPONENT32, windowWidth, windowHeight, 0, GL_DEPTH_COMPONENT, GL_FLOAT, 0); 333 | glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST); 334 | glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST); 335 | glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE); 336 | glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); 337 | 338 | glFramebufferTexture(GL_FRAMEBUFFER, GL_DEPTH_ATTACHMENT, depthTexture, 0); 339 | // No color output in the bound framebuffer, only depth. 340 | glDrawBuffer(GL_NONE); 341 | 342 | // Always check that our framebuffer is ok 343 | if (glCheckFramebufferStatus(GL_FRAMEBUFFER) != GL_FRAMEBUFFER_COMPLETE) { 344 | printf("something wrong in FramebufferStatus check\n"); 345 | return false; 346 | } 347 | 348 | // Use our shader 349 | glUseProgram(depthProgramID); 350 | 351 | int frame = 0; 352 | do{ 353 | // Clear the screen 354 | glBindFramebuffer(GL_FRAMEBUFFER, FramebufferName); 355 | glViewport(0, 0, windowWidth, windowHeight); 356 | glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT); 357 | 358 | { 359 | std::stringstream tmpFrame; 360 | tmpFrame << std::setw(4) << std::setfill('0') << frame; 361 | // FILE *in = fopen((std::string("") + target_scene + std::string("\\pose\\frame-00") + tmpFrame.str() + ".pose.txt").c_str(), "r"); 362 | FILE *in = fopen((poses_path + std::string("/frame-00") + tmpFrame.str() + ".pose.txt").c_str(), "r"); 363 | if (in == NULL) 364 | break; 365 | float tmp[4][4]; 366 | for (int i = 0; i < 4; i++) 367 | for (int j = 0; j < 4; j++) 368 | fscanf(in, "%f", &tmp[i][j]); 369 | fclose(in); 370 | if (tmp[0][0] != -1.0) { 371 | camera = ml::Matrix4x4((float*)&tmp[0][0]); 372 | } 373 | else { 374 | frame += interval; 375 | continue; 376 | } 377 | camera._m00 *= -1.0; camera._m01 *= -1.0; camera._m02 *= -1.0; 378 | camera._m10 *= -1.0; camera._m11 *= -1.0; camera._m12 *= -1.0; 379 | camera._m20 *= -1.0; camera._m21 *= -1.0; camera._m22 *= -1.0; 380 | } 381 | ml::Matrix4x4 MVP = proj*camera.getInverse(); 382 | MVP.transpose(); 383 | glUniformMatrix4fv(depthMatrixID, 1, GL_FALSE, &MVP.matrix[0]); 384 | glShadeModel(GL_FLAT); 385 | 386 | // 1rst attribute buffer : vertices 387 | glEnableVertexAttribArray(0); 388 | glBindBuffer(GL_ARRAY_BUFFER, vertexbuffer); 389 | glVertexAttribPointer(0, 3, GL_FLOAT, GL_FALSE, 0, (void*)0); 390 | 391 | // Index buffer 392 | glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, indexbuffer); 393 | // Draw the triangle ! 394 | glDrawElements(GL_TRIANGLES, indexvec.size(), GL_UNSIGNED_INT, (void*)0); 395 | glDisableVertexAttribArray(0); 396 | 397 | // Swap buffers 398 | DenseMatrixf ones(windowHeight, windowWidth, 0.0f); 399 | glfwSwapBuffers(window); 400 | { 401 | printf("Processing %d..\r", frame); 402 | fflush(stdout); 403 | cv::Mat1f tmp(windowHeight, windowWidth); 404 | glReadPixels(0, 0, windowWidth, windowHeight, GL_DEPTH_COMPONENT, GL_FLOAT, tmp.data); 405 | tmp = tmp * 2.0f - 1.0f; 406 | tmp = (2.0 * zNear * zFar) / (zFar + zNear - tmp * (zFar - zNear)); 407 | cv::Mat1w out(windowHeight, windowWidth); 408 | cv::Mat1w maskedOut(windowHeight, windowWidth); 409 | maskedOut.setTo(0); 410 | tmp.convertTo(out, CV_16UC1, 1000.0); 411 | cv::Mat1b mask = out < ushort(zFar*1000-1); 412 | out.copyTo(maskedOut, mask); 413 | std::stringstream tmpFrame; 414 | tmpFrame << std::setw(4) << std::setfill('0') << frame; 415 | cv::imwrite((out_path + std::string("/frame-00") + tmpFrame.str() + ".depth.png").c_str(), maskedOut); 416 | frame += interval; 417 | } 418 | } // Check if the ESC key was pressed or the window was closed 419 | while (glfwGetKey(window, GLFW_KEY_ESCAPE) != GLFW_PRESS && glfwWindowShouldClose(window) == 0); 420 | 421 | printf("\n"); 422 | // Cleanup VBO and shader 423 | glDeleteBuffers(1, &vertexbuffer); 424 | glDeleteBuffers(1, &indexbuffer); 425 | glDeleteProgram(depthProgramID); 426 | glDeleteVertexArrays(1, &VertexArrayID); 427 | 428 | // Close OpenGL window and terminate GLFW 429 | glfwTerminate(); 430 | } 431 | 432 | int main(int argc, char** argv) 433 | { 434 | bool headless_render = true; 435 | if (argc < 5) 436 | { 437 | printf("Usage: ./depthRenderer "); 438 | exit(-1); 439 | } 440 | 441 | std::string mesh_path(argv[1]); 442 | std::string poses_path(argv[2]); 443 | std::string out_path(argv[3]); 444 | int interval = atoi(argv[4]); 445 | printf("Render '%s' to '%s' with interval %d\n", mesh_path.c_str(), out_path.c_str(), interval); 446 | MeshDataf mesh; 447 | MeshIOf::loadFromFile(mesh_path.c_str(), mesh); 448 | std::printf("Loaded a mesh with %d vertices\n", (int)mesh.m_Vertices.size()); 449 | Timer t; 450 | 451 | if (headless_render) 452 | { 453 | // 1. Initialize EGL 454 | EGLDisplay eglDpy = eglGetDisplay(EGL_DEFAULT_DISPLAY); 455 | EGLint major, minor; 456 | eglInitialize(eglDpy, &major, &minor); 457 | 458 | // 2. Select an appropriate configuration 459 | EGLint numConfigs; 460 | EGLConfig eglCfg; 461 | eglChooseConfig(eglDpy, configAttribs, &eglCfg, 1, &numConfigs); 462 | 463 | // 3. Create a surface 464 | EGLSurface eglSurf = eglCreatePbufferSurface(eglDpy, eglCfg, pbufferAttribs); 465 | 466 | // 4. Bind the API 467 | eglBindAPI(EGL_OPENGL_API); 468 | 469 | // 5. Create a context and make it current 470 | EGLContext eglCtx = eglCreateContext(eglDpy, eglCfg, EGL_NO_CONTEXT, NULL); 471 | eglMakeCurrent(eglDpy, eglSurf, eglSurf, eglCtx); 472 | 473 | // from now on use your OpenGL context 474 | HeadlessRenderDepthFromMesh(mesh, poses_path, out_path, interval); 475 | 476 | // 6. Terminate EGL when finished 477 | eglTerminate(eglDpy); 478 | } 479 | else 480 | { 481 | RenderDepthFromMesh(mesh, poses_path, out_path, interval); 482 | } 483 | cout << t.getElapsedTime() << "s" << endl; 484 | return 0; 485 | } 486 | -------------------------------------------------------------------------------- /model_epoch_99.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunhoJeon/depth_refine_reconstruct/452eb9b4791e5d06c37908ba192d31157cb28198/model_epoch_99.pth -------------------------------------------------------------------------------- /raw_depth0195.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunhoJeon/depth_refine_reconstruct/452eb9b4791e5d06c37908ba192d31157cb28198/raw_depth0195.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | Pillow 4 | opencv-python 5 | -------------------------------------------------------------------------------- /ulapsrn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | 8 | def get_upsample_filter(size): 9 | """Make a 2D bilinear kernel suitable for upsampling""" 10 | factor = (size + 1) // 2 11 | if size % 2 == 1: 12 | center = factor - 1 13 | else: 14 | center = factor - 0.5 15 | og = np.ogrid[:size, :size] 16 | filter = (1 - abs(og[0] - center) / factor) * \ 17 | (1 - abs(og[1] - center) / factor) 18 | return torch.from_numpy(filter).float() 19 | 20 | class _Shared_Source_Residual_Block(nn.Module): 21 | def __init__(self, D=5, R=2): 22 | super(_Shared_Source_Residual_Block, self).__init__() 23 | self.D = D 24 | self.R = R 25 | conv_block = [] 26 | for i in range(0, self.D): 27 | conv_block.append(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)) 28 | if i < self.D-1: 29 | conv_block.append(nn.LeakyReLU(0.2, inplace=True)) 30 | self.cov_block = nn.Sequential(*conv_block) 31 | 32 | def forward(self, x): 33 | output = x 34 | for i in range(0, self.R): 35 | output = x + self.cov_block(output) 36 | return output 37 | 38 | class _Distinct_Source_Residual_Block(nn.Module): 39 | def __init__(self, D=10, R=1): 40 | super(_Distinct_Source_Residual_Block, self).__init__() 41 | self.D = D 42 | self.R = R 43 | conv_block = [] 44 | for i in range(0, self.D): 45 | conv_block.append(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)) 46 | if i < self.D-1: 47 | conv_block.append(nn.LeakyReLU(0.2, inplace=True)) 48 | self.cov_block = nn.Sequential(*conv_block) 49 | 50 | def forward(self, x): 51 | output = x 52 | for i in range(0, self.R): 53 | output = output + self.cov_block(output) 54 | return output 55 | 56 | class _Downsample_Block(nn.Module): 57 | def __init__(self): 58 | super(_Downsample_Block, self).__init__() 59 | self.down_block = nn.Sequential( 60 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False), # down-sampling 61 | nn.LeakyReLU(0.2, inplace=True), 62 | # nn.MaxPool2d(2) 63 | ) 64 | 65 | def forward(self, x): 66 | output = self.down_block(x) 67 | return output 68 | 69 | class _Upsample_Block(nn.Module): 70 | def __init__(self): 71 | super(_Upsample_Block, self).__init__() 72 | 73 | self.cov_block = nn.Sequential( 74 | nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False), 75 | nn.LeakyReLU(0.2, inplace=True), 76 | ) 77 | 78 | def forward(self, x): 79 | output = self.cov_block(x) 80 | return output 81 | 82 | class Net_Quarter_Deep_Mask(nn.Module): 83 | def __init__(self): 84 | super(Net_Quarter_Deep_Mask, self).__init__() 85 | 86 | recursive_block = _Distinct_Source_Residual_Block 87 | 88 | self.conv_input = nn.Conv2d(in_channels=2, out_channels=64, kernel_size=7, stride=1, padding=3, bias=False) 89 | self.relu = nn.LeakyReLU(0.2, inplace=True) 90 | self.en_feature_original = self.make_layer(recursive_block) 91 | self.en_downsample_half = self.make_layer(_Downsample_Block) 92 | self.en_feature_half = self.make_layer(recursive_block) 93 | self.en_downsample_quarter = self.make_layer(_Downsample_Block) 94 | self.en_feature_quarter = self.make_layer(recursive_block) 95 | self.conv_R_quarter = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False) 96 | 97 | for m in self.modules(): 98 | if isinstance(m, nn.Conv2d): 99 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 100 | m.weight.data.normal_(0, math.sqrt(2. / n)) 101 | if m.bias is not None: 102 | m.bias.data.zero_() 103 | if isinstance(m, nn.ConvTranspose2d): 104 | c1, c2, h, w = m.weight.data.size() 105 | weight = get_upsample_filter(h) 106 | m.weight.data = weight.view(1, 1, h, w).repeat(c1, c2, 1, 1) 107 | if m.bias is not None: 108 | m.bias.data.zero_() 109 | 110 | def make_layer(self, block): 111 | layers = [] 112 | layers.append(block()) 113 | return nn.Sequential(*layers) 114 | 115 | def forward(self, x): 116 | input_mask = (x[0] == 0).float() 117 | data_quarter = x[1] 118 | input = torch.cat((x[0], input_mask), dim=1) 119 | en_feature_original = self.en_feature_original(self.relu(self.conv_input(input))) 120 | en_downsample_half = self.en_downsample_half(en_feature_original) 121 | en_feature_half = self.en_feature_half(en_downsample_half) 122 | en_downsample_quarter = self.en_downsample_quarter(en_feature_half) 123 | en_feature_quarter = self.en_feature_quarter(en_downsample_quarter) 124 | refined_quarter = data_quarter + self.conv_R_quarter(en_feature_quarter) 125 | return refined_quarter, en_feature_quarter, en_feature_half, en_feature_original 126 | 127 | class Net_Quarter_Half(nn.Module): 128 | def __init__(self, pretrained=None): 129 | super(Net_Quarter_Half, self).__init__() 130 | 131 | recursive_block = _Distinct_Source_Residual_Block 132 | # self.QuarterNet = Net_Quarter() 133 | self.QuarterNet = Net_Quarter_Deep_Mask() 134 | self.de_upsample_half = self.make_layer(_Upsample_Block) 135 | self.de_feature_half = self.make_layer(recursive_block, 20, 1) # this is long decoder for large receptive field 136 | self.conv_R_half = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False) 137 | self.upsample_img_half = nn.Upsample(scale_factor=2, mode='bilinear') 138 | 139 | for m in self.modules(): 140 | if isinstance(m, nn.Conv2d): 141 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 142 | m.weight.data.normal_(0, math.sqrt(2. / n)) 143 | if m.bias is not None: 144 | m.bias.data.zero_() 145 | if isinstance(m, nn.ConvTranspose2d): 146 | c1, c2, h, w = m.weight.data.size() 147 | weight = get_upsample_filter(h) 148 | m.weight.data = weight.view(1, 1, h, w).repeat(c1, c2, 1, 1) 149 | if m.bias is not None: 150 | m.bias.data.zero_() 151 | 152 | def make_layer(self, block, *args): 153 | layers = [] 154 | layers.append(block(*args)) 155 | return nn.Sequential(*layers) 156 | 157 | def forward(self, x): 158 | input, data_quarter = x 159 | refined_quarter, en_feature_quarter, en_feature_half, en_feature_original = self.QuarterNet((input, data_quarter)) 160 | de_upsample_half = self.de_upsample_half(en_feature_quarter) + en_feature_half 161 | de_feature_half = self.de_feature_half(de_upsample_half) 162 | refined_half = self.upsample_img_half(refined_quarter) + self.conv_R_half(de_feature_half) 163 | return refined_half, refined_quarter, de_feature_half, en_feature_original 164 | 165 | class Net_Quarter_Half_Mapping(nn.Module): 166 | def __init__(self, pretrained=None): 167 | super(Net_Quarter_Half_Mapping, self).__init__() 168 | 169 | recursive_block = _Distinct_Source_Residual_Block 170 | # self.QuarterNet = Net_Quarter() 171 | self.QuarterNet = Net_Quarter_Deep_Mask() 172 | self.de_upsample_half = self.make_layer(_Upsample_Block) 173 | self.de_feature_half = self.make_layer(recursive_block, 10, 1) # this is long decoder for large receptive field 174 | self.mapping_feature_half = self.make_layer(recursive_block, 10, 1) 175 | self.conv_R_half = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False) 176 | self.upsample_img_half = nn.Upsample(scale_factor=2, mode='bilinear') 177 | 178 | for m in self.modules(): 179 | if isinstance(m, nn.Conv2d): 180 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 181 | m.weight.data.normal_(0, math.sqrt(2. / n)) 182 | if m.bias is not None: 183 | m.bias.data.zero_() 184 | if isinstance(m, nn.ConvTranspose2d): 185 | c1, c2, h, w = m.weight.data.size() 186 | weight = get_upsample_filter(h) 187 | m.weight.data = weight.view(1, 1, h, w).repeat(c1, c2, 1, 1) 188 | if m.bias is not None: 189 | m.bias.data.zero_() 190 | 191 | def make_layer(self, block, *args): 192 | layers = [] 193 | layers.append(block(*args)) 194 | return nn.Sequential(*layers) 195 | 196 | def forward(self, x): 197 | input, data_quarter = x 198 | refined_quarter, en_feature_quarter, en_feature_half, en_feature_original = self.QuarterNet((input, data_quarter)) 199 | de_upsample_half = self.de_upsample_half(en_feature_quarter) + self.mapping_feature_half(en_feature_half) 200 | de_feature_half = self.de_feature_half(de_upsample_half) 201 | refined_half = self.upsample_img_half(refined_quarter) + self.conv_R_half(de_feature_half) 202 | return refined_half, refined_quarter, de_feature_half, en_feature_original 203 | 204 | class Net_Quarter_Half_Original(nn.Module): 205 | def __init__(self): 206 | super(Net_Quarter_Half_Original, self).__init__() 207 | 208 | recursive_block = _Distinct_Source_Residual_Block 209 | self.HalfNet = Net_Quarter_Half() 210 | self.de_upsample_original = self.make_layer(_Upsample_Block) 211 | self.de_feature_original = self.make_layer(recursive_block, 40, 1) 212 | self.conv_R_original = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False) 213 | self.upsample_img_original = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=4, stride=2, padding=1, bias=False) 214 | # self.upsample_img_original = nn.Upsample(scale_factor=2, mode='bilinear') 215 | 216 | for m in self.modules(): 217 | if isinstance(m, nn.Conv2d): 218 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 219 | m.weight.data.normal_(0, math.sqrt(2. / n)) 220 | if m.bias is not None: 221 | m.bias.data.zero_() 222 | if isinstance(m, nn.ConvTranspose2d): 223 | c1, c2, h, w = m.weight.data.size() 224 | weight = get_upsample_filter(h) 225 | m.weight.data = weight.view(1, 1, h, w).repeat(c1, c2, 1, 1) 226 | if m.bias is not None: 227 | m.bias.data.zero_() 228 | 229 | def make_layer(self, block, *args): 230 | layers = [] 231 | layers.append(block(*args)) 232 | return nn.Sequential(*layers) 233 | 234 | def forward(self, x): 235 | input, data_quarter = x 236 | refined_half, refined_quarter, de_feature_half, en_feature_original = self.HalfNet((input, data_quarter)) 237 | de_upsample_original = self.de_upsample_original(de_feature_half) + en_feature_original 238 | de_feature_original = self.de_feature_original(de_upsample_original) 239 | refined_original = self.upsample_img_original(refined_half) + self.conv_R_original(de_feature_original) 240 | return refined_original, refined_half, refined_quarter 241 | 242 | class Net(nn.Module): 243 | def __init__(self): 244 | super(Net, self).__init__() 245 | 246 | parameters_share = False 247 | recursive_block = _Distinct_Source_Residual_Block 248 | 249 | # F1 is just downsampled, F2 passed feature embedding 250 | self.conv_input = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=7, stride=1, padding=3, bias=False) 251 | self.relu = nn.LeakyReLU(0.2, inplace=True) 252 | # self.en_feature_original = self.make_layer(recursive_block) 253 | self.en_downsample_half = self.make_layer(_Downsample_Block) 254 | self.en_feature_half = self.make_layer(recursive_block) 255 | if parameters_share == False: 256 | self.en_downsample_quarter = self.make_layer(_Downsample_Block) 257 | self.en_feature_quarter = self.make_layer(recursive_block) 258 | else: 259 | self.en_downsample_quarter = self.en_downsample_half 260 | self.en_feature_quarter = self.en_feature_half 261 | 262 | self.de_feature_quarter = self.make_layer(recursive_block) # takes the features right before the image 263 | self.de_upsample_half = self.make_layer(_Upsample_Block) 264 | if parameters_share == False: 265 | self.de_feature_half = self.make_layer(recursive_block) 266 | self.de_upsample_original = self.make_layer(_Upsample_Block) 267 | self.de_feature_original = self.make_layer(recursive_block) 268 | else: 269 | self.de_feature_half = self.de_feature_quarter 270 | self.de_upsample_original = self.de_upsample_half 271 | self.de_feature_original = self.de_feature_quarter 272 | 273 | self.de_upsample_img_half = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=4, stride=2, padding=1, bias=False) 274 | if parameters_share == False: 275 | self.de_upsample_img_original = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=4, stride=2, padding=1, bias=False) 276 | # self.de_upsample_img_original = nn.Upsample(scale_factor=2, mode='nearest') 277 | else: 278 | self.de_upsample_img_original = self.de_upsample_img_half 279 | 280 | self.conv_R_quarter = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False) 281 | if parameters_share == False: 282 | self.conv_R_half = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False) 283 | self.conv_R_original = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False) 284 | else: 285 | self.conv_R_half = self.conv_R_quarter 286 | self.conv_R_original = self.conv_R_quarter 287 | 288 | for m in self.modules(): 289 | if isinstance(m, nn.Conv2d): 290 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 291 | m.weight.data.normal_(0, math.sqrt(2. / n)) 292 | if m.bias is not None: 293 | m.bias.data.zero_() 294 | if isinstance(m, nn.ConvTranspose2d): 295 | c1, c2, h, w = m.weight.data.size() 296 | weight = get_upsample_filter(h) 297 | m.weight.data = weight.view(1, 1, h, w).repeat(c1, c2, 1, 1) 298 | if m.bias is not None: 299 | m.bias.data.zero_() 300 | 301 | def make_layer(self, block): 302 | layers = [] 303 | layers.append(block()) 304 | return nn.Sequential(*layers) 305 | 306 | def forward(self, x): 307 | input, data_quarter = x 308 | 309 | en_feature_original = self.relu(self.conv_input(input)) 310 | en_downsample_half = self.en_downsample_half(en_feature_original) 311 | en_feature_half = self.en_feature_half(en_downsample_half) 312 | en_downsample_quarter = self.en_downsample_quarter(en_feature_half) 313 | 314 | en_feature_quarter = self.en_feature_quarter(en_downsample_quarter) 315 | refined_quarter = data_quarter + self.conv_R_quarter(en_feature_quarter) 316 | 317 | de_feature_quarter = self.de_feature_quarter(en_feature_quarter) 318 | de_upsample_half = self.de_upsample_half(de_feature_quarter)# + en_feature_half 319 | de_feature_half = self.de_feature_half(de_upsample_half) 320 | refined_half = self.de_upsample_img_half(refined_quarter) + self.conv_R_half(de_feature_half) 321 | 322 | de_upsample_original = self.de_upsample_original(de_feature_half)# + en_feature_original 323 | de_feature_original = self.de_feature_original(de_upsample_original) 324 | refined = self.de_upsample_img_original(refined_half) + self.conv_R_original(de_feature_original) 325 | 326 | return refined_quarter, refined_half, refined 327 | 328 | class L1_Charbonnier_loss(nn.Module): 329 | """L1 Charbonnierloss.""" 330 | def __init__(self): 331 | super(L1_Charbonnier_loss, self).__init__() 332 | self.eps = 1e-6 333 | 334 | def forward(self, X, Y): 335 | diff = torch.add(X, -Y) 336 | error = torch.sqrt( diff * diff + self.eps ) 337 | loss = torch.mean(error) 338 | return loss 339 | 340 | class L1_Gradient_loss(nn.Module): 341 | def __init__(self): 342 | super(L1_Gradient_loss, self).__init__() 343 | self.eps = 1e-6 344 | self.crit = L1_Charbonnier_loss() 345 | 346 | def forward(self, X, Y): 347 | xgin = X[:,:,1:,:] - X[:,:,0:-1,:] 348 | ygin = X[:,:,:,1:] - X[:,:,:,0:-1] 349 | xgtarget = Y[:,:,1:,:] - Y[:,:,0:-1,:] 350 | ygtarget = Y[:,:,:,1:] - Y[:,:,:,0:-1] 351 | 352 | xl = self.crit(xgin, xgtarget) 353 | yl = self.crit(ygin, ygtarget) 354 | return (xl + yl) * 0.5 355 | 356 | class Patch_Discontinuity_loss(nn.Module): 357 | def __init__(self, kernel_size=5): 358 | super(Patch_Discontinuity_loss, self).__init__() 359 | self.eps = 1e-6 360 | self.crit = nn.MSELoss() 361 | psize = kernel_size // 2 362 | self.pool_xgin = nn.MaxPool2d(kernel_size, 1, padding=psize) 363 | self.pool_ygin = nn.MaxPool2d(kernel_size, 1, padding=psize) 364 | self.pool_xgtarget = nn.MaxPool2d(kernel_size, 1, padding=psize) 365 | self.pool_ygtarget = nn.MaxPool2d(kernel_size, 1, padding=psize) 366 | 367 | def forward(self, X, Y): 368 | b, c, h, w = X.size() 369 | xgtarget = torch.abs(Y[:,:,1:,:] - Y[:,:,0:-1,:]) 370 | ygtarget = torch.abs(Y[:,:,:,1:] - Y[:,:,:,0:-1]) 371 | xmask = (Y[:,:,1:,:] > 0).float() * (Y[:,:,0:-1,:] > 0).float() * (xgtarget > 0.1).float() 372 | ymask = (Y[:,:,:,1:] > 0).float() * (Y[:,:,:,0:-1] > 0).float() * (ygtarget > 0.1).float() 373 | ygin = torch.abs(X[:,:,:,1:] - X[:,:,:,0:-1]) * ymask 374 | xgin = torch.abs(X.narrow(2, 1, h-1) - X.narrow(2, 0, h-1)) * xmask 375 | xgtarget2 = xgtarget * xmask 376 | ygtarget2 = ygtarget * ymask 377 | 378 | xl = self.crit(self.pool_xgin(xgin), self.pool_xgtarget(xgtarget2)) 379 | yl = self.crit(self.pool_ygin(ygin), self.pool_ygtarget(ygtarget2)) 380 | return (xl + yl) * 0.5 381 | 382 | 383 | # Tukey loss in Robust Optimization for Deep Regression (ICCV) 384 | class TukeyLoss(nn.Module): 385 | def __init__(self): 386 | super(TukeyLoss, self).__init__() 387 | self.epoch = 0 388 | 389 | def setIter(self, epoch): 390 | self.epoch = epoch 391 | 392 | def mad(self, x): 393 | med = torch.median(x) 394 | return torch.median(torch.abs(x - med)) 395 | 396 | def forward(self, X, Y): 397 | res = Y-X 398 | MAD = 1.4826 * self.mad(res) 399 | 400 | if self.epoch < 20: 401 | MAD = MAD * 7 402 | 403 | resMAD = res / MAD 404 | c = 4.6851 405 | yt = (c*c/6) * (1 - (1-(resMAD/c)**2)**3) 406 | yt = torch.clamp(yt, 0, c*c/6) 407 | return torch.mean(yt) 408 | 409 | 410 | class TVLoss(nn.Module): 411 | def __init__(self, tv_loss_weight=1): 412 | super(TVLoss, self).__init__() 413 | self.tv_loss_weight = tv_loss_weight 414 | 415 | def forward(self, x): 416 | batch_size = x.size()[0] 417 | h_x = x.size()[2] 418 | w_x = x.size()[3] 419 | count_h = self.tensor_size(x[:, :, 1:, :]) 420 | count_w = self.tensor_size(x[:, :, :, 1:]) 421 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() 422 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() 423 | return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size 424 | 425 | @staticmethod 426 | def tensor_size(t): 427 | return t.size()[1] * t.size()[2] * t.size()[3] 428 | 429 | class L1_Masked_Charbonnier_loss(nn.Module): 430 | """L1 Masked Charbonnierloss. (ignore large gap between input and GT)""" 431 | def __init__(self, kernel_size): 432 | super(L1_Masked_Charbonnier_loss, self).__init__() 433 | self.eps = 1e-6 434 | psize = kernel_size // 2 435 | self.pool = nn.MaxPool2d(kernel_size, 1, padding=psize) 436 | 437 | def forward(self, X, Y, G): 438 | mask = 1 - ((self.pool(torch.abs(Y-G)) > 0.2) * (Y == 0)).float() 439 | diff = torch.add(X, -Y) 440 | error = torch.sqrt( diff * diff + self.eps ) * mask 441 | loss = torch.mean(error) 442 | return loss 443 | 444 | def weight_init(m): 445 | print(m) 446 | if isinstance(m, nn.Conv2d): 447 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 448 | m.weight.data.normal_(0, math.sqrt(2. / n)) 449 | if m.bias is not None: 450 | m.bias.data.zero_() 451 | if isinstance(m, nn.ConvTranspose2d): 452 | c1, c2, h, w = m.weight.data.size() 453 | weight = get_upsample_filter(h) 454 | m.weight.data = weight.view(1, 1, h, w).repeat(c1, c2, 1, 1) 455 | if m.bias is not None: 456 | m.bias.data.zero_() 457 | if isinstance(m, nn.BatchNorm2d): 458 | m.weight.data.normal_(1.0, 0.02) 459 | m.bias.data.zero_() 460 | --------------------------------------------------------------------------------