├── .gitignore ├── LICENSE ├── README.md ├── img ├── teaser-ar.jpg └── teaser.jpg └── src ├── README.md ├── Render ├── __init__.py ├── compile.sh ├── main.cpp └── render.py ├── checkpoints └── download.sh ├── data └── download.sh ├── dataset.py ├── demo ├── AttachTexture.py ├── README.md ├── cpp │ ├── CMakeLists.txt │ └── main.cpp ├── direction.py ├── download.sh └── visualizer.py ├── dorn.py ├── evaluate.sh ├── evaluate_joint.py ├── train.sh ├── train_affine_dorn.py └── visualize_field.py /.gitignore: -------------------------------------------------------------------------------- 1 | # General 2 | .DS_Store 3 | .AppleDouble 4 | .LSOverride 5 | 6 | # Icon must end with two \r 7 | Icon 8 | 9 | 10 | # Thumbnails 11 | ._* 12 | 13 | # Files that might appear in the root of a volume 14 | .DocumentRevisions-V100 15 | .fseventsd 16 | .Spotlight-V100 17 | .TemporaryItems 18 | .Trashes 19 | .VolumeIcon.icns 20 | .com.apple.timemachine.donotpresent 21 | 22 | # Directories potentially created on remote AFP share 23 | .AppleDB 24 | .AppleDesktop 25 | Network Trash Folder 26 | Temporary Items 27 | .apdisk 28 | # Byte-compiled / optimized / DLL files 29 | __pycache__/ 30 | *.py[cod] 31 | *$py.class 32 | 33 | # C extensions 34 | *.so 35 | 36 | # Distribution / packaging 37 | .Python 38 | build/ 39 | develop-eggs/ 40 | dist/ 41 | downloads/ 42 | eggs/ 43 | .eggs/ 44 | lib/ 45 | lib64/ 46 | parts/ 47 | sdist/ 48 | var/ 49 | wheels/ 50 | *.egg-info/ 51 | .installed.cfg 52 | *.egg 53 | MANIFEST 54 | 55 | # PyInstaller 56 | # Usually these files are written by a python script from a template 57 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 58 | *.manifest 59 | *.spec 60 | 61 | # Installer logs 62 | pip-log.txt 63 | pip-delete-this-directory.txt 64 | 65 | # Unit test / coverage reports 66 | htmlcov/ 67 | .tox/ 68 | .coverage 69 | .coverage.* 70 | .cache 71 | nosetests.xml 72 | coverage.xml 73 | *.cover 74 | .hypothesis/ 75 | .pytest_cache/ 76 | 77 | # Translations 78 | *.mo 79 | *.pot 80 | 81 | # Django stuff: 82 | *.log 83 | local_settings.py 84 | db.sqlite3 85 | 86 | # Flask stuff: 87 | instance/ 88 | .webassets-cache 89 | 90 | # Scrapy stuff: 91 | .scrapy 92 | 93 | # Sphinx documentation 94 | docs/_build/ 95 | 96 | # PyBuilder 97 | target/ 98 | 99 | # Jupyter Notebook 100 | .ipynb_checkpoints 101 | 102 | # pyenv 103 | .python-version 104 | 105 | # celery beat schedule file 106 | celerybeat-schedule 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Jingwei Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FrameNet: Learning Local Canonical Frames of 3D Surfaces from a Single RGB Image 2 | 3 | Source code for the paper: 4 | 5 | Jingwei Huang, Yichao Zhou, Thomas Funkhouser, and Leonidas Guibas. [**FrameNet: Learning Local Canonical Frames of 3D Surfaces from a Single RGB Image**](http://stanford.edu/~jingweih/papers/framenet.pdf), ICCV 2019. 6 | 7 | 8 | ![FrameNet Teaser](https://github.com/hjwdzh/framenet/raw/master/img/teaser.jpg) 9 | 10 | ## Usage Pipeline 11 | 12 | ### Experiments 13 | Please refer to [**src**](https://github.com/hjwdzh/framenet/raw/master/src/) directory for details. 14 | 15 | ### Fun AR Application 16 | Please refer to [**src/demo**](https://github.com/hjwdzh/framenet/raw/master/src/demo/) directory for details. 17 | 18 | ### Rendering Toolbox 19 | Please try the example from this [**repo**](https://github.com/hjwdzh/pyRender). 20 | 21 | ## Author 22 | - [Jingwei Huang](mailto:jingweih@stanford.edu) 23 | 24 | © 2019 Jingwei Huang All Rights Reserved 25 | 26 | **IMPORTANT**: If you use this code please cite the following in any resulting publication: 27 | ``` 28 | @article{huang2019framenet, 29 | title={FrameNet: Learning Local Canonical Frames of 3D Surfaces from a Single RGB Image}, 30 | author={Huang, Jingwei and Zhou, Yichao and Funkhouser, Thomas and Guibas, Leonidas}, 31 | journal={arXiv preprint arXiv:1903.12305}, 32 | year={2019} 33 | } 34 | ``` 35 | -------------------------------------------------------------------------------- /img/teaser-ar.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hjwdzh/FrameNet/fe5cc45148f210ad9a2520a576ad92f5f282ca71/img/teaser-ar.jpg -------------------------------------------------------------------------------- /img/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hjwdzh/FrameNet/fe5cc45148f210ad9a2520a576ad92f5f282ca71/img/teaser.jpg -------------------------------------------------------------------------------- /src/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for Experiments 2 | 3 | ### Download the Data 4 | ``` 5 | cd data 6 | sh download.sh 7 | ``` 8 | 9 | ### Download Pre-trained Model 10 | ``` 11 | cd checkpoints 12 | sh download.sh 13 | ``` 14 | 15 | ### Evaluation 16 | ``` 17 | sh evaluate.sh 18 | ``` 19 | 20 | ### Train 21 | ``` 22 | sh train.sh 23 | ``` 24 | 25 | ### Visualize the Tangent Field 26 | ``` 27 | python visualize_field.py 28 | ``` 29 | -------------------------------------------------------------------------------- /src/Render/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hjwdzh/FrameNet/fe5cc45148f210ad9a2520a576ad92f5f282ca71/src/Render/__init__.py -------------------------------------------------------------------------------- /src/Render/compile.sh: -------------------------------------------------------------------------------- 1 | export PATH=$PATH:/usr/local/cuda/bin 2 | export LD_LIBRARY_PATH=/usr/local/cuda/lib64 3 | 4 | export OPENCV_INCLUDE_DIR=/orions4-zfs/projects/jingweih/opencv/include 5 | export OPENCV_LIBRARY_DIR=/orions4-zfs/projects/jingweih/opencv/lib 6 | 7 | export CFLAGS="-I/data/Taskonomy/opencv/include -I$OPENCV_INCLUDE_DIR" 8 | export DFLAGS="-L/data/Taskonomy/opencv/lib -L$OPENCV_LIBRARY_DIR -lopencv_core -lopencv_highgui" 9 | 10 | g++ -std=c++11 -c main.cpp $CFLAGS -O2 -o main.o -fPIC 11 | g++ -std=c++11 main.o $CFLAGS $DFLAGS -O2 -o libRender.so -shared -fPIC 12 | 13 | #g++ -std=c++11 main.o buffer.o loader.o render.o $CFLAGS $DFLAGS -o render -lcudart 14 | 15 | #rm *.o 16 | -------------------------------------------------------------------------------- /src/Render/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | extern "C" { 5 | 6 | void VisualizeDirection(const char* output_file, float* color, cv::Point2f* Qx, cv::Point2f* Qy, int height, int width) { 7 | cv::Mat final(height, width, CV_32FC3); 8 | memcpy(final.data, color, sizeof(float) * 3 * height * width); 9 | final.convertTo(final, CV_8UC3, 255); 10 | cv::cvtColor(final, final, CV_RGB2BGR); 11 | 12 | for (int i = 0; i < 300; ++i) { 13 | int px = rand() % width; 14 | int py = rand() % height; 15 | 16 | for (int k = 0; k < 2; ++k) { 17 | std::vector points; 18 | points.push_back(cv::Point2f(px, py)); 19 | for (int j = 0; j < 300; ++j) { 20 | auto p = points.back(); 21 | int px = p.x; 22 | int py = p.y; 23 | if (px < 0 || py < 0 || px >= width - 2 || py >= height - 2) 24 | break; 25 | auto& Q = (k == 0) ? Qx : Qy; 26 | auto p11 = Q[py * width + px]; 27 | auto p12 = Q[py * width + px + 1]; 28 | auto p21 = Q[py * width + px + width]; 29 | auto p22 = Q[py * width + px + width + 1]; 30 | if (p11.dot(p11) < 1e-6) 31 | break; 32 | if (p12.dot(p12) < 1e-6) 33 | break; 34 | if (p21.dot(p21) < 1e-6) 35 | break; 36 | if (p22.dot(p22) < 1e-6) 37 | break; 38 | 39 | float wx = p.x - px; 40 | float wy = p.y - py; 41 | cv::Point2f dp = (p11 * (1 - wx) + p12 * wx) * (1 - wy) + (p21 * (1 - wx) + p22 * wx) * wy; 42 | points.push_back(p + dp); 43 | if (k == 0) { 44 | cv::line(final, p, points.back(), cv::Scalar(255, 0, 0)); 45 | } else { 46 | cv::line(final, p, points.back(), cv::Scalar(0, 255, 0)); 47 | } 48 | } 49 | } 50 | } 51 | cv::imwrite(output_file, final); 52 | } 53 | 54 | 55 | }; -------------------------------------------------------------------------------- /src/Render/render.py: -------------------------------------------------------------------------------- 1 | from ctypes import * 2 | 3 | import numpy as np 4 | Render = cdll.LoadLibrary('./Render/libRender.so') 5 | 6 | def setup(info): 7 | Render.InitializeCamera(info['depthWidth'], info['depthHeight'], 8 | c_float(info['d_fx']), c_float(info['d_fy']), c_float(info['d_cx']), c_float(info['d_cy'])) 9 | 10 | def SetMesh(V, F): 11 | handle = Render.SetMesh(c_void_p(V.ctypes.data), c_void_p(F.ctypes.data), V.shape[0], F.shape[0]) 12 | return handle 13 | 14 | def render(handle, world2cam): 15 | Render.SetTransform(handle, c_void_p(world2cam.ctypes.data)) 16 | Render.Render(handle); 17 | 18 | def getDepth(info): 19 | depth = np.zeros((info['depthHeight'],info['depthWidth']), dtype='float32') 20 | Render.GetDepth(c_void_p(depth.ctypes.data)) 21 | 22 | return depth 23 | 24 | def getVMap(handle, info): 25 | vindices = np.zeros((info['depthHeight'],info['depthWidth'], 3), dtype='int32') 26 | vweights = np.zeros((info['depthHeight'],info['depthWidth'], 3), dtype='float32') 27 | 28 | Render.GetVMap(handle, c_void_p(vindices.ctypes.data), c_void_p(vweights.ctypes.data)) 29 | 30 | return vindices, vweights 31 | 32 | def colorize(VC, vindices, vweights, mask, cimage): 33 | Render.Colorize(c_void_p(VC.ctypes.data), c_void_p(vindices.ctypes.data), c_void_p(vweights.ctypes.data), 34 | c_void_p(mask.ctypes.data), c_void_p(cimage.ctypes.data), vindices.shape[0], vindices.shape[1]) 35 | 36 | def directionalize(Qx, Qy, ambiguity, vindices, vweights, mask, Q_cam, N_cam, rot, depth, fx, fy, cx, cy): 37 | print(vindices.shape) 38 | Render.Directionalize(c_void_p(Qx.ctypes.data), c_void_p(Qy.ctypes.data), c_void_p(ambiguity.ctypes.data), c_void_p(vindices.ctypes.data), c_void_p(vweights.ctypes.data), 39 | c_void_p(mask.ctypes.data), c_void_p(Q_cam.ctypes.data), c_void_p(N_cam.ctypes.data), c_void_p(rot.ctypes.data), c_void_p(depth.ctypes.data), vindices.shape[0], vindices.shape[1], c_float(fx), c_float(fy), c_float(cx), c_float(cy)) 40 | 41 | def Clear(): 42 | Render.ClearData() 43 | 44 | def visualizeDirection(file, gt_color, Qx, Qy): 45 | Render.VisualizeDirection(c_char_p(file.encode('utf-8')), c_void_p(gt_color.ctypes.data), c_void_p(Qx.ctypes.data), c_void_p(Qy.ctypes.data), gt_color.shape[0], gt_color.shape[1]) 46 | 47 | def Rotate(tar, src, rot): 48 | Render.Rotate(c_void_p(tar.ctypes.data), c_void_p(src.ctypes.data), src.shape[0], c_void_p(rot.ctypes.data)) 49 | -------------------------------------------------------------------------------- /src/checkpoints/download.sh: -------------------------------------------------------------------------------- 1 | wget http://download.cs.stanford.edu/orion/framenet/dorn.cpkt -------------------------------------------------------------------------------- /src/data/download.sh: -------------------------------------------------------------------------------- 1 | wget http://download.cs.stanford.edu/orion/framenet/scannet-frame-links.txt 2 | wget http://download.cs.stanford.edu/orion/framenet/train_test_split.pkl 3 | wget -i scannet-frame-links.txt 4 | unzip '*.zip' 5 | #rm *.zip -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from torchvision import transforms 5 | from torch.utils.data.dataset import Dataset # For custom datasets 6 | import random 7 | import skimage.io as sio 8 | import skimage.transform as tr 9 | import pickle 10 | import numpy as np 11 | import scipy.misc as misc 12 | import os 13 | 14 | 15 | class AffineDataset(Dataset): 16 | def __init__(self, root='/orion/downloads/framenet', usage='test', feat=0): 17 | # Transforms 18 | self.root = root 19 | self.to_tensor = transforms.ToTensor() 20 | # Read the csv file 21 | self.data_info = pickle.load(open(self.root + '/train_test_split.pkl', 'rb'))[usage] 22 | 23 | self.idx = [i for i in range(0,len(self.data_info[0]), 1)] 24 | self.data_len = len(self.data_info[0]) 25 | 26 | self.intrinsics = [577.591,318.905,578.73,242.684] 27 | xx, yy = np.meshgrid(np.array([i for i in range(640)]), np.array([i for i in range(480)])) 28 | self.mesh_x = misc.imresize((xx - self.intrinsics[1]) / self.intrinsics[0], (240,320),'nearest',mode='F') 29 | self.mesh_y = misc.imresize((yy - self.intrinsics[3]) / self.intrinsics[2], (240,320),'nearest',mode='F') 30 | self.feat = feat 31 | self.root = root + '/scannet-frames' 32 | 33 | def __getitem__(self, index): 34 | # Get image name from the pandas df 35 | color_info = self.data_info[0][self.idx[index]] 36 | orient_info = self.data_info[1][self.idx[index]] 37 | orient_info_X = self.data_info[1][self.idx[index]][:-10] + 'orient-X.png' 38 | orient_info_Y = self.data_info[1][self.idx[index]][:-10] + 'orient-Y.png' 39 | mask_info = self.data_info[2][self.idx[index]] 40 | 41 | color_info = self.root + '/' + color_info[26:] 42 | orient_info = self.root + '/' + orient_info[27:] 43 | orient_info_X = self.root + '/' + orient_info_X[27:] 44 | orient_info_Y = self.root + '/' + orient_info_Y[27:] 45 | mask_info = self.root + '/' + mask_info[27:] 46 | orient_mask_tensor = misc.imresize(sio.imread(mask_info), (240,320), 'nearest') 47 | 48 | # Open image 49 | color_img = misc.imresize(sio.imread(color_info), (240,320,3), 'nearest') 50 | color_tensor = self.to_tensor(color_img) 51 | input_tensor = np.zeros((5, color_img.shape[0], color_img.shape[1]), dtype='float32') 52 | input_tensor[0:3,:,:] = color_tensor 53 | input_tensor[3,:,:] = self.mesh_x 54 | input_tensor[4,:,:] = self.mesh_y 55 | 56 | if self.feat == 1: 57 | depth_info = color_info[:-9] + 'renderdepth.png' 58 | depth_img = misc.imresize(sio.imread(depth_info)/1000.0,(240,320,3),'nearest',mode='F') 59 | 60 | orient_x = misc.imresize(sio.imread(orient_info_X), (240,320,3),'nearest') 61 | orient_x = (orient_x / 255.0 * 2.0 - 1.0).astype('float32') 62 | l1 = np.linalg.norm(orient_x, axis=2) 63 | for j in range(3): 64 | orient_x[:,:,j] /= (l1 + 1e-9) 65 | #X = self.to_tensor(orient_x.copy()) 66 | X = torch.from_numpy(np.transpose(orient_x, (2,0,1))) 67 | #print(np.max(orient_x), X.max(), orient_x.shape, X.shape) 68 | orient_x[:,:,0] = orient_x[:,:,0] - self.mesh_x * orient_x[:,:,2] 69 | orient_x[:,:,1] = orient_x[:,:,1] - self.mesh_y * orient_x[:,:,2] 70 | if self.feat == 1: 71 | orient_x[:,:,0] /= (depth_img + 1e-7) 72 | orient_x[:,:,1] /= (depth_img + 1e-7) 73 | elif self.feat == 2: 74 | l = np.sqrt(orient_x[:,:,0]*orient_x[:,:,0]+orient_x[:,:,1]*orient_x[:,:,1]) + 1e-7 75 | orient_x[:,:,0] /= l 76 | orient_x[:,:,1] /= l 77 | 78 | orient_y = misc.imresize(sio.imread(orient_info_Y), (240,320,3), 'nearest') 79 | orient_y = (orient_y / 255.0 * 2.0 - 1.0).astype('float32') 80 | l2 = np.linalg.norm(orient_y, axis=2) 81 | for j in range(3): 82 | orient_y[:,:,j] /= (l2 + 1e-9) 83 | #Y = self.to_tensor(orient_y.copy()) 84 | #print(np.max(orient_y), Y.max()) 85 | Y = torch.from_numpy(np.transpose(orient_y, (2,0,1))) 86 | orient_y[:,:,0] = orient_y[:,:,0] - self.mesh_x * orient_y[:,:,2] 87 | orient_y[:,:,1] = orient_y[:,:,1] - self.mesh_y * orient_y[:,:,2] 88 | 89 | if self.feat == 1: 90 | orient_y[:,:,0] /= (depth_img + 1e-7) 91 | orient_y[:,:,1] /= (depth_img + 1e-7) 92 | elif self.feat == 2: 93 | l = np.sqrt(orient_y[:,:,0]*orient_y[:,:,0]+orient_y[:,:,1]*orient_y[:,:,1]) + 1e-7 94 | orient_y[:,:,0] /= l 95 | orient_y[:,:,1] /= l 96 | 97 | orient_img = np.zeros((orient_x.shape[0], orient_x.shape[1], 4), dtype='float32') 98 | orient_img[:,:,0] = orient_x[:,:,0] * (l1 > 0.5) 99 | orient_img[:,:,1] = orient_x[:,:,1] * (l1 > 0.5) 100 | orient_img[:,:,2] = orient_y[:,:,0] * (l2 > 0.5) 101 | orient_img[:,:,3] = orient_y[:,:,1] * (l2 > 0.5) 102 | #orient_img = np.concatenate((orient_x[:,:,0:2], orient_y[:,:,2]), axis=2) 103 | 104 | orient_img_vertical = orient_img.copy() 105 | orient_img_vertical[:,:,0:2] = orient_img[:,:,2:4] 106 | orient_img_vertical[:,:,2:4] = -orient_img[:,:,0:2] 107 | 108 | #orient_tensor = self.to_tensor(orient_img) 109 | #print(np.max(orient_img), orient_tensor.max()) 110 | orient_tensor = torch.from_numpy(np.transpose(orient_img, (2,0,1))) 111 | #orient_vert_tensor = self.to_tensor(orient_img_vertical) 112 | orient_vert_tensor = torch.from_numpy(np.transpose(orient_img_vertical,(2,0,1))) 113 | #orient_mask_tensor = misc.imresize(orient_mask_tensor, (240,320), 'nearest') 114 | orient_mask_tensor = torch.Tensor(orient_mask_tensor / 255.0) 115 | #orient_mask = np.reshape(orient_mask, (orient_mask.shape[0], orient_mask.shape[1], 1)) 116 | #orient_mask_tensor = self.to_tensor(orient_mask) 117 | return {'image':input_tensor, 'label':orient_tensor, 'label_alt':orient_vert_tensor, 'mask':orient_mask_tensor, 'X':X, 'Y':Y} 118 | 119 | def __len__(self): 120 | return self.data_len 121 | 122 | class AffineTestsDataset(Dataset): 123 | def __init__(self, root='/orion/downloads/framenet', usage='test', feat=0): 124 | # Transforms 125 | self.root = root 126 | self.to_tensor = transforms.ToTensor() 127 | # Read the csv file 128 | self.data_info = pickle.load(open(self.root + '/train_test_split.pkl', 'rb'))[usage] 129 | #print(len(self.data_info[0])) 130 | #self.data_info = [self.data_info[i][30000:30008] for i in range(3)] 131 | self.idx = [i for i in range(0,len(self.data_info[0]),200)] 132 | #random.shuffle(self.idx) 133 | 134 | # First column contains the image paths 135 | self.data_len = len(self.idx) 136 | self.intrinsics = [577.591,318.905,578.73,242.684] 137 | xx, yy = np.meshgrid(np.array([i for i in range(640)]), np.array([i for i in range(480)])) 138 | self.mesh_x = misc.imresize((xx - self.intrinsics[1]) / self.intrinsics[0], (240,320),'nearest',mode='F') 139 | self.mesh_y = misc.imresize((yy - self.intrinsics[3]) / self.intrinsics[2], (240,320),'nearest',mode='F') 140 | self.feat = feat 141 | self.root = root + '/scannet-frames' 142 | 143 | def __getitem__(self, index): 144 | # Get image name from the pandas df 145 | color_info = self.data_info[0][self.idx[index]] 146 | orient_info = self.data_info[1][self.idx[index]] 147 | orient_info_X = self.data_info[1][self.idx[index]][:-10] + 'orient-X.png' 148 | orient_info_Y = self.data_info[1][self.idx[index]][:-10] + 'orient-Y.png' 149 | mask_info = self.data_info[2][self.idx[index]] 150 | 151 | color_info = self.root + '/' + color_info[26:] 152 | orient_info = self.root + '/' + orient_info[27:] 153 | orient_info_X = self.root + '/' + orient_info_X[27:] 154 | orient_info_Y = self.root + '/' + orient_info_Y[27:] 155 | mask_info = self.root + '/' + mask_info[27:] 156 | orient_mask_tensor = misc.imresize(sio.imread(mask_info), (240,320), 'nearest') 157 | 158 | # Open image 159 | color_img = misc.imresize(sio.imread(color_info), (240,320,3), 'nearest') 160 | color_tensor = self.to_tensor(color_img) 161 | 162 | input_tensor = np.zeros((5, color_img.shape[0], color_img.shape[1]), dtype='float32') 163 | input_tensor[0:3,:,:] = color_tensor 164 | input_tensor[3,:,:] = self.mesh_x 165 | input_tensor[4,:,:] = self.mesh_y 166 | 167 | #orient_img = sio.imread(orient_info) 168 | #orient_img = (orient_img / 255.0 * 2.0 - 1.0).astype('float32') 169 | 170 | if self.feat == 1: 171 | depth_info = color_info[:-9] + 'renderdepth.png' 172 | depth_img = misc.imresize(sio.imread(depth_info)/1000.0,(240,320,3),'nearest',mode='F') 173 | 174 | orient_x = misc.imresize(sio.imread(orient_info_X), (240,320,3),'nearest') 175 | orient_x = (orient_x / 255.0 * 2.0 - 1.0).astype('float32') 176 | l1 = np.linalg.norm(orient_x, axis=2) 177 | for j in range(3): 178 | orient_x[:,:,j] /= (l1 + 1e-9) 179 | #X = self.to_tensor(orient_x.copy()) 180 | X = torch.from_numpy(np.transpose(orient_x,(2,0,1))) 181 | orient_y = misc.imresize(sio.imread(orient_info_Y), (240,320,3), 'nearest') 182 | orient_y = (orient_y / 255.0 * 2.0 - 1.0).astype('float32') 183 | l2 = np.linalg.norm(orient_y, axis=2) 184 | for j in range(3): 185 | orient_y[:,:,j] /= (l2 + 1e-9) 186 | #Y = self.to_tensor(orient_y.copy()) 187 | Y = torch.from_numpy(np.transpose(orient_y,(2,0,1))) 188 | 189 | orient_x[:,:,0] = orient_x[:,:,0] - self.mesh_x * orient_x[:,:,2] 190 | orient_x[:,:,1] = orient_x[:,:,1] - self.mesh_y * orient_x[:,:,2] 191 | if self.feat == 1: 192 | orient_x[:,:,0] /= (depth_img + 1e-7) 193 | orient_x[:,:,1] /= (depth_img + 1e-7) 194 | elif self.feat == 2: 195 | l = np.sqrt(orient_x[:,:,0]*orient_x[:,:,0]+orient_x[:,:,1]*orient_x[:,:,1]) + 1e-7 196 | orient_x[:,:,0] /= l 197 | orient_x[:,:,1] /= l 198 | 199 | orient_y[:,:,0] = orient_y[:,:,0] - self.mesh_x * orient_y[:,:,2] 200 | orient_y[:,:,1] = orient_y[:,:,1] - self.mesh_y * orient_y[:,:,2] 201 | 202 | if self.feat == 1: 203 | orient_y[:,:,0] /= (depth_img + 1e-7) 204 | orient_y[:,:,1] /= (depth_img + 1e-7) 205 | elif self.feat == 2: 206 | l = np.sqrt(orient_y[:,:,0]*orient_y[:,:,0]+orient_y[:,:,1]*orient_y[:,:,1]) + 1e-7 207 | orient_y[:,:,0] /= l 208 | orient_y[:,:,1] /= l 209 | 210 | orient_img = np.zeros((orient_x.shape[0], orient_x.shape[1], 4), dtype='float32') 211 | orient_img[:,:,0] = orient_x[:,:,0] * (l1 > 0.5) 212 | orient_img[:,:,1] = orient_x[:,:,1] * (l1 > 0.5) 213 | orient_img[:,:,2] = orient_y[:,:,0] * (l2 > 0.5) 214 | orient_img[:,:,3] = orient_y[:,:,1] * (l2 > 0.5) 215 | #orient_img = np.concatenate((orient_x[:,:,0:2], orient_y[:,:,2]), axis=2) 216 | 217 | orient_img_vertical = orient_img.copy() 218 | orient_img_vertical[:,:,0:2] = orient_img[:,:,2:4] 219 | orient_img_vertical[:,:,2:4] = -orient_img[:,:,0:2] 220 | 221 | #orient_tensor = self.to_tensor(orient_img) 222 | #orient_vert_tensor = self.to_tensor(orient_img_vertical) 223 | orient_tensor = torch.from_numpy(np.transpose(orient_img,(2,0,1))) 224 | orient_vert_tensor = torch.from_numpy(np.transpose(orient_img_vertical,(2,0,1))) 225 | #orient_mask_tensor = misc.imresize(orient_mask_tensor, (240,320), 'nearest') 226 | orient_mask_tensor = torch.Tensor(orient_mask_tensor / 255.0) 227 | #orient_mask = np.reshape(orient_mask, (orient_mask.shape[0], orient_mask.shape[1], 1)) 228 | #orient_mask_tensor = self.to_tensor(orient_mask) 229 | return {'image':input_tensor, 'label':orient_tensor, 'label_alt':orient_vert_tensor, 'mask':orient_mask_tensor, 'X':X, 'Y':Y} 230 | 231 | def __len__(self): 232 | return self.data_len 233 | -------------------------------------------------------------------------------- /src/demo/AttachTexture.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import skimage.io as sio 3 | import sys 4 | import numpy as np 5 | from visualizer import app 6 | from direction import * 7 | import scipy.misc as misc 8 | import argparse 9 | 10 | parser = argparse.ArgumentParser(description='Process saome integers.') 11 | parser.add_argument('--input', type=str, default='selected/0000') 12 | parser.add_argument('--resource', type=str, default='resources/im4.png') 13 | args = parser.parse_args() 14 | # set up file names 15 | start_name = args.input 16 | color_name = start_name + '-color.png' 17 | orient_x_name = start_name + '-orient-X_pred.png' 18 | orient_y_name = start_name + '-orient-Y_pred.png' 19 | 20 | intrinsics = np.array([577.591,318.905,578.73,242.684]).astype('float32') 21 | 22 | if args.resource[-3:] == 'obj': 23 | mesh_info = ProcessOBJ(args.resource,args.resource[:-3] + 'jpg') 24 | else: 25 | mesh_info = ProcessOBJ('resources/objects/toymonkey/toymonkey.obj','resources/objects/toymonkey/toymonkey.jpg') 26 | 27 | color_image = cv2.imread(color_name) 28 | dirX_3d = Color2Vec(misc.imresize(sio.imread(orient_x_name), (480,640), 'nearest')) 29 | dirY_3d = Color2Vec(misc.imresize(sio.imread(orient_y_name), (480,640), 'nearest')) 30 | 31 | if args.resource[-3:] != 'obj': 32 | resource_image = cv2.imread(args.resource) 33 | else: 34 | resource_image = cv2.imread('resources/im4.png') 35 | 36 | if (resource_image.shape[2] == 4): 37 | mask = resource_image[:,:,3] > 0 38 | resource_image[:,:,0] *= mask 39 | resource_image[:,:,1] *= mask 40 | resource_image[:,:,2] *= mask 41 | resource_image = resource_image[:,:,0:3] 42 | min_v = np.min([resource_image.shape[0], resource_image.shape[1]]) 43 | start_x = (resource_image.shape[1] - min_v)//2 44 | start_y = (resource_image.shape[0] - min_v)//2 45 | attached_patch = np.ascontiguousarray(resource_image[start_y:start_y+min_v, start_x:start_x+min_v,:]) 46 | attached_patch = misc.imresize(attached_patch, (401,401),'nearest') 47 | 48 | app(color_image, dirX_3d, dirY_3d, attached_patch, intrinsics, mesh_info) -------------------------------------------------------------------------------- /src/demo/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for AR 2 | 3 | ![FrameNet AR Teaser](https://github.com/hjwdzh/framenet/raw/master/img/teaser-ar.jpg) 4 | 5 | ### Build the cpp library 6 | ``` 7 | cd cpp 8 | mkdir build 9 | cd build 10 | cmake .. -DCMAKE_BUILD_TYPE=Release 11 | make 12 | ``` 13 | 14 | ### Download the data 15 | ``` 16 | sh download.sh 17 | ``` 18 | 19 | ### Run the App 20 | Our app takes the scene and the object as input. Please see our default options in the script for example. 21 | ``` 22 | python AttachTexture.py [--input scene] [--resource object] 23 | ``` 24 | 25 | ### Instructions for usage 26 | * Move mouse to specify the center the object is going to be placed. 27 | * Left click to place the object. 28 | * Press 'g' to quit the app. 29 | * Press 'd' and 'f' to make the object larger or smaller. 30 | * Press 'a' to switch among three different modes. 31 | * Attach pattern in rigid mode. 32 | * Attach pattern in deformable mode. 33 | * Attach a 3D object. 34 | * Press 'r' to rotate the object among 4 possible orientations. 35 | -------------------------------------------------------------------------------- /src/demo/cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.1) 2 | project(Direction) 3 | 4 | list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake") 5 | 6 | set(CMAKE_INCLUDE_CURRENT_DIR ON) 7 | set(CMAKE_CXX_STANDARD 14) 8 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 9 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wno-sign-compare") 10 | set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG}") 11 | set(CMAKE_CXX_FLAGS_RELEASE "-O3") # enable assert 12 | set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g") # enable assert 13 | set(CMAKE_LINKER_FLAGS "${CMAKE_LINKER_FLAGS}") 14 | set(CMAKE_LINKER_FLAGS_DEBUG "${CMAKE_LINKER_FLAGS_DEBUG}") 15 | 16 | if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") 17 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp -Wno-int-in-bool-context") 18 | set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address") 19 | set(CMAKE_LINKER_FLAGS "${CMAKE_LINKER_FLAGS}") 20 | set(CMAKE_LINKER_FLAGS_DEBUG "${CMAKE_LINKER_FLAGS_DEBUG} -fsanitize=address") 21 | endif() 22 | 23 | set( 24 | Direction_SRC 25 | main.cpp 26 | ) 27 | 28 | add_library( 29 | Direction SHARED 30 | ${Direction_SRC} 31 | ) 32 | 33 | target_link_libraries( 34 | Direction 35 | opencv_world 36 | ) 37 | -------------------------------------------------------------------------------- /src/demo/cpp/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #define MIN(a,b) (((a)<(b))?(a):(b)) 4 | #define MAX(a,b) (((a)>(b))?(a):(b)) 5 | 6 | extern "C" { 7 | 8 | void VisualizeDirection(const char* output_file, float* color, cv::Point2f* Qx, cv::Point2f* Qy, int height, int width) { 9 | cv::Mat final(height, width, CV_32FC3); 10 | memcpy(final.data, color, sizeof(float) * 3 * height * width); 11 | final.convertTo(final, CV_8UC3, 255); 12 | 13 | for (int i = 0; i < 300; ++i) { 14 | int px = rand() % width; 15 | int py = rand() % height; 16 | 17 | for (int k = 0; k < 2; ++k) { 18 | std::vector points; 19 | points.push_back(cv::Point2f(px, py)); 20 | for (int j = 0; j < 300; ++j) { 21 | auto p = points.back(); 22 | int px = p.x; 23 | int py = p.y; 24 | if (px < 0 || py < 0 || px >= width - 2 || py >= height - 2) 25 | break; 26 | auto& Q = (k == 0) ? Qx : Qy; 27 | auto p11 = Q[py * width + px]; 28 | auto p12 = Q[py * width + px + 1]; 29 | auto p21 = Q[py * width + px + width]; 30 | auto p22 = Q[py * width + px + width + 1]; 31 | if (p11.dot(p11) < 1e-6) 32 | break; 33 | if (p12.dot(p12) < 1e-6) 34 | break; 35 | if (p21.dot(p21) < 1e-6) 36 | break; 37 | if (p22.dot(p22) < 1e-6) 38 | break; 39 | 40 | float wx = p.x - px; 41 | float wy = p.y - py; 42 | cv::Point2f dp = (p11 * (1 - wx) + p12 * wx) * (1 - wy) + (p21 * (1 - wx) + p22 * wx) * wy; 43 | points.push_back(p + dp); 44 | if (k == 0) { 45 | cv::line(final, p, points.back(), cv::Scalar(255, 0, 0)); 46 | } else { 47 | cv::line(final, p, points.back(), cv::Scalar(0, 255, 0)); 48 | } 49 | } 50 | } 51 | } 52 | cv::imwrite(output_file, final); 53 | } 54 | 55 | void ComputeWarping(int* integer_params, float delta, float* intrinsics, cv::Vec3f* Qx, cv::Vec3f* Qy, float* output_coords){//, float* frameX, float* frameY) { 56 | int height = integer_params[0]; 57 | int width = integer_params[1]; 58 | int px = integer_params[2]; 59 | int py = integer_params[3]; 60 | int patch_w = integer_params[4]; 61 | cv::Vec3f pt((px - intrinsics[1]) / intrinsics[0], (py - intrinsics[3]) / intrinsics[2], 1); 62 | int patch_w2 = patch_w * 2 + 1; 63 | std::vector coords; 64 | coords.reserve(patch_w2 * patch_w2); 65 | coords.push_back(cv::Vec2i(patch_w, patch_w)); 66 | std::vector hash(patch_w2 * patch_w2); 67 | hash[patch_w * patch_w2 + patch_w] = 1; 68 | int f = 0; 69 | while (f < coords.size()) { 70 | cv::Vec2i& p = coords[f]; 71 | cv::Vec2i dir[4] = {cv::Vec2i(0,1), cv::Vec2i(1,0), cv::Vec2i(0,-1), cv::Vec2i(-1,0)}; 72 | for (int i = 0; i < 4; ++i) { 73 | cv::Vec2i np = p + dir[i]; 74 | if (np.val[0] >= 0 && np.val[0] < patch_w2 && np.val[1] >= 0 && np.val[1] < patch_w2) { 75 | if (hash[np.val[1] * patch_w2 + np.val[0]] == 0) { 76 | hash[np.val[1] * patch_w2 + np.val[0]] = 1; 77 | coords.push_back(np); 78 | } 79 | } 80 | } 81 | f += 1; 82 | } 83 | for (auto& h : hash) { 84 | h = 0; 85 | } 86 | 87 | std::vector positions(patch_w2 * patch_w2); 88 | std::vector positions2d(patch_w2 * patch_w2); 89 | std::vector > frames(patch_w2 * patch_w2); 90 | 91 | positions[patch_w * patch_w2 + patch_w] = cv::Vec3f((px - intrinsics[1]) / intrinsics[0], (py - intrinsics[3]) / intrinsics[2], 1); 92 | positions2d[patch_w * patch_w2 + patch_w] = cv::Vec2f(px, py); 93 | frames[patch_w * patch_w2 + patch_w] = std::make_pair(Qx[py * width + px], Qy[py * width + px]); 94 | hash[patch_w * patch_w2 + patch_w] = 1; 95 | auto p1 = Qx[py * width + px]; 96 | auto p2 = Qy[py * width + px]; 97 | for (int i = 0; i < coords.size(); ++i) { 98 | int index = coords[i].val[1] * patch_w2 + coords[i].val[0]; 99 | if (hash[index]) 100 | continue; 101 | int top = 0; 102 | cv::Vec3f slots[8]; 103 | std::pair new_frame, frame; 104 | int dx[] = {-1,-1,-1,0,0,1,1,1}; 105 | int dy[] = {-1,0,1,-1,1,-1,0,1}; 106 | for (int j = 0; j < 8; ++j) { 107 | int tx = coords[i].val[0] + dx[j]; 108 | int ty = coords[i].val[1] + dy[j]; 109 | if (tx < 0 || ty < 0 || tx >= patch_w2 || ty >= patch_w2) 110 | continue; 111 | if (hash[ty * patch_w2 + tx] == 0) 112 | continue; 113 | frame = frames[ty * patch_w2 + tx]; 114 | slots[top++] = positions[ty * patch_w2 + tx] - delta * frame.first * dx[j] - delta * frame.second * dy[j]; 115 | } 116 | cv::Vec3f p(0,0,0); 117 | for (int j = 0; j < top; ++j) { 118 | p += slots[j]; 119 | } 120 | p /= (float)top; 121 | positions[index] = p; 122 | hash[index] = 1; 123 | float current_px = p.val[0] / p.val[2] * intrinsics[0] + intrinsics[1]; 124 | float current_py = p.val[1] / p.val[2] * intrinsics[2] + intrinsics[3]; 125 | positions2d[index] = cv::Vec2f(current_px, current_py); 126 | if (current_px < 0 || current_px + 1 >= width || current_py < 0 || current_py + 1 >= height) 127 | continue; 128 | new_frame = std::make_pair(cv::Vec3f(0,0,0), cv::Vec3f(0,0,0)); 129 | for (int dpy = 0; dpy <= 1; ++dpy) { 130 | float weight_y = (dpy == 0) ? 1 - (current_py - (int)current_py) : (current_py - (int)current_py); 131 | for (int dpx = 0; dpx <= 1; ++dpx) { 132 | float weight_x = (dpy == 0) ? 1 - (current_px - (int)current_px) : (current_px - (int)current_px); 133 | int index = (int(current_py) + dpy) * width + (int(current_px) + dpx); 134 | cv::Vec3f dir_x = Qx[index]; 135 | cv::Vec3f dir_y = Qy[index]; 136 | float max_dot = -1e30; 137 | cv::Vec3f dir_p1, dir_p2; 138 | for (int k = 0; k < 4; ++k) { 139 | float dot = std::max(dir_x.dot(frame.first), dir_y.dot(frame.second)); 140 | if (dot > max_dot) { 141 | max_dot = dot; 142 | dir_p1 = dir_x; 143 | dir_p2 = dir_y; 144 | } 145 | auto temp = dir_x; 146 | dir_x = -dir_y; 147 | dir_y = temp; 148 | } 149 | new_frame.first += weight_x * weight_y * dir_p1; 150 | new_frame.second += weight_x * weight_y * dir_p2; 151 | } 152 | } 153 | 154 | double l = new_frame.first.dot(new_frame.first); 155 | if (l > 0) 156 | new_frame.first /= sqrt(l); 157 | l = new_frame.second.dot(new_frame.second); 158 | if (l > 0) 159 | new_frame.second /= sqrt(l); 160 | frames[index] = new_frame; 161 | } 162 | memcpy(output_coords, positions2d.data(), sizeof(cv::Vec2f) * positions2d.size()); 163 | } 164 | 165 | float orient2d(const cv::Point2f& a, const cv::Point2f& b, const cv::Point2f& c) 166 | { 167 | return (b.x-a.x)*(c.y-a.y) - (b.y-a.y)*(c.x-a.x); 168 | } 169 | 170 | int min3(float v1, float v2, float v3) { 171 | return std::min(v1, std::min(v2, v3)); 172 | } 173 | int max3(float v1, float v2, float v3) { 174 | return 0.99999 + std::max(v1, std::max(v2, v3)); 175 | } 176 | 177 | void rasterize(unsigned char* rgb1, unsigned char* rgb2, unsigned char* rgb3, cv::Point2f& v0, cv::Point2f& v1, cv::Point2f& v2, unsigned char* image, int height, int width, int solid) { 178 | // Compute triangle bounding box 179 | int minX = min3(v0.x, v1.x, v2.x); 180 | int minY = min3(v0.y, v1.y, v2.y); 181 | int maxX = max3(v0.x, v1.x, v2.x); 182 | int maxY = max3(v0.y, v1.y, v2.y); 183 | 184 | // Clip against screen bounds 185 | minX = std::max(minX, 0); 186 | minY = std::max(minY, 0); 187 | maxX = std::min(maxX, width - 1); 188 | maxY = std::min(maxY, height - 1); 189 | 190 | // Rasterize 191 | for (int py = minY; py <= maxY; py++) { 192 | for (int px = minX; px <= maxX; px++) { 193 | // Determine barycentric coordinates 194 | cv::Point2f p(px, py); 195 | float w0 = orient2d(v1, v2, p); 196 | float w1 = orient2d(v2, v0, p); 197 | float w2 = orient2d(v0, v1, p); 198 | 199 | // If p is on or inside all edges, render pixel. 200 | if (w0 >= 0 && w1 >= 0 && w2 >= 0) { 201 | float w = 1.0 / (w0 + w1 + w2); 202 | w0 *= w; 203 | w1 *= w; 204 | w2 *= w; 205 | unsigned char* target = image + (py * width + px) * 3; 206 | unsigned char t[3]; 207 | for (int j = 0; j < 3; ++j) { 208 | float c = rgb1[j] * w0 + rgb2[j] * w1 + rgb3[j] * w2; 209 | if (c > 255) 210 | c = 255; 211 | if (c < 0) 212 | c = 0; 213 | t[j] = c; 214 | } 215 | if (t[0] != 0 || t[1] != 0 || t[2] != 0) { 216 | if (solid == 1) 217 | for (int j = 0; j < 3; ++j) { 218 | target[j] = t[j]; 219 | } 220 | else 221 | for (int j = 0; j < 3; ++j) { 222 | target[j] = target[j] * 0.5 + t[j] * 0.5; 223 | } 224 | } 225 | } 226 | } 227 | } 228 | } 229 | void Rasterize(unsigned char* patch, cv::Point2f* coords, int patch_w, unsigned char* image, int height, int width, int solid) { 230 | for (int i = 0; i < patch_w - 1; ++i) { 231 | for (int j = 0; j < patch_w - 1; ++j) { 232 | { 233 | unsigned char* rgb1 = patch + (i * patch_w + j) * 3; 234 | unsigned char* rgb2 = patch + (i * patch_w + j + 1) * 3; 235 | unsigned char* rgb3 = patch + (i * patch_w + j + patch_w) * 3; 236 | cv::Point2f& coord1 = coords[i * patch_w + j]; 237 | cv::Point2f& coord2 = coords[i * patch_w + j + 1]; 238 | cv::Point2f& coord3 = coords[i * patch_w + j + patch_w]; 239 | rasterize(rgb1, rgb2, rgb3, coord1, coord2, coord3, image, height, width, solid); 240 | } 241 | { 242 | unsigned char* rgb1 = patch + (i * patch_w + j + 1) * 3; 243 | unsigned char* rgb2 = patch + (i * patch_w + j + 1 + patch_w) * 3; 244 | unsigned char* rgb3 = patch + (i * patch_w + j + patch_w) * 3; 245 | cv::Point2f& coord1 = coords[(i * patch_w + j + 1)]; 246 | cv::Point2f& coord2 = coords[i * patch_w + j + 1 + patch_w]; 247 | cv::Point2f& coord3 = coords[i * patch_w + j + patch_w]; 248 | rasterize(rgb1, rgb2, rgb3, coord1, coord2, coord3, image, height, width, solid); 249 | } 250 | } 251 | } 252 | } 253 | 254 | 255 | float calculateSignedArea2(const glm::vec3& a, const glm::vec3& b, const glm::vec3& c) { 256 | return ((c.x - a.x) * (b.y - a.y) - (b.x - a.x) * (c.y - a.y)); 257 | } 258 | 259 | glm::vec3 calculateBarycentricCoordinate(const glm::vec3& a, const glm::vec3& b, const glm::vec3& c, const glm::vec3& p) { 260 | float beta_tri = calculateSignedArea2(a, p, c); 261 | float gamma_tri = calculateSignedArea2(a, b, p); 262 | float tri_inv = 1.0f / calculateSignedArea2(a, b, c); 263 | float beta = beta_tri * tri_inv; 264 | float gamma = gamma_tri * tri_inv; 265 | float alpha = 1.0 - beta - gamma; 266 | return glm::vec3(alpha, beta, gamma); 267 | } 268 | 269 | bool isBarycentricCoordInBounds(const glm::vec3 barycentricCoord) { 270 | return barycentricCoord.x >= 0.0 && barycentricCoord.x <= 1.0 && 271 | barycentricCoord.y >= 0.0 && barycentricCoord.y <= 1.0 && 272 | barycentricCoord.z >= 0.0 && barycentricCoord.z <= 1.0; 273 | } 274 | 275 | float getZAtCoordinate(const glm::vec3 barycentricCoord, const glm::vec3& a, const glm::vec3& b, const glm::vec3& c) { 276 | return (barycentricCoord.x * a.z 277 | + barycentricCoord.y * b.z 278 | + barycentricCoord.z * c.z); 279 | } 280 | 281 | 282 | void DrawTriangle(glm::vec3* v1, glm::vec3* v2, glm::vec3* v3, 283 | glm::vec2* t1, glm::vec2* t2, glm::vec2* t3, 284 | glm::vec3* n1, glm::vec3* n2, glm::vec3* n3, 285 | unsigned char* tex_image, unsigned char* color_image, float* zbuffer, 286 | float* intrinsics, int tex_width, int tex_height, int width, int height) { 287 | 288 | glm::vec3 p1 = *v1, p2 = *v2, p3 = *v3; 289 | if (p1.z < 0.01 || p2.z < 0.01 || p3.z < 0.01) 290 | return; 291 | 292 | p1.z = 1.0f / p1.z; 293 | p2.z = 1.0f / p2.z; 294 | p3.z = 1.0f / p3.z; 295 | 296 | p1.x = p1.x * p1.z; 297 | p1.y = p1.y * p1.z; 298 | p2.x = p2.x * p2.z; 299 | p2.y = p2.y * p2.z; 300 | p3.x = p3.x * p3.z; 301 | p3.y = p3.y * p3.z; 302 | 303 | float fx = intrinsics[0]; 304 | float cx = intrinsics[1]; 305 | float fy = intrinsics[2]; 306 | float cy = intrinsics[3]; 307 | int minX = (MIN(p1.x, MIN(p2.x, p3.x)) * fx + cx); 308 | int minY = (MIN(p1.y, MIN(p2.y, p3.y)) * fy + cy); 309 | int maxX = (MAX(p1.x, MAX(p2.x, p3.x)) * fx + cx) + 0.999999f; 310 | int maxY = (MAX(p1.y, MAX(p2.y, p3.y)) * fy + cy) + 0.999999f; 311 | 312 | minX = MAX(0, minX); 313 | minY = MAX(0, minY); 314 | maxX = MIN(width, maxX); 315 | maxY = MIN(height, maxY); 316 | 317 | 318 | for (int py = minY; py <= maxY; ++py) { 319 | for (int px = minX; px <= maxX; ++px) { 320 | if (px < 0 || px >= width || py < 0 || py >= height) 321 | continue; 322 | 323 | float x = (px - cx) / fx; 324 | float y = (py - cy) / fy; 325 | glm::vec3 baryCentricCoordinate = calculateBarycentricCoordinate(p1, p2, p3, glm::vec3(x, y, 0)); 326 | 327 | if (isBarycentricCoordInBounds(baryCentricCoordinate)) { 328 | int pixel = py * width + px; 329 | 330 | float z = getZAtCoordinate(baryCentricCoordinate, p1, p2, p3); 331 | int z_quantize = z * 100000; 332 | 333 | int original_z = zbuffer[pixel]; 334 | 335 | if (original_z < z_quantize) { 336 | glm::vec2 tex = *t1 * baryCentricCoordinate.x + *t2 * baryCentricCoordinate.y + *t3 * baryCentricCoordinate.z; 337 | glm::vec3 normal = *n1 * baryCentricCoordinate.x + *n2 * baryCentricCoordinate.y + *n3 * baryCentricCoordinate.z; 338 | 339 | glm::vec3 light_dir((px - cx) / fx, (py - cy) / fy, 1); 340 | light_dir = glm::normalize(light_dir); 341 | normal = glm::normalize(normal); 342 | while (tex.x > 1) 343 | tex.x -= 1; 344 | while (tex.x < 0) 345 | tex.x += 1; 346 | while (tex.y > 1) 347 | tex.y -= 1; 348 | while (tex.y < 0) 349 | tex.y += 1; 350 | float tex_x = tex.x * tex_width; 351 | float tex_y = (1 - tex.y) * tex_height; 352 | int ppx = tex_x, ppy = tex_y; 353 | float wx = tex_x - ppx, wy = tex_y - ppy; 354 | if (ppx >= tex_width - 1) 355 | ppx -= 1; 356 | if (ppy >= tex_height - 1) 357 | ppy -= 1; 358 | unsigned char* rgb1 = tex_image + (ppy * tex_width + ppx) * 3; 359 | unsigned char* rgb2 = tex_image + (ppy * tex_width + ppx + 1) * 3; 360 | unsigned char* rgb3 = tex_image + (ppy * tex_width + ppx + tex_width) * 3; 361 | unsigned char* rgb4 = tex_image + (ppy * tex_width + ppx + tex_width + 1) * 3; 362 | unsigned char* output_rgb = color_image + pixel * 3; 363 | float intensity = 1;//0.3 + 0.7 * std::abs(glm::dot(light_dir, normal)); 364 | for (int t = 0; t < 3; ++t) { 365 | output_rgb[t] = ((rgb1[t] * (1 - wx) + rgb2[t] * wx) * (1 - wy) + (rgb3[t] * (1 - wx) + rgb4[t] * wx) * wy) * intensity; 366 | //printf("%f ", (rgb1[t] * (1 - wx) + rgb2[t] * wx) * (1 - wy) + (rgb3[t] * (1 - wx) + rgb4[t] * wx) * wy); 367 | } 368 | //printf("%d\n", pixel); 369 | zbuffer[pixel] = z_quantize; 370 | } 371 | } 372 | } 373 | }} 374 | }; -------------------------------------------------------------------------------- /src/demo/direction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from ctypes import * 3 | import cv2 4 | import scipy.misc as misc 5 | #dirlib = cdll.LoadLibrary('./cpp/build/libDirection.dylib') 6 | dirlib = cdll.LoadLibrary('./cpp/build/libDirection.so') 7 | 8 | def Color2Vec(image): 9 | image = image / 255.0 * 2.0 - 1.0 10 | norm = np.linalg.norm(image, axis=2) + 1e-7 11 | for j in range(3): 12 | image[:,:,j] /= norm 13 | return image 14 | 15 | def ImageCoord(height, width, intrinsics): 16 | xx, yy = np.meshgrid(np.array([i for i in range(width)]), np.array([i for i in range(height)])) 17 | mesh_x = (xx - intrinsics[1]) / intrinsics[0] 18 | mesh_y = (yy - intrinsics[3]) / intrinsics[2] 19 | return mesh_x, mesh_y 20 | 21 | def ProjectDir(dir3D, pixel_x, pixel_y): 22 | dir2D = np.zeros((dir3D.shape[0], dir3D.shape[1], 2)) 23 | dir2D[:,:,0] = dir3D[:,:,0] - pixel_x * dir3D[:,:,2] 24 | dir2D[:,:,1] = dir3D[:,:,1] - pixel_y * dir3D[:,:,2] 25 | return dir2D 26 | 27 | def VisualizeDirection(file, gt_color, Qx, Qy): 28 | if gt_color.dtype == 'uint8': 29 | color = gt_color.astype('float32') / 255.0 30 | else: 31 | color = gt_color 32 | Qx_float = np.ascontiguousarray(Qx.astype('float32')) 33 | Qy_float = np.ascontiguousarray(Qy.astype('float32')) 34 | color = np.ascontiguousarray(color.astype('float32')) 35 | dirlib.VisualizeDirection(c_char_p(file.encode('utf-8')), c_void_p(color.ctypes.data), c_void_p(Qx_float.ctypes.data), c_void_p(Qy_float.ctypes.data), gt_color.shape[0], gt_color.shape[1]) 36 | 37 | def ComputeWarping(Qx, Qy, intrinsics, px, py, pixel_w, patch_w): 38 | output = np.zeros((patch_w * 2 + 1, patch_w * 2 + 1, 2), dtype='float32') 39 | Qx_float = np.ascontiguousarray(Qx.astype('float32')) 40 | Qy_float = np.ascontiguousarray(Qy.astype('float32')) 41 | intrinsics_float = np.ascontiguousarray(intrinsics.astype('float32')) 42 | 43 | integer_params = np.array([Qx.shape[0], Qx.shape[1], px, py, patch_w], dtype='int32') 44 | 45 | #framesX = np.zeros((257,257,3), dtype='float32') 46 | #framesY = np.zeros((257,257,3), dtype='float32') 47 | dirlib.ComputeWarping( 48 | c_void_p(integer_params.ctypes.data), c_float(pixel_w),\ 49 | c_void_p(intrinsics_float.ctypes.data),\ 50 | c_void_p(Qx_float.ctypes.data), c_void_p(Qy_float.ctypes.data),\ 51 | c_void_p(output.ctypes.data)) 52 | #c_void_p(framesX.ctypes.data), c_void_p(framesY.ctypes.data)) 53 | 54 | 55 | #cv2.imwrite('frame1.png', ((framesX + 1) / 2.0 * 255).astype('uint8')) 56 | #cv2.imwrite('frame2.png', ((framesY + 1) / 2.0 * 255).astype('uint8')) 57 | 58 | return output 59 | 60 | def Render(patch, coord, image, solid): 61 | dirlib.Rasterize(c_void_p(patch.ctypes.data), c_void_p(coord.ctypes.data), c_int(coord.shape[0]), c_void_p(image.ctypes.data), c_int(image.shape[0]), c_int(image.shape[1]), c_int(solid)) 62 | return image 63 | 64 | def DrawTriangle(v1,v2,v3,t1,t2,t3,n1,n2,n3,tex, color_image, z_image, intrinsics): 65 | dirlib.DrawTriangle( 66 | c_void_p(v1.ctypes.data),c_void_p(v2.ctypes.data),c_void_p(v3.ctypes.data),\ 67 | c_void_p(t1.ctypes.data),c_void_p(t2.ctypes.data),c_void_p(t3.ctypes.data),\ 68 | c_void_p(n1.ctypes.data),c_void_p(n2.ctypes.data),c_void_p(n3.ctypes.data),\ 69 | c_void_p(tex.ctypes.data),c_void_p(color_image.ctypes.data),c_void_p(z_image.ctypes.data),\ 70 | c_void_p(intrinsics.ctypes.data),\ 71 | c_int(tex.shape[1]), c_int(tex.shape[0]), c_int(color_image.shape[1]), c_int(color_image.shape[0])) 72 | 73 | def CanonicalPixel(intrinsics, pixel): 74 | #fx, cx, fy, cy 75 | return np.array([(pixel[0] - intrinsics[1]) / intrinsics[0], (pixel[1] - intrinsics[3]) / intrinsics[2], 1]) 76 | 77 | def IntrinsicMatrix(intrinsics): 78 | m = np.zeros((3,3)) 79 | m[0,0] = intrinsics[0] 80 | m[0,2] = intrinsics[1] 81 | m[1,1] = intrinsics[2] 82 | m[1,2] = intrinsics[3] 83 | m[2,2] = 1 84 | return m 85 | 86 | def BuildHomography(intrinsics, pixel, dirX, dirY, pixel_w, image_c): 87 | p_star = CanonicalPixel(intrinsics, pixel) 88 | homography = np.zeros((3,3)) 89 | homography[:,0] = dirX * pixel_w 90 | homography[:,1] = dirY * pixel_w 91 | homography[:,2] = p_star - image_c * (dirX + dirY) * pixel_w 92 | intrinsic_matrix = IntrinsicMatrix(intrinsics) 93 | homography = np.dot(intrinsic_matrix, homography) 94 | patch_coord = np.array([1,0,1]) 95 | #print(np.dot(homography, patch_coord)) 96 | ''' 97 | scale = float(patch_w) / image_w 98 | homography[0,:] *= scale 99 | homography[1,:] *= scale 100 | ''' 101 | return homography 102 | 103 | def ProcessOBJ(obj_file,texture_file=''): 104 | lines = [l.strip() for l in open(obj_file) if l.strip() != ''] 105 | vertices = [] 106 | normals = [] 107 | texs = [] 108 | face_mats = [] 109 | face_Tinds = [] 110 | face_Ninds = [] 111 | face_Vinds = [] 112 | mat_type = -1 113 | Imgs = [] 114 | for l in lines: 115 | words = [w for w in l.split(' ') if w != ''] 116 | if words[0] == 'v': 117 | vertices.append([float(words[1]), float(words[2]), float(words[3])]) 118 | if words[0] == 'vt': 119 | texs.append([float(words[1]), float(words[2])]) 120 | if words[0] == 'vn': 121 | normals.append([float(words[1]), float(words[2]), float(words[3])]) 122 | if words[0] == 'usemtl': 123 | if (texture_file == ''): 124 | if words[1] == 'blinn1SG': 125 | mat_type = len(Imgs) 126 | img = np.zeros((2,2,3),dtype='uint8') 127 | img[:,:,0] = 0.59 * 255 128 | img[:,:,1] = 0.63 * 255 129 | img[:,:,2] = 0.66 * 255 130 | Imgs.append(img.copy()) 131 | elif words[1] == 'lambert2SG': 132 | img = cv2.imread('resources/Converse_obj/converse.jpg') 133 | img = misc.imresize(img, (256,256)) 134 | mat_type = len(Imgs) 135 | Imgs.append(img.copy()) 136 | elif words[1] == 'lambert3SG': 137 | img = cv2.imread('resources/Converse_obj/laces.jpg') 138 | img = misc.imresize(img, (256,256)) 139 | mat_type = len(Imgs) 140 | Imgs.append(img.copy()) 141 | else: 142 | print('wrong!') 143 | exit(0) 144 | else: 145 | img = cv2.imread(texture_file) 146 | mat_type = len(Imgs) 147 | Imgs.append(img.copy()) 148 | if words[0] == 'f': 149 | if mat_type == -1: 150 | print('wrong') 151 | exit(0) 152 | vinds = [] 153 | tinds = [] 154 | ninds = [] 155 | for j in range(3): 156 | ws = words[j + 1].split('/') 157 | vinds.append(int(ws[0])) 158 | tinds.append(int(ws[1])) 159 | ninds.append(int(ws[2])) 160 | face_Vinds.append(vinds) 161 | face_Tinds.append(tinds) 162 | face_Ninds.append(ninds) 163 | face_mats.append(mat_type) 164 | if len(words) == 5: 165 | vinds = [] 166 | tinds = [] 167 | for j in range(3): 168 | p = j + 2 169 | if j == 0: 170 | p = 1 171 | ws = words[p].split('/') 172 | vinds.append(int(ws[0])) 173 | tinds.append(int(ws[1])) 174 | ninds.append(int(ws[2])) 175 | face_Vinds.append(vinds) 176 | face_Tinds.append(tinds) 177 | face_Ninds.append(ninds) 178 | face_mats.append(mat_type) 179 | 180 | 181 | vertices = np.array(vertices, dtype='float32') 182 | texs = np.array(texs, dtype='float32') 183 | face_Vinds = np.array(face_Vinds, dtype='int32') - 1 184 | face_Tinds = np.array(face_Tinds, dtype='int32') - 1 185 | face_Ninds = np.array(face_Ninds, dtype='int32') - 1 186 | face_mats = np.array(face_mats, dtype='int32') 187 | 188 | vertices[:,1] = -vertices[:,1] 189 | vertices[:,2] = -vertices[:,2] 190 | min_v = np.array([np.min(vertices[:,i]) for i in range(3)]) 191 | max_v = np.array([np.max(vertices[:,i]) for i in range(3)]) 192 | max_len = np.max(max_v - min_v) 193 | vertices /= max_len 194 | vertices = np.ascontiguousarray(vertices.astype('float32')) 195 | 196 | return {'v':vertices, 't':texs, 'n':normals, 'fv':face_Vinds, 'ft':face_Tinds, 'fn':face_Ninds, 'm':face_mats, 'tex':Imgs} 197 | 198 | 199 | def Render3D(mesh_info, intrinsics, rotation, translation): 200 | vertices= mesh_info['v'] 201 | vertices = np.dot(vertices, np.transpose(rotation)) 202 | for j in range(3): 203 | vertices[:,j] += translation[j] 204 | vertices = np.ascontiguousarray(vertices.astype('float32')) 205 | 206 | normals = mesh_info['n'] 207 | normals = np.dot(normals, rotation) 208 | normals = np.ascontiguousarray(normals.astype('float32')) 209 | texs = mesh_info['t'] 210 | face_Vinds = mesh_info['fv'] 211 | face_Tinds = mesh_info['ft'] 212 | face_Ninds = mesh_info['fn'] 213 | face_mats = mesh_info['m'] 214 | textures = mesh_info['tex'] 215 | color_image = np.zeros((480,640,3),dtype='uint8') 216 | z_image = np.zeros((480,640),dtype='float32') 217 | for i in range(0,face_Vinds.shape[0]): 218 | v1 = vertices[face_Vinds[i][0]] 219 | v2 = vertices[face_Vinds[i][1]] 220 | v3 = vertices[face_Vinds[i][2]] 221 | t1 = texs[face_Tinds[i][0]] 222 | t2 = texs[face_Tinds[i][1]] 223 | t3 = texs[face_Tinds[i][2]] 224 | n1 = normals[face_Ninds[i][0]] 225 | n2 = normals[face_Ninds[i][1]] 226 | n3 = normals[face_Ninds[i][2]] 227 | mat_id = face_mats[i] 228 | texture = textures[mat_id] 229 | DrawTriangle(v1,v2,v3,t1,t2,t3,n1,n2,n3,texture, color_image, z_image, intrinsics) 230 | return color_image, z_image -------------------------------------------------------------------------------- /src/demo/download.sh: -------------------------------------------------------------------------------- 1 | wget http://download.cs.stanford.edu/orion/framenet/ar.zip 2 | unzip ar.zip 3 | rm ar.zip -------------------------------------------------------------------------------- /src/demo/visualizer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from direction import * 4 | 5 | color_image = None 6 | original_image = None 7 | update = False 8 | mouse_x = 0 9 | mouse_y = 0 10 | patch = None 11 | dirX = None 12 | dirY = None 13 | intrinsics = None 14 | patch_w = 0.1 15 | patch_c = 0 16 | deformable = 0 17 | mesh_info = None 18 | rot = 0 19 | scale = 5 20 | def UpdateImage(): 21 | global draw_image, mouse_x, mouse_y, intrinsics, dirX, dirY, patch_w, patch_c, color_image, patch, deformable 22 | 23 | draw_image = color_image.copy() 24 | if deformable == 1: 25 | outputs = ComputeWarping(dirX, dirY, intrinsics, mouse_x,mouse_y,patch_w, patch_c).astype('float32') 26 | Render(patch, outputs, draw_image, 0) 27 | cv2.imshow("TextureAttach", draw_image) 28 | else: 29 | 30 | H = BuildHomography(intrinsics, np.array([mouse_x, mouse_y]), dirX[mouse_y,mouse_x], dirY[mouse_y,mouse_x], patch_w, patch_c) 31 | warp = cv2.warpPerspective(patch, H, (color_image.shape[1], color_image.shape[0])) 32 | 33 | mask = (warp[:,:,0] < 10) * (warp[:,:,1]<10) * (warp[:,:,2] < 10) 34 | mask = np.reshape(mask, (mask.shape[0],mask.shape[1],1)) 35 | mask = np.tile(mask, (1,1,3)) 36 | draw_image = (draw_image * mask + draw_image * (1 - mask) * 0.5 + warp*(1-mask) * 0.5).astype('uint8') 37 | cv2.imshow("TextureAttach", draw_image) 38 | 39 | def ModifyImage(): 40 | global mouse_x, mouse_y, intrinsics, dirX, dirY, patch_w, patch_c, color_image, deformable, mesh_info, rot, scale 41 | 42 | if deformable == 2: 43 | tangent1 = dirX[mouse_y, mouse_x] 44 | tangent2 = dirY[mouse_y, mouse_x] 45 | 46 | for j in range(rot): 47 | temp = tangent1.copy() 48 | tangent1 = -tangent2.copy() 49 | tangent2 = temp 50 | normal = np.cross(tangent1, tangent2) 51 | x = tangent1 52 | y = normal 53 | z = np.cross(x, y) 54 | rotation = np.array([x, y, z]) 55 | rotation = np.transpose(rotation) 56 | translation = np.array([(mouse_x - intrinsics[1]) / intrinsics[0], (mouse_y - intrinsics[3]) / intrinsics[2], 1]) * scale 57 | color,zbuffer = Render3D(mesh_info, intrinsics, rotation, translation) 58 | mask = np.tile(np.reshape(zbuffer > 0,(480,640,1)),(1,1,3)) 59 | 60 | color_image = (color_image * (1 - mask) + color) 61 | 62 | elif deformable == 1: 63 | outputs = ComputeWarping(dirX, dirY, intrinsics, mouse_x,mouse_y,patch_w, patch_c).astype('float32') 64 | Render(patch, outputs, color_image, 1) 65 | else: 66 | H = BuildHomography(intrinsics, np.array([mouse_x, mouse_y]), dirX[mouse_y,mouse_x], dirY[mouse_y,mouse_x], patch_w, patch_c) 67 | warp = cv2.warpPerspective(patch, H, (color_image.shape[1], color_image.shape[0])) 68 | mask = (patch[:,:,0]>= 1) + (patch[:,:,1]>=1) + (patch[:,:,2] >= 1) 69 | mask = np.reshape(mask > 0, (mask.shape[0],mask.shape[1],1)) 70 | mask = np.tile(mask, (1,1,3)).astype('uint8') * 255 71 | warp_mask = cv2.warpPerspective(mask, H, (color_image.shape[1], color_image.shape[0])) 72 | 73 | mask = warp_mask == 255 74 | #mask = np.reshape(mask, (mask.shape[0],mask.shape[1],1)) 75 | #mask = np.tile(mask, (1,1,3)) 76 | color_image = (color_image * (1-mask) + warp * mask).astype('uint8')#color_image * (warp > 0) * 0.5 + warp * 0.5).astype('uint8') 77 | 78 | def click_and_crop(event, x, y, flags, param): 79 | global mouse_x, mouse_y, original_image, color_image 80 | mouse_x = x 81 | mouse_y = y 82 | if event == cv2.EVENT_LBUTTONDOWN: 83 | ModifyImage() 84 | 85 | elif event == cv2.EVENT_MBUTTONDOWN: 86 | color_image = original_image.copy() 87 | 88 | UpdateImage() 89 | 90 | def app(cimage, dirX_3d, dirY_3d, attached_patch, intrinsic, mesh_infos): 91 | global color_image, original_image, patch_w, update, mouse_y, mouse_x, patch, intrinsics, dirX, dirY, patch_c, deformable, mesh_info, rot, scale 92 | mesh_info = mesh_infos 93 | patch_w = 0.2 / attached_patch.shape[0] 94 | color_image = cimage.copy() 95 | patch = attached_patch.copy() 96 | intrinsics = intrinsic.copy() 97 | dirX = dirX_3d.copy() 98 | dirY = dirY_3d.copy() 99 | 100 | original_image = color_image.copy() 101 | update = False 102 | mouse_x = 0 103 | mouse_y = 0 104 | patch_c = patch.shape[0]//2 105 | cv2.namedWindow("TextureAttach") 106 | cv2.setMouseCallback("TextureAttach", click_and_crop) 107 | draw_image = color_image.copy() 108 | cv2.imshow("TextureAttach", draw_image) 109 | 110 | while True: 111 | # display the image and wait for a keypress 112 | key = cv2.waitKey(1) & 0xFF 113 | 114 | # if the 'c' key is pressed, break from the loop 115 | if key == ord("g"): 116 | break 117 | 118 | if key == ord('d'): 119 | if deformable < 2: 120 | patch_w *= 1.1 121 | else: 122 | scale *= 1.1 123 | UpdateImage() 124 | 125 | if key == ord('f'): 126 | if deformable < 2: 127 | patch_w /= 1.1 128 | else: 129 | scale /= 1.1 130 | UpdateImage() 131 | 132 | if key == ord('a'): 133 | deformable = deformable + 1 134 | if deformable == 3: 135 | deformable = 0 136 | 137 | if key == ord('s'): 138 | cv2.imwrite('result.png', color_image) 139 | 140 | if key == ord('r'): 141 | rot = rot + 1 142 | if rot == 4: 143 | rot = 0 144 | M = cv2.getRotationMatrix2D((patch.shape[1]/2, patch.shape[0]/2), 90, 1) 145 | patch = cv2.warpAffine(patch, M, (patch.shape[0], patch.shape[1])) 146 | UpdateImage() -------------------------------------------------------------------------------- /src/dorn.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.models 6 | import collections 7 | import math 8 | 9 | 10 | def weights_init(modules, type='xavier'): 11 | m = modules 12 | if isinstance(m, nn.Conv2d): 13 | if type == 'xavier': 14 | torch.nn.init.xavier_normal_(m.weight) 15 | elif type == 'kaiming': # msra 16 | torch.nn.init.kaiming_normal_(m.weight) 17 | else: 18 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 19 | m.weight.data.normal_(0, math.sqrt(2. / n)) 20 | 21 | if m.bias is not None: 22 | m.bias.data.zero_() 23 | elif isinstance(m, nn.ConvTranspose2d): 24 | if type == 'xavier': 25 | torch.nn.init.xavier_normal_(m.weight) 26 | elif type == 'kaiming': # msra 27 | torch.nn.init.kaiming_normal_(m.weight) 28 | else: 29 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 30 | m.weight.data.normal_(0, math.sqrt(2. / n)) 31 | 32 | if m.bias is not None: 33 | m.bias.data.zero_() 34 | elif isinstance(m, nn.BatchNorm2d): 35 | m.weight.data.fill_(1.0) 36 | m.bias.data.zero_() 37 | elif isinstance(m, nn.Linear): 38 | if type == 'xavier': 39 | torch.nn.init.xavier_normal_(m.weight) 40 | elif type == 'kaiming': # msra 41 | torch.nn.init.kaiming_normal_(m.weight) 42 | else: 43 | m.weight.data.fill_(1.0) 44 | 45 | if m.bias is not None: 46 | m.bias.data.zero_() 47 | elif isinstance(m, nn.Module): 48 | for m in modules: 49 | if isinstance(m, nn.Conv2d): 50 | if type == 'xavier': 51 | torch.nn.init.xavier_normal_(m.weight) 52 | elif type == 'kaiming': # msra 53 | torch.nn.init.kaiming_normal_(m.weight) 54 | else: 55 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 56 | m.weight.data.normal_(0, math.sqrt(2. / n)) 57 | 58 | if m.bias is not None: 59 | m.bias.data.zero_() 60 | elif isinstance(m, nn.ConvTranspose2d): 61 | if type == 'xavier': 62 | torch.nn.init.xavier_normal_(m.weight) 63 | elif type == 'kaiming': # msra 64 | torch.nn.init.kaiming_normal_(m.weight) 65 | else: 66 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 67 | m.weight.data.normal_(0, math.sqrt(2. / n)) 68 | 69 | if m.bias is not None: 70 | m.bias.data.zero_() 71 | elif isinstance(m, nn.BatchNorm2d): 72 | m.weight.data.fill_(1.0) 73 | m.bias.data.zero_() 74 | elif isinstance(m, nn.Linear): 75 | if type == 'xavier': 76 | torch.nn.init.xavier_normal_(m.weight) 77 | elif type == 'kaiming': # msra 78 | torch.nn.init.kaiming_normal_(m.weight) 79 | else: 80 | m.weight.data.fill_(1.0) 81 | 82 | if m.bias is not None: 83 | m.bias.data.zero_() 84 | 85 | 86 | class FullImageEncoder(nn.Module): 87 | def __init__(self, dataset='kitti'): 88 | super(FullImageEncoder, self).__init__() 89 | self.global_pooling = nn.AvgPool2d(8, stride=8, padding=(1, 0)) # KITTI 16 16 90 | self.dropout = nn.Dropout2d(p=0.5) 91 | self.global_fc = nn.Linear(2048 * 4 * 5, 512) 92 | self.relu = nn.ReLU(inplace=True) 93 | self.conv1 = nn.Conv2d(512, 512, 1) # 1x1 卷积 94 | self.upsample = nn.UpsamplingBilinear2d(size=(30, 40)) # KITTI 49X65 NYU 33X45 95 | self.dataset = dataset 96 | weights_init(self.modules(), 'xavier') 97 | 98 | def forward(self, x): 99 | x1 = self.global_pooling(x) 100 | 101 | # print('# x1 size:', x1.size()) 102 | x2 = self.dropout(x1) 103 | x3 = x2.view(-1, 2048 * 4 * 5) 104 | x4 = self.relu(self.global_fc(x3)) 105 | # print('# x4 size:', x4.size()) 106 | x4 = x4.view(-1, 512, 1, 1) 107 | # print('# x4 size:', x4.size()) 108 | x5 = self.conv1(x4) 109 | out = self.upsample(x5) 110 | return out 111 | 112 | 113 | class SceneUnderstandingModule(nn.Module): 114 | def __init__(self, output_channel=136, dataset='kitti'): 115 | super(SceneUnderstandingModule, self).__init__() 116 | self.encoder = FullImageEncoder(dataset=dataset) 117 | self.aspp1 = nn.Sequential( 118 | nn.Conv2d(2048, 512, 1), 119 | nn.ReLU(inplace=True), 120 | nn.Conv2d(512, 512, 1), 121 | nn.ReLU(inplace=True) 122 | ) 123 | self.aspp2 = nn.Sequential( 124 | nn.Conv2d(2048, 512, 3, padding=6, dilation=6), 125 | nn.ReLU(inplace=True), 126 | nn.Conv2d(512, 512, 1), 127 | nn.ReLU(inplace=True) 128 | ) 129 | self.aspp3 = nn.Sequential( 130 | nn.Conv2d(2048, 512, 3, padding=12, dilation=12), 131 | nn.ReLU(inplace=True), 132 | nn.Conv2d(512, 512, 1), 133 | nn.ReLU(inplace=True) 134 | ) 135 | self.aspp4 = nn.Sequential( 136 | nn.Conv2d(2048, 512, 3, padding=18, dilation=18), 137 | nn.ReLU(inplace=True), 138 | nn.Conv2d(512, 512, 1), 139 | nn.ReLU(inplace=True) 140 | ) 141 | self.concat_process = nn.Sequential( 142 | nn.Dropout2d(p=0.5), 143 | nn.Conv2d(512 * 5, 2048, 1), 144 | nn.ReLU(inplace=True), 145 | nn.Dropout2d(p=0.5), 146 | nn.Conv2d(2048, output_channel, 1), # KITTI 142 NYU 136 In paper, K = 80 is best, so use 160 is good! 147 | # nn.UpsamplingBilinear2d(scale_factor=8) 148 | nn.UpsamplingBilinear2d(size=(240, 320)) 149 | ) 150 | 151 | weights_init(self.modules(), type='xavier') 152 | 153 | def forward(self, x): 154 | x1 = self.encoder(x) 155 | 156 | x2 = self.aspp1(x) 157 | x3 = self.aspp2(x) 158 | x4 = self.aspp3(x) 159 | x5 = self.aspp4(x) 160 | 161 | x6 = torch.cat((x1, x2, x3, x4, x5), dim=1) 162 | # print('cat x6 size:', x6.size()) 163 | out = self.concat_process(x6) 164 | return out 165 | 166 | 167 | class OrdinalRegressionLayer(nn.Module): 168 | def __init__(self): 169 | super(OrdinalRegressionLayer, self).__init__() 170 | #self.logsoftmax = nn.Logsoftmax(dim=1) 171 | 172 | def forward(self, x): 173 | N, C, H, W = x.size() 174 | 175 | ord_num = C // 2 176 | 177 | A = x[:, ::2, :, :].clone() 178 | B = x[:, 1::2, :, :].clone() 179 | 180 | A = A.view(N, 1, ord_num * H * W) 181 | B = B.view(N, 1, ord_num * H * W) 182 | 183 | C = torch.cat((A, B), dim=1) 184 | #C = torch.clamp(C, min=1e-7, max=1e7) # prevent nans 185 | 186 | ord_c = nn.functional.softmax(C, dim=1) 187 | 188 | ord_c1 = ord_c[:, 1, :].clone() 189 | ord_c2 = nn.LogSoftmax(dim=1)(C) 190 | 191 | ord_c1 = ord_c1.view(-1, ord_num, H, W) 192 | ord_c2 = ord_c2.view(-1, ord_num * 2, H, W) 193 | decode_c = torch.sum((ord_c1 >= 0.5), dim=1).view(-1, 1, H, W).float() 194 | 195 | return decode_c, ord_c2 196 | 197 | class ResNet(nn.Module): 198 | def __init__(self, in_channels=3, pretrained=True, freeze=True): 199 | super(ResNet, self).__init__() 200 | pretrained_model = torchvision.models.__dict__['resnet{}'.format(101)](pretrained=pretrained) 201 | 202 | self.channel = in_channels 203 | 204 | self.conv1 = nn.Sequential(collections.OrderedDict([ 205 | ('conv1_1', nn.Conv2d(self.channel, 64, kernel_size=3, stride=2, padding=1, bias=False)), 206 | ('bn1_1', nn.BatchNorm2d(64)), 207 | ('relu1_1', nn.ReLU(inplace=True)), 208 | ('conv1_2', nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)), 209 | ('bn_2', nn.BatchNorm2d(64)), 210 | ('relu1_2', nn.ReLU(inplace=True)), 211 | ('conv1_3', nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False)), 212 | ('bn1_3', nn.BatchNorm2d(128)), 213 | ('relu1_3', nn.ReLU(inplace=True)) 214 | ])) 215 | 216 | self.bn1 = nn.BatchNorm2d(128) 217 | 218 | # print(pretrained_model._modules['layer1'][0].conv1) 219 | 220 | self.relu = pretrained_model._modules['relu'] 221 | self.maxpool = pretrained_model._modules['maxpool'] 222 | self.layer1 = pretrained_model._modules['layer1'] 223 | self.layer1[0].conv1 = nn.Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False) 224 | self.layer1[0].downsample[0] = nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) 225 | 226 | self.layer2 = pretrained_model._modules['layer2'] 227 | 228 | self.layer3 = pretrained_model._modules['layer3'] 229 | self.layer3[0].conv2 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 230 | self.layer3[0].downsample[0] = nn.Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) 231 | 232 | self.layer4 = pretrained_model._modules['layer4'] 233 | self.layer4[0].conv2 = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 234 | self.layer4[0].downsample[0] = nn.Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False) 235 | 236 | # clear memory 237 | del pretrained_model 238 | 239 | if pretrained: 240 | weights_init(self.conv1, type='kaiming') 241 | weights_init(self.layer1[0].conv1, type='kaiming') 242 | weights_init(self.layer1[0].downsample[0], type='kaiming') 243 | weights_init(self.layer3[0].conv2, type='kaiming') 244 | weights_init(self.layer3[0].downsample[0], type='kaiming') 245 | weights_init(self.layer4[0].conv2, 'kaiming') 246 | weights_init(self.layer4[0].downsample[0], 'kaiming') 247 | else: 248 | weights_init(self.modules(), type='kaiming') 249 | 250 | if freeze: 251 | self.freeze() 252 | 253 | def forward(self, x): 254 | # print(pretrained_model._modules) 255 | 256 | x = self.conv1(x) 257 | x = self.bn1(x) 258 | x = self.relu(x) 259 | 260 | # print('conv1:', x.size()) 261 | 262 | x = self.maxpool(x) 263 | 264 | # print('pool:', x.size()) 265 | 266 | x1 = self.layer1(x) 267 | # print('layer1 size:', x1.size()) 268 | x2 = self.layer2(x1) 269 | # print('layer2 size:', x2.size()) 270 | x3 = self.layer3(x2) 271 | # print('layer3 size:', x3.size()) 272 | x4 = self.layer4(x3) 273 | # print('layer4 size:', x4.size()) 274 | return x4 275 | 276 | def freeze(self): 277 | for m in self.modules(): 278 | if isinstance(m, nn.BatchNorm2d): 279 | m.eval() 280 | 281 | 282 | class DORN(nn.Module): 283 | def __init__(self, output_size=(240, 320), losstype=1, channel=3, pretrained=True, freeze=True, output_channel=3, dataset='kitti'): 284 | super(DORN, self).__init__() 285 | 286 | self.output_size = output_size 287 | self.channel = channel 288 | self.feature_extractor = ResNet(in_channels=channel, pretrained=pretrained, freeze=freeze) 289 | self.aspp_module = SceneUnderstandingModule(output_channel=output_channel, dataset=dataset) 290 | self.orl = OrdinalRegressionLayer() 291 | self.losstype = losstype 292 | 293 | def forward(self, x): 294 | x1 = self.feature_extractor(x) 295 | x2 = self.aspp_module(x1) 296 | return x2 297 | 298 | def get_1x_lr_params(self): 299 | b = [self.feature_extractor] 300 | for i in range(len(b)): 301 | for k in b[i].parameters(): 302 | if k.requires_grad: 303 | yield k 304 | 305 | def get_10x_lr_params(self): 306 | b = [self.aspp_module, self.orl] 307 | for j in range(len(b)): 308 | for k in b[j].parameters(): 309 | if k.requires_grad: 310 | yield k 311 | 312 | 313 | # os.environ["CUDA_VISIBLE_DEVICES"] = "1" # 默认使用GPU 0 314 | 315 | if __name__ == "__main__": 316 | model = DORN() 317 | model = model.cuda() 318 | model.eval() 319 | image = torch.randn(1, 3, 257, 353) 320 | image = image.cuda() 321 | with torch.no_grad(): 322 | out0, out1 = model(image) 323 | print('out0 size:', out0.size()) 324 | print('out1 size:', out1.size()) 325 | 326 | print(out0) 327 | -------------------------------------------------------------------------------- /src/evaluate.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python evaluate_joint.py --resume checkpoints/dorn.cpkt --evaluate normal --root ./data 2 | -------------------------------------------------------------------------------- /src/evaluate_joint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | os.environ['TORCH_HOME'] = './' 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | from torchvision import transforms 7 | from torch.utils.data.dataset import Dataset # For custom datasets 8 | from torch.utils.data import DataLoader 9 | 10 | from dorn import DORN 11 | from dataset import AffineTestsDataset 12 | import numpy as np 13 | from tensorboardX import SummaryWriter 14 | import argparse 15 | import skimage.io as sio 16 | import scipy.misc as misc 17 | num_epochs = 10000 18 | batch_size = 1 19 | 20 | parser = argparse.ArgumentParser(description='Process saome integers.') 21 | parser.add_argument('--resume', type=str, default='') 22 | parser.add_argument('--scale', type=int, default=2) 23 | parser.add_argument('--root', type=str, default=2) 24 | parser.add_argument('--evaluate', type=str,default='normal') 25 | args = parser.parse_args() 26 | 27 | cnn = DORN(channel=5,output_channel=13) 28 | 29 | criterion = nn.MSELoss(size_average=False) 30 | optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3); 31 | 32 | test_dataset = AffineTestsDataset(root=args.root, feat=0) 33 | test_dataloader = DataLoader(test_dataset, batch_size=batch_size, 34 | shuffle=False, num_workers=0) 35 | 36 | cnn = cnn.cuda() 37 | 38 | if (args.resume != ''): 39 | cnn.load_state_dict(torch.load(args.resume)) 40 | 41 | def log(str): 42 | if args.save != '': 43 | fp.write('%s\n'%(str)) 44 | fp.flush() 45 | #os.fsync(fp) 46 | print(str) 47 | 48 | n_iter = 0 49 | m_iter = 0 50 | 51 | test_iter = iter(test_dataloader) 52 | 53 | errors = None 54 | errors0 = None 55 | errors1 = None 56 | errors2 = None 57 | 58 | def ConvertToAngle(Q): 59 | angle1 = torch.atan2(Q[:,1:2,:,:],Q[:,0:1,:,:]) / np.pi * 180 60 | angle2 = torch.atan2(Q[:,3:4,:,:],Q[:,2:3,:,:]) / np.pi * 180 61 | angles = torch.cat([angle1, angle2], dim=1) 62 | 63 | q1 = torch.cos(angle1 / 180.0 * np.pi) 64 | q2 = torch.sin(angle1 / 180.0 * np.pi) 65 | 66 | return angles 67 | 68 | def ConvertToDirection(Q): 69 | x1 = torch.cos(Q[:,0:1,:,:] / 180.0 * np.pi) 70 | y1 = torch.sin(Q[:,0:1,:,:] / 180.0 * np.pi) 71 | x2 = torch.cos(Q[:,1:2,:,:] / 180.0 * np.pi) 72 | y2 = torch.sin(Q[:,1:2,:,:] / 180.0 * np.pi) 73 | 74 | return torch.cat([x1,y1,x2,y2], dim=1) 75 | 76 | def Rotate90(Q): 77 | Q0 = Q.clone() 78 | Q0[:,0,:,:] = Q[:,2,:,:] 79 | Q0[:,1,:,:] = Q[:,3,:,:] 80 | Q0[:,2,:,:] = -Q[:,0,:,:] 81 | Q0[:,3,:,:] = -Q[:,1,:,:] 82 | return Q0 83 | 84 | def Normalize(dir_x): 85 | dir_x_l = torch.sqrt(torch.sum(dir_x ** 2,dim=1) + 1e-7).view(dir_x.shape[0],1,dir_x.shape[2],dir_x.shape[3]) 86 | #dir_x_l = (torch.norm(dir_x, p=2, dim=1) + 1e-7).view(dir_x.shape[0],1,dir_x.shape[2],dir_x.shape[3]) 87 | if dir_x.shape[1] == 3: 88 | dir_x_l = torch.cat([dir_x_l, dir_x_l, dir_x_l], dim=1) 89 | elif dir_x.shape[1] == 2: 90 | dir_x_l = torch.cat([dir_x_l, dir_x_l], dim=1) 91 | return dir_x / dir_x_l 92 | 93 | def train_one_iter(i, sample_batched, evaluate=0): 94 | global errors, errors0, errors1, errors2 95 | cnn.eval() 96 | images = sample_batched['image'] 97 | labels = sample_batched['label'] 98 | masks_tensor = sample_batched['mask'] > 0 99 | X = sample_batched['X'].cuda() 100 | Y = sample_batched['Y'].cuda() 101 | 102 | images_tensor = Variable(images.float()) 103 | labels_tensor = Variable(labels.float()) 104 | 105 | images_tensor, labels_tensor = images_tensor.cuda(), labels_tensor.cuda() 106 | 107 | masks_tensor = masks_tensor.cuda() 108 | 109 | masks_tensor = masks_tensor.float() 110 | elems = torch.sum(masks_tensor).item() 111 | if elems == 0: 112 | return 113 | # Forward + Backward + Optimize 114 | optimizer.zero_grad() 115 | 116 | if args.evaluate == 'normal': 117 | outputs = cnn(images_tensor)[:,10:13,:,:] 118 | 119 | norm1 = Normalize(outputs) 120 | norm2 = Normalize(torch.cross(X,Y)) 121 | mask = masks_tensor 122 | 123 | mask = mask[0].data.cpu().numpy() 124 | dot_product = torch.sum(norm1 * norm2, dim=1) 125 | dot_product = torch.clamp(dot_product,min=-1.0,max=1.0) 126 | angles = torch.acos(dot_product) * masks_tensor / np.pi * 180 127 | 128 | norm1_copy = norm1.clone() 129 | norm2_copy = norm2.clone() 130 | 131 | for j in range(3): 132 | norm1 = norm1_copy.clone() 133 | norm1[:,j,:,:] = 0 134 | norm2 = norm2_copy.clone() 135 | norm2[:,j,:,:] = 0 136 | norm1 = Normalize(norm1) 137 | norm2 = Normalize(norm2) 138 | dot_product = torch.sum(norm1 * norm2, dim=1) 139 | dot_product = torch.clamp(dot_product,min=-1.0,max=1.0) 140 | if j == 0: 141 | angles0 = torch.acos(dot_product) * masks_tensor / np.pi * 180 142 | elif j == 1: 143 | angles1 = torch.acos(dot_product) * masks_tensor / np.pi * 180 144 | else: 145 | angles2 = torch.acos(dot_product) * masks_tensor / np.pi * 180 146 | 147 | 148 | 149 | elif args.evaluate == 'projection': 150 | outputs = cnn(images_tensor)[:,0:4,:,:] 151 | preds = ConvertToAngle(outputs) 152 | mask = masks_tensor 153 | l0 = labels_tensor 154 | a1 = ConvertToAngle(l0) 155 | l1 = Rotate90(l0) 156 | a2 = ConvertToAngle(l1) 157 | l2 = Rotate90(l1) 158 | a3 = ConvertToAngle(l2) 159 | l3 = Rotate90(l2) 160 | a4 = ConvertToAngle(l3) 161 | d0 = preds - a1 162 | d0 = torch.min(torch.abs(d0), torch.min(torch.abs(d0 + 360), torch.abs(d0 - 360))) 163 | d0 = torch.sum(d0, dim=1) 164 | d1 = preds - a2 165 | d1 = torch.min(torch.abs(d1), torch.min(torch.abs(d1 + 360), torch.abs(d1 - 360))) 166 | d1 = torch.sum(d1, dim=1) 167 | d2 = preds - a3 168 | d2 = torch.min(torch.abs(d2), torch.min(torch.abs(d2 + 360), torch.abs(d2 - 360))) 169 | d2 = torch.sum(d2, dim=1) 170 | d3 = preds - a4 171 | d3 = torch.min(torch.abs(d3), torch.min(torch.abs(d3 + 360), torch.abs(d3 - 360))) 172 | d3 = torch.sum(d3, dim=1) 173 | d = torch.min(d0, torch.min(d1, torch.min(d2, d3))) 174 | d = d * mask 175 | angles = d / 2 176 | 177 | elif args.evaluate == 'principal': 178 | outputs = cnn(images_tensor)[:,4:10,:,:] 179 | dir_x = Normalize(outputs[:,0:3,:,:]) 180 | dir_y = Normalize(outputs[:,3:6,:,:]) 181 | X = Normalize(X) 182 | Y = Normalize(Y) 183 | 184 | angles0 = torch.acos(torch.clamp(torch.sum(dir_x*X,dim=1), min=-1.0,max=1.0)) * masks_tensor / np.pi * 180\ 185 | + torch.acos(torch.clamp(torch.sum(dir_y*Y,dim=1), min=-1.0,max=1.0)) * masks_tensor / np.pi * 180 186 | angles1 = torch.acos(torch.clamp(torch.sum(-dir_x*Y,dim=1), min=-1.0,max=1.0)) * masks_tensor / np.pi * 180\ 187 | + torch.acos(torch.clamp(torch.sum(dir_y*X,dim=1), min=-1.0,max=1.0)) * masks_tensor / np.pi * 180 188 | angles2 = torch.acos(torch.clamp(torch.sum(-dir_x*X,dim=1), min=-1.0,max=1.0)) * masks_tensor / np.pi * 180\ 189 | + torch.acos(torch.clamp(torch.sum(-dir_y*Y,dim=1), min=-1.0,max=1.0)) * masks_tensor / np.pi * 180 190 | angles3 = torch.acos(torch.clamp(torch.sum(dir_x*Y,dim=1), min=-1.0,max=1.0)) * masks_tensor / np.pi * 180\ 191 | + torch.acos(torch.clamp(torch.sum(-dir_y*X,dim=1), min=-1.0,max=1.0)) * masks_tensor / np.pi * 180 192 | 193 | angles = torch.min(angles0, torch.min(angles1, torch.min(angles2,angles3))) * 0.5 194 | mask1 = (angles == angles0).float() 195 | mask2 = (angles == angles1).float() 196 | mask3 = (angles == angles2).float() 197 | mask4 = (angles == angles3).float() 198 | 199 | selected_X = mask1 * X - mask2 * Y - mask3 * X + mask4 * Y 200 | selected_Y = mask1 * Y + mask2 * X - mask3 * Y - mask4 * X 201 | 202 | for j in range(3): 203 | dir_x_copy = dir_x.clone() 204 | dir_y_copy = dir_y.clone() 205 | X_copy = selected_X.clone() 206 | Y_copy = selected_Y.clone() 207 | X_copy[:,j,:,:] = 0 208 | Y_copy[:,j,:,:] = 0 209 | dir_x_copy[:,j,:,:] = 0 210 | dir_y_copy[:,j,:,:] = 0 211 | X_copy = Normalize(X_copy) 212 | Y_copy = Normalize(Y_copy) 213 | dir_x_copy = Normalize(dir_x_copy) 214 | dir_y_copy = Normalize(dir_y_copy) 215 | 216 | a = torch.acos(torch.clamp(torch.sum(dir_x_copy*X_copy,dim=1),min=-1.0,max=1.0)) * masks_tensor / np.pi * 180\ 217 | + torch.acos(torch.clamp(torch.sum(dir_y_copy*Y_copy,dim=1),min=-1.0,max=1.0)) * masks_tensor / np.pi * 180 218 | 219 | if j == 0: 220 | angles0 = a 221 | elif j == 1: 222 | angles1 = a 223 | else: 224 | angles2 = a 225 | 226 | 227 | masks_np = masks_tensor.data.cpu().numpy() > 0 228 | 229 | angles_np = angles.data.cpu().numpy() 230 | angles_np = angles_np[masks_np] 231 | if errors is None: 232 | errors = angles_np.copy() 233 | else: 234 | errors = np.concatenate((errors, angles_np)) 235 | 236 | if i % 10 == 0 or i > 320: 237 | print('Item %d of %d: Mean %f, Median %f, Rmse %f, delta1 %f, delta2 %f delta3 %f'%(i,len(test_dataset),np.average(errors), np.median(errors), np.sqrt(np.sum(errors * errors)/errors.shape),\ 238 | np.sum(errors < 11.25) / errors.shape[0],np.sum(errors < 22.5) / errors.shape[0],np.sum(errors < 30) / errors.shape[0])) 239 | del images_tensor, labels_tensor, outputs, masks_tensor 240 | 241 | for i, sample_batched in enumerate(test_dataloader): 242 | train_one_iter(i, sample_batched, 2) 243 | -------------------------------------------------------------------------------- /src/train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train_affine_dorn.py --train 1 --logtype final_run --save final_model --root ./data 2 | -------------------------------------------------------------------------------- /src/train_affine_dorn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | os.environ['TORCH_HOME'] = './' 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | from torchvision import transforms 7 | from torch.utils.data.dataset import Dataset # For custom datasets 8 | from torch.utils.data import DataLoader 9 | 10 | from dorn import DORN 11 | from dataset import AffineDataset 12 | import numpy as np 13 | from tensorboardX import SummaryWriter 14 | import argparse 15 | import skimage.io as sio 16 | import sys 17 | #import CudaRender.render as render 18 | import math 19 | from time import time 20 | 21 | 22 | num_epochs = 10000 23 | 24 | parser = argparse.ArgumentParser(description='Process saome integers.') 25 | parser.add_argument('--resume', type=str, default='') 26 | parser.add_argument('--train', type=int, default=1) 27 | parser.add_argument('--logtype', type=str, default='') 28 | parser.add_argument('--save', type=str, default='') 29 | parser.add_argument('--scale', type=int, default='2') 30 | parser.add_argument('--eval', type=int, default='1') 31 | parser.add_argument('--use_min', type=int, default='0') 32 | parser.add_argument('--horizontal', type=int, default='0') 33 | parser.add_argument('--batch_size', type=int, default='8') 34 | parser.add_argument('--joint', type=int, default='1') 35 | parser.add_argument('--lr', type=int, default=1e-3) 36 | parser.add_argument('--root', type=str) 37 | args = parser.parse_args() 38 | 39 | batch_size = args.batch_size 40 | train_dataset = AffineDataset(usage='train', root=args.root) 41 | dataloader = DataLoader(train_dataset, batch_size=batch_size, 42 | shuffle=True, num_workers=0) 43 | 44 | test_dataset = AffineDataset(usage='test', root=args.root) 45 | test_dataloader = DataLoader(test_dataset, batch_size=batch_size, 46 | shuffle=False, num_workers=0) 47 | 48 | 49 | val_dataset = AffineDataset(usage='test', root=args.root) 50 | val_dataloader = DataLoader(val_dataset, batch_size=batch_size, 51 | shuffle=False, num_workers=0) 52 | 53 | if args.save != '': 54 | if not os.path.exists(args.save): 55 | os.mkdir(args.save) 56 | #instance of the Conv Net 57 | cnn = DORN(channel=5,output_channel=13) 58 | if args.logtype != '': 59 | writer = SummaryWriter(logdir='./dorn-resume') 60 | #loss function and optimizer 61 | criterion = nn.MSELoss(size_average=False) 62 | optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3); 63 | 64 | cnn = cnn.cuda() 65 | 66 | if (args.resume != ''): 67 | state = cnn.state_dict() 68 | state.update(torch.load(args.resume)) 69 | cnn.load_state_dict(state) 70 | 71 | if args.save != '': 72 | fp = open(args.save + '/logs.txt', 'w') 73 | def log(str): 74 | if args.save != '': 75 | fp.write('%s\n'%(str)) 76 | fp.flush() 77 | #os.fsync(fp) 78 | print(str) 79 | 80 | s = 'python train_affine_dorn.py' 81 | for j in sys.argv: 82 | s += ' ' + j 83 | log(s) 84 | 85 | def ConvertToAngle(Q): 86 | angle1 = torch.atan2(Q[:,1:2,:,:],Q[:,0:1,:,:]) / np.pi * 180 87 | angle2 = torch.atan2(Q[:,3:4,:,:],Q[:,2:3,:,:]) / np.pi * 180 88 | angles = torch.cat([angle1, angle2], dim=1) 89 | 90 | q1 = torch.cos(angle1 / 180.0 * np.pi) 91 | q2 = torch.sin(angle1 / 180.0 * np.pi) 92 | 93 | return angles 94 | 95 | def ConvertToDirection(Q): 96 | x1 = torch.cos(Q[:,0:1,:,:] / 180.0 * np.pi) 97 | y1 = torch.sin(Q[:,0:1,:,:] / 180.0 * np.pi) 98 | x2 = torch.cos(Q[:,1:2,:,:] / 180.0 * np.pi) 99 | y2 = torch.sin(Q[:,1:2,:,:] / 180.0 * np.pi) 100 | 101 | return torch.cat([x1,y1,x2,y2], dim=1) 102 | 103 | def Rotate90(Q): 104 | Q0 = Q.clone() 105 | Q0[:,0,:,:] = Q[:,2,:,:] 106 | Q0[:,1,:,:] = Q[:,3,:,:] 107 | Q0[:,2,:,:] = -Q[:,0,:,:] 108 | Q0[:,3,:,:] = -Q[:,1,:,:] 109 | return Q0 110 | 111 | def RemoveAngleAmbiguity(Q0): 112 | Q = Q0.clone() 113 | mask = Q[:,0:1,:,:] > 90 114 | mask = torch.cat([mask, mask], dim=1) 115 | while torch.sum(mask).item() > 0: 116 | Q -= 90 * mask.float() 117 | mask = Q[:,0:1,:,:] > 90 118 | mask = torch.cat([mask, mask], dim=1) 119 | mask = Q[:,0:1,:,:] < 0 120 | mask = torch.cat([mask, mask], dim=1) 121 | while torch.sum(mask).item() > 0: 122 | Q += 90 * mask.float() 123 | mask = Q[:,0:1,:,:] < 0 124 | mask = torch.cat([mask, mask], dim=1) 125 | 126 | mask = (Q > 45).float() 127 | return mask * (90 - Q) + (1 - mask) * Q, Q 128 | 129 | n_iter = 0 130 | m_iter = 0 131 | 132 | test_iter = iter(test_dataloader) 133 | val_iter = iter(val_dataloader) 134 | 135 | def Normalize(dir_x): 136 | dir_x_l = torch.sqrt(torch.sum(dir_x ** 2,dim=1) + 1e-7).view(dir_x.shape[0],1,dir_x.shape[2],dir_x.shape[3]) 137 | dir_x_l = torch.cat([dir_x_l, dir_x_l, dir_x_l], dim=1) 138 | return dir_x / dir_x_l 139 | 140 | def train_one_iter(i, sample_batched, evaluate=0): 141 | global n_iter, m_iter 142 | cnn.train() 143 | if evaluate > 0 and args.eval == 1: 144 | cnn.eval() 145 | images = sample_batched['image'] 146 | labels = sample_batched['label'] 147 | labels_alt = sample_batched['label_alt'] 148 | mask_alt = sample_batched['mask'] 149 | tmask = mask_alt.clone() 150 | X = sample_batched['X'] 151 | Y = sample_batched['Y'] 152 | 153 | images_tensor = Variable(images.float()) 154 | labels_tensor = Variable(labels) 155 | labels_alt_tensor = Variable(labels_alt) 156 | mask_alt_tensor = Variable(mask_alt) 157 | 158 | images_tensor, labels_tensor, labels_alt_tensor, mask_alt_tensor = images_tensor.cuda(), labels_tensor.cuda(), labels_alt_tensor.cuda(), mask_alt_tensor.cuda() 159 | 160 | mask = (labels_tensor[:,0:1,:,:] * labels_tensor[:,0:1,:,:] + labels_tensor[:,1:2,:,:] * labels_tensor[:,1:2,:,:]) > 0.2 161 | if args.horizontal == 0: 162 | mask_alt_tensor = ((mask_alt_tensor < 0.9) & (mask_alt_tensor > 0.1)).view(mask_alt_tensor.shape[0],1,mask_alt_tensor.shape[1],mask_alt_tensor.shape[2]) 163 | mask = mask & mask_alt_tensor 164 | elif args.horizontal == 1: 165 | mask_alt_tensor = ((mask_alt_tensor > 0.9)).view(mask_alt_tensor.shape[0],1,mask_alt_tensor.shape[1],mask_alt_tensor.shape[2]) 166 | mask = mask & mask_alt_tensor 167 | 168 | mask = mask.float() 169 | X = sample_batched['X'].cuda() 170 | Y = sample_batched['Y'].cuda() 171 | 172 | elems = torch.sum(mask).item() * 2 173 | if elems == 0: 174 | return 175 | # Forward + Backward + Optimize 176 | optimizer.zero_grad() 177 | 178 | outputs_temp = cnn(images_tensor) 179 | outputs = outputs_temp[:,0:4,:,:] 180 | outputs2 = outputs_temp[:,4:10,:,:] 181 | norm1 = outputs_temp[:,10:13,:,:] 182 | dir_x = Normalize(outputs2[:,0:3,:,:]) 183 | dir_y = Normalize(outputs2[:,3:6,:,:]) 184 | 185 | preds = ConvertToAngle(outputs) 186 | 187 | l0 = labels_tensor 188 | a1 = ConvertToAngle(l0) 189 | l1 = Rotate90(l0) 190 | a2 = ConvertToAngle(l1) 191 | l2 = Rotate90(l1) 192 | a3 = ConvertToAngle(l2) 193 | l3 = Rotate90(l2) 194 | a4 = ConvertToAngle(l3) 195 | 196 | if args.use_min == 0: 197 | d0 = preds - a1 198 | d0 = torch.min(torch.abs(d0), torch.min(torch.abs(d0 + 360), torch.abs(d0 - 360))) 199 | d0 = torch.sum(d0, dim=1) 200 | d0 = d0.view(d0.shape[0], 1, d0.shape[1], d0.shape[2]) 201 | d = d0 * mask 202 | loss = torch.sum(d) 203 | diff = (outputs - l0) ** 2 204 | diff = torch.sum(diff, dim=1).view(outputs.shape[0], 1, outputs.shape[2], outputs.shape[3]) 205 | diff = diff * mask 206 | mse_loss = torch.sum(diff) 207 | 208 | diff_2a = torch.sum((dir_x - X) ** 2, dim=1).view(outputs.shape[0], 1, outputs.shape[2], outputs.shape[3]) 209 | diff_2b = torch.sum((dir_y - Y) ** 2, dim=1).view(outputs.shape[0], 1, outputs.shape[2], outputs.shape[3]) 210 | mse_loss_2 = torch.sum((diff_2a + diff_2b) * mask) 211 | 212 | else: 213 | d0 = preds - a1 214 | d0 = torch.min(torch.abs(d0), torch.min(torch.abs(d0 + 360), torch.abs(d0 - 360))) 215 | d0 = torch.sum(d0, dim=1) 216 | d1 = preds - a2 217 | d1 = torch.min(torch.abs(d1), torch.min(torch.abs(d1 + 360), torch.abs(d1 - 360))) 218 | d1 = torch.sum(d1, dim=1) 219 | d2 = preds - a3 220 | d2 = torch.min(torch.abs(d2), torch.min(torch.abs(d2 + 360), torch.abs(d2 - 360))) 221 | d2 = torch.sum(d2, dim=1) 222 | d3 = preds - a4 223 | d3 = torch.min(torch.abs(d3), torch.min(torch.abs(d3 + 360), torch.abs(d3 - 360))) 224 | d3 = torch.sum(d3, dim=1) 225 | d = torch.min(d0, torch.min(d1, torch.min(d2, d3))) 226 | d = d.view(d.shape[0], 1, d.shape[1], d.shape[2]) 227 | d = d * mask 228 | loss = torch.sum(d) 229 | 230 | diff1 = torch.sum((outputs - l0) ** 2, dim=1) 231 | diff2 = torch.sum((outputs - l1) ** 2, dim=1) 232 | diff3 = torch.sum((outputs - l2) ** 2, dim=1) 233 | diff4 = torch.sum((outputs - l3) ** 2, dim=1) 234 | diff = torch.min(diff1, torch.min(diff2, torch.min(diff3, diff4))) 235 | diff = diff.view(diff.shape[0], 1, diff.shape[1], diff.shape[2]) 236 | mse_loss = torch.sum(diff * mask) 237 | 238 | diff_2a = torch.sum((dir_x - X) ** 2, dim=1).view(outputs.shape[0], 1, outputs.shape[2], outputs.shape[3]) 239 | diff_2b = torch.sum((dir_y - Y) ** 2, dim=1).view(outputs.shape[0], 1, outputs.shape[2], outputs.shape[3]) 240 | diff_2_x = diff_2a + diff_2b 241 | 242 | diff_2a = torch.sum((dir_x - Y) ** 2, dim=1).view(outputs.shape[0], 1, outputs.shape[2], outputs.shape[3]) 243 | diff_2b = torch.sum((dir_y + X) ** 2, dim=1).view(outputs.shape[0], 1, outputs.shape[2], outputs.shape[3]) 244 | diff_2_y = diff_2a + diff_2b 245 | 246 | diff_2a = torch.sum((dir_x + X) ** 2, dim=1).view(outputs.shape[0], 1, outputs.shape[2], outputs.shape[3]) 247 | diff_2b = torch.sum((dir_y + Y) ** 2, dim=1).view(outputs.shape[0], 1, outputs.shape[2], outputs.shape[3]) 248 | diff_2_z = diff_2a + diff_2b 249 | 250 | diff_2a = torch.sum((dir_x + Y) ** 2, dim=1).view(outputs.shape[0], 1, outputs.shape[2], outputs.shape[3]) 251 | diff_2b = torch.sum((dir_y - X) ** 2, dim=1).view(outputs.shape[0], 1, outputs.shape[2], outputs.shape[3]) 252 | diff_2_w = diff_2a + diff_2b 253 | 254 | diff_2 = torch.min(diff_2_x, torch.min(diff_2_y, torch.min(diff_2_z, diff_2_w))) 255 | mse_loss_2 = torch.sum(diff_2 * mask) 256 | 257 | c_1 = dir_x[:,0,:,:] - images_tensor[:,3,:,:] * dir_x[:,2,:,:] - outputs[:,0,:,:] 258 | c_2 = dir_x[:,1,:,:] - images_tensor[:,4,:,:] * dir_x[:,2,:,:] - outputs[:,1,:,:] 259 | c_3 = dir_y[:,0,:,:] - images_tensor[:,3,:,:] * dir_y[:,2,:,:] - outputs[:,2,:,:] 260 | c_4 = dir_y[:,1,:,:] - images_tensor[:,4,:,:] * dir_y[:,2,:,:] - outputs[:,3,:,:] 261 | mse_loss_proj = torch.sum((c_1 ** 2 + c_2 ** 2 + c_3 ** 2 + c_4 ** 2).view(mask.shape[0],1,mask.shape[2],mask.shape[3]) * mask) 262 | 263 | norm0 = Normalize(torch.cross(dir_x, dir_y, dim=1)) 264 | norm1 = Normalize(norm1) 265 | norm2 = Normalize(torch.cross(X,Y,dim=1)) 266 | 267 | angle = torch.acos(torch.clamp(torch.sum(norm1 * norm2, dim=1), -1, 1)) / np.pi * 180 268 | angle = angle.view(mask.shape[0],1,mask.shape[2],mask.shape[3]) * mask 269 | angle = torch.sum(angle) 270 | 271 | mse_loss_norm = torch.sum(torch.sum((norm1 - norm0)**2,dim=1).view(mask.shape[0],1,mask.shape[2],mask.shape[3]) * mask) 272 | angle_loss = torch.sum(torch.sum((norm1 - norm2)**2, dim=1).view(mask.shape[0],1,mask.shape[2],mask.shape[3]) * mask) 273 | if args.train == 0: 274 | preds = ConvertToDirection(preds) 275 | #labels_tensor = ConvertToDirection(labels_tensor) 276 | mask = torch.cat([mask, mask, mask, mask], dim=1).float() 277 | #preds *= mask 278 | labels_tensor *= mask 279 | for j in range(preds.shape[0]): 280 | im = images_tensor[j].data.cpu().numpy() 281 | im = np.ascontiguousarray(np.swapaxes(np.swapaxes(im, 0, 1), 1, 2)) 282 | pred = preds[j].data.cpu().numpy() 283 | pred = np.ascontiguousarray(np.swapaxes(np.swapaxes(pred, 0, 1), 1, 2)) 284 | label = labels_tensor[j].data.cpu().numpy() 285 | label = np.ascontiguousarray(np.swapaxes(np.swapaxes(label, 0, 1), 1, 2)) 286 | m = tmask[j].numpy() 287 | m = (m * 255).astype('uint8') 288 | label = label / np.max(np.abs(label)) 289 | pred = pred / np.max(np.abs(pred)) 290 | sio.imsave('preds/pred-%06d-color.png'%(m_iter*preds.shape[0]+j), im[:,:,0:3]) 291 | 292 | 293 | normal = norm2[j].data.cpu().numpy() 294 | normal = np.ascontiguousarray(np.swapaxes(np.swapaxes(normal, 0, 1), 1, 2)) 295 | sio.imsave('preds/pred-%06d-normal-gt.png'%(m_iter*preds.shape[0]+j), normal * 0.5 + 0.5) 296 | normal_pred = norm1[j].data.cpu().numpy() 297 | normal_pred = np.ascontiguousarray(np.swapaxes(np.swapaxes(normal_pred, 0, 1), 1, 2)) 298 | sio.imsave('preds/pred-%06d-normal-pred.png'%(m_iter*preds.shape[0]+j), normal_pred * 0.5 + 0.5) 299 | 300 | diff = normal_pred - normal 301 | diff = (np.sqrt(np.sum(diff * diff, axis=2)) * 512).astype('uint8') * (m > 0) 302 | sio.imsave('preds/pred-%06d-normal-diff.png'%(m_iter*preds.shape[0]+j), diff) 303 | #sio.imsave('preds/pred-%06d-mask.png'%(m_iter*preds.shape[0]+j), (mask[j][0].data.cpu().numpy() * 255).astype('uint8')) 304 | 305 | color = np.ascontiguousarray(im[:,:,0:3]).astype('float32') 306 | #try: 307 | Qx = np.ascontiguousarray(label[:,:,0:2].astype('float32')) 308 | Qy = np.ascontiguousarray(label[:,:,2:4].astype('float32')) 309 | #render.visualizeDirection('preds/pred-%06d-vis-gt.png'%(m_iter*preds.shape[0]+j), color, Qx, Qy) 310 | 311 | 312 | diff1 = torch.sum((outputs - l0) ** 2, dim=1) 313 | diff2 = torch.sum((outputs - l1) ** 2, dim=1) 314 | diff3 = torch.sum((outputs - l2) ** 2, dim=1) 315 | diff4 = torch.sum((outputs - l3) ** 2, dim=1) 316 | diff = torch.min(diff1, torch.min(diff2, torch.min(diff3, diff4))) 317 | mask1 = (diff == diff1).data.cpu().numpy()[j] 318 | mask2 = (diff == diff2).data.cpu().numpy()[j] 319 | mask3 = (diff == diff3).data.cpu().numpy()[j] 320 | mask4 = (diff == diff4).data.cpu().numpy()[j] 321 | 322 | mask1 = np.tile(np.reshape(mask1, (mask1.shape[0],mask1.shape[1],1)), (1,1,2)) 323 | mask2 = np.tile(np.reshape(mask2, (mask2.shape[0],mask2.shape[1],1)), (1,1,2)) 324 | mask3 = np.tile(np.reshape(mask3, (mask3.shape[0],mask3.shape[1],1)), (1,1,2)) 325 | mask4 = np.tile(np.reshape(mask4, (mask4.shape[0],mask4.shape[1],1)), (1,1,2)) 326 | 327 | Qx1 = np.ascontiguousarray(pred[:,:,0:2].astype('float32')) 328 | Qy1 = np.ascontiguousarray(pred[:,:,2:4].astype('float32')) 329 | 330 | Qx2 = np.ascontiguousarray(pred[:,:,2:4].astype('float32')) 331 | Qy2 = np.ascontiguousarray(-pred[:,:,0:2].astype('float32')) 332 | 333 | Qx3 = np.ascontiguousarray(-pred[:,:,0:2].astype('float32')) 334 | Qy3 = np.ascontiguousarray(-pred[:,:,2:4].astype('float32')) 335 | 336 | Qx4 = np.ascontiguousarray(-pred[:,:,2:4].astype('float32')) 337 | Qy4 = np.ascontiguousarray(pred[:,:,0:2].astype('float32')) 338 | 339 | Qx = Qx1 * mask1 + Qx2 * mask2 + Qx3 * mask3 + Qx4 * mask4 340 | Qy = Qy1 * mask1 + Qy2 * mask2 + Qy3 * mask3 + Qy4 * mask4 341 | 342 | m_iter += 1 343 | else: 344 | if evaluate == 0: 345 | if args.joint == 1: 346 | losses = mse_loss + mse_loss_2 + angle_loss + mse_loss_proj * 5 + mse_loss_norm * 5# + angle / 200.0 347 | elif args.joint == 0: 348 | losses = angle_loss 349 | elif args.joint == 2: 350 | losses = angle_loss + mse_loss_2 351 | elif args.joint == 3: 352 | losses = angle_loss + mse_loss_2 + mse_loss 353 | elif args.joint == 4: 354 | losses = angle_loss + mse_loss_2 + mse_loss + mse_loss_norm * 5 355 | elif args.joint == 5: 356 | losses = angle_loss + mse_loss_2 + mse_loss + mse_loss_proj * 5 357 | losses.backward() 358 | optimizer.step() 359 | 360 | if args.logtype != '': 361 | if evaluate == 0: 362 | writer.add_scalar(args.logtype + '/project_loss', mse_loss.item() / elems, n_iter) 363 | writer.add_scalar(args.logtype + '/3D_loss', mse_loss_2.item() / elems, n_iter) 364 | writer.add_scalar(args.logtype + '/Consistency_loss', mse_loss_proj.item() / elems, n_iter) 365 | writer.add_scalar(args.logtype + '/Normal_Consistency_loss', mse_loss_norm.item() / elems, n_iter) 366 | writer.add_scalar(args.logtype + '/projection_err', loss.item() / elems, n_iter) 367 | writer.add_scalar(args.logtype + '/normal_err', angle.item() / elems * 2, n_iter) 368 | n_iter += 1 369 | elif evaluate == 1: 370 | writer.add_scalar(args.logtype + '/val_loss', mse_loss.item() / elems, m_iter) 371 | writer.add_scalar(args.logtype + '/val_loss2', mse_loss_2.item() / elems, n_iter) 372 | writer.add_scalar(args.logtype + '/val_loss3', mse_loss_proj.item() / elems, n_iter) 373 | writer.add_scalar(args.logtype + '/val_err', loss.item() / elems, m_iter) 374 | m_iter += 1 375 | else: 376 | writer.add_scalar(args.logtype + '/test_project_loss', mse_loss.item() / elems, m_iter) 377 | writer.add_scalar(args.logtype + '/test_3D_loss', mse_loss_2.item() / elems, m_iter) 378 | writer.add_scalar(args.logtype + '/test_Consistency_loss', mse_loss_proj.item() / elems, m_iter) 379 | writer.add_scalar(args.logtype + '/test_Normal_Consistency_loss', mse_loss_norm.item() / elems, m_iter) 380 | writer.add_scalar(args.logtype + '/test_projection_err', loss.item() / elems, m_iter) 381 | writer.add_scalar(args.logtype + '/test_normal_err', angle.item() / elems * 2, m_iter) 382 | m_iter += 1 383 | 384 | if evaluate == 0: 385 | log ('Epoch : %d/%d, Iter : %d/%d, Loss: <%.4f, %.4f, %.4f>, Err: <%.4f %.4f>' 386 | %(epoch+1, num_epochs, i+1, len(train_dataset)//batch_size, mse_loss.item()/elems, mse_loss_2.item()/elems, mse_loss_proj.item()/elems, loss.item() / elems, angle.item() / elems * 2)) 387 | else: 388 | log ('(Test) Epoch : %d/%d, Iter : %d/%d, Loss: <%.4f, %.4f, %.4f>, Err: <%.4f %.4f>' 389 | %(epoch+1, num_epochs, i+1, len(train_dataset)//batch_size, mse_loss.item()/elems, mse_loss_2.item()/elems, mse_loss_proj.item()/elems, loss.item() / elems, angle.item() / elems * 2)) 390 | del images_tensor, labels_tensor, loss, outputs, mask 391 | 392 | for epoch in range(num_epochs): 393 | if args.train == 1: 394 | for i, sample_batched in enumerate(dataloader): 395 | #print('start train') 396 | if i % 8 == 0: 397 | m_iter += 1 398 | try: 399 | sample_batched_t = next(test_iter) 400 | except: 401 | test_iter = iter(test_dataloader) 402 | sample_batched_t = next(test_iter) 403 | train_one_iter(i, sample_batched_t, 2) 404 | 405 | train_one_iter(i, sample_batched, 0) 406 | 407 | if i % 1000 == 0 and args.save != '': 408 | path = args.save + '/model-epoch-%05d-iter-%05d.cpkt'%(epoch, i) 409 | torch.save(cnn.state_dict(), path) 410 | 411 | if args.train == 0: 412 | correct_num = 0 413 | total_num = 0 414 | for i, sample_batched in enumerate(test_dataloader): 415 | train_one_iter(i, sample_batched, 2) 416 | m_iter += 1 417 | break 418 | 419 | if epoch == 2: 420 | args.use_min = 1 421 | args.horizontal = 2 422 | 423 | writer.close() 424 | -------------------------------------------------------------------------------- /src/visualize_field.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | os.environ['TORCH_HOME'] = './' 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | from torchvision import transforms 7 | from torch.utils.data.dataset import Dataset # For custom datasets 8 | from torch.utils.data import DataLoader 9 | 10 | from model import UNet, UNet_2 11 | from dorn import DORN 12 | from dataset import AffineDataset, AffineTestsDataset 13 | import numpy as np 14 | from tensorboardX import SummaryWriter 15 | import argparse 16 | import skimage.io as sio 17 | import sys 18 | import Render.render as render 19 | import math 20 | from time import time 21 | 22 | train_dataset = AffineTestsDataset(feat=0,root='data') 23 | 24 | sample_batched = train_dataset[35] 25 | color = np.ascontiguousarray(np.transpose(sample_batched['image'][0:3,:,:], (1, 2, 0)).astype('float32')) 26 | labels = np.transpose(sample_batched['label'].numpy(), (1, 2, 0)) 27 | Qx = np.ascontiguousarray(labels[:,:,0:2].astype('float32')) 28 | Qy = np.ascontiguousarray(labels[:,:,2:4].astype('float32')) 29 | 30 | sio.imsave('color.png', color) 31 | render.visualizeDirection('vis.png', color, Qx, Qy) 32 | --------------------------------------------------------------------------------