├── .gitignore ├── LICENSE ├── README.md ├── data ├── celeba_crop.py ├── celeba_crop_bbox.txt ├── download_cat.sh ├── download_syncar.sh └── download_synface.sh ├── demo ├── demo.py ├── images │ ├── cat_face │ │ ├── 001_cat.png │ │ ├── 002_cat.png │ │ ├── 003_cat.png │ │ ├── 004_cat.png │ │ ├── 005_cat.png │ │ ├── 006_cat.png │ │ ├── 007_cat.png │ │ ├── 008_cat.png │ │ ├── 009_cat.png │ │ ├── 010_cat.png │ │ ├── 011_cat.png │ │ ├── 012_cat.png │ │ ├── 013_cat.png │ │ ├── 014_cat.png │ │ ├── 015_cat.png │ │ ├── 016_cat.png │ │ ├── 017_cat.png │ │ ├── 018_cat.png │ │ ├── 019_cat.png │ │ ├── 020_cat.png │ │ ├── 021_cat.png │ │ ├── 022_cat.png │ │ ├── 023_cat.png │ │ ├── 024_abstract.png │ │ ├── 025_abstract.png │ │ ├── 026_abstract.png │ │ ├── 027_abstract.png │ │ ├── 028_abstract.png │ │ ├── 029_abstract.png │ │ ├── 030_abstract.png │ │ ├── 031_abstract.png │ │ ├── 032_abstract.png │ │ ├── 033_abstract.png │ │ ├── 034_abstract.png │ │ ├── 035_abstract.png │ │ ├── 036_abstract.png │ │ ├── 037_abstract.png │ │ ├── 038_abstract.png │ │ ├── 039_abstract.png │ │ ├── 040_abstract.png │ │ ├── 041_abstract.png │ │ ├── 042_abstract.png │ │ ├── 043_abstract.png │ │ ├── 044_abstract.png │ │ └── 045_abstract.png │ └── human_face │ │ ├── 001_face.png │ │ ├── 002_face.png │ │ ├── 003_face.png │ │ ├── 004_face.png │ │ ├── 005_face.png │ │ ├── 006_face.png │ │ ├── 007_face.png │ │ ├── 008_face.png │ │ ├── 009_face.png │ │ ├── 010_face.png │ │ ├── 011_face.png │ │ ├── 012_face.png │ │ ├── 013_face.png │ │ ├── 014_face.png │ │ ├── 015_face.png │ │ ├── 016_paint.png │ │ ├── 017_paint.png │ │ ├── 018_paint.png │ │ ├── 019_paint.png │ │ ├── 020_paint.png │ │ ├── 021_paint.png │ │ ├── 022_paint.png │ │ ├── 023_paint.png │ │ ├── 024_paint.png │ │ ├── 025_paint.png │ │ ├── 026_paint.png │ │ ├── 027_paint.png │ │ ├── 028_paint.png │ │ ├── 029_paint.png │ │ ├── 030_paint.png │ │ ├── 031_abstract.png │ │ ├── 032_abstract.png │ │ ├── 033_abstract.png │ │ ├── 034_abstract.png │ │ ├── 035_abstract.png │ │ ├── 036_abstract.png │ │ ├── 037_abstract.png │ │ ├── 038_abstract.png │ │ ├── 039_abstract.png │ │ ├── 040_abstract.png │ │ ├── 041_abstract.png │ │ ├── 042_abstract.png │ │ ├── 043_abstract.png │ │ ├── 044_abstract.png │ │ └── 045_abstract.png └── utils.py ├── environment.yml ├── experiments ├── test_cat.yml ├── test_celeba.yml ├── test_syncar.yml ├── test_synface.yml ├── train_cat.yml ├── train_celeba.yml ├── train_syncar.yml └── train_synface.yml ├── img └── teaser.jpg ├── pretrained ├── download_pretrained_cat.sh ├── download_pretrained_celeba.sh ├── download_pretrained_syncar.sh └── download_pretrained_synface.sh ├── run.py └── unsup3d ├── __init__.py ├── dataloaders.py ├── meters.py ├── model.py ├── networks.py ├── renderer ├── __init__.py ├── renderer.py └── utils.py ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | data/*/ 3 | pretrained/*/ 4 | results 5 | neural_renderer 6 | *.zip 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 elliottwu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Learning of Probably Symmetric Deformable 3D Objects from Images in the Wild 2 | #### [Demo](http://www.robots.ox.ac.uk/~vgg/blog/unsupervised-learning-of-probably-symmetric-deformable-3d-objects-from-images-in-the-wild.html) | [Project Page](https://elliottwu.com/projects/unsup3d/) | [Video](https://www.youtube.com/watch?v=5rPJyrU-WE4) | [Paper](https://arxiv.org/abs/1911.11130) 3 | [Shangzhe Wu](https://elliottwu.com/), [Christian Rupprecht](https://chrirupp.github.io/), [Andrea Vedaldi](http://www.robots.ox.ac.uk/~vedaldi/), Visual Geometry Group, University of Oxford. In CVPR 2020 (Best Paper Award). 4 | 5 | 6 | 7 | We propose a method to learn weakly symmetric deformable 3D object categories from raw single-view images, without ground-truth 3D, multiple views, 2D/3D keypoints, prior shape models or any other supervision. 8 | 9 | 10 | ## Setup (with [Anaconda](https://www.anaconda.com/)) 11 | 12 | ### 1. Install dependencies: 13 | ``` 14 | conda env create -f environment.yml 15 | ``` 16 | OR manually: 17 | ``` 18 | conda install -c conda-forge scikit-image matplotlib opencv moviepy pyyaml tensorboardX 19 | ``` 20 | 21 | 22 | ### 2. Install [PyTorch](https://pytorch.org/): 23 | ``` 24 | conda install pytorch==1.2.0 torchvision==0.4.0 cudatoolkit=9.2 -c pytorch 25 | ``` 26 | *Note*: The code is tested with PyTorch 1.2.0 and CUDA 9.2 on CentOS 7. A GPU version is required for training and testing, since the [neural_renderer](https://github.com/daniilidis-group/neural_renderer) package only has GPU implementation. You are still able to run the demo without GPU. 27 | 28 | 29 | ### 3. Install [neural_renderer](https://github.com/daniilidis-group/neural_renderer): 30 | This package is required for training and testing, and optional for the demo. It requires a GPU device and GPU-enabled PyTorch. 31 | ``` 32 | pip install neural_renderer_pytorch 33 | ``` 34 | 35 | *Note*: It may fail if you have a GCC version below 5. If you do not want to upgrade your GCC, one alternative solution is to use conda's GCC and compile the package from source. For example: 36 | ``` 37 | conda install gxx_linux-64=7.3 38 | git clone https://github.com/daniilidis-group/neural_renderer.git 39 | cd neural_renderer 40 | python setup.py install 41 | ``` 42 | 43 | 44 | ### 4. (For demo only) Install [facenet-pytorch](https://github.com/timesler/facenet-pytorch): 45 | This package is optional for the demo. It allows automatic human face detection. 46 | ``` 47 | pip install facenet-pytorch 48 | ``` 49 | 50 | 51 | ## Datasets 52 | 1. [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) face dataset. Please download the original images (`img_celeba.7z`) from their [website](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) and run `celeba_crop.py` in `data/` to crop the images. 53 | 2. Synthetic face dataset generated using [Basel Face Model](https://faces.dmi.unibas.ch/bfm/). This can be downloaded using the script `download_synface.sh` provided in `data/`. 54 | 3. Cat face dataset composed of [Cat Head Dataset](http://academictorrents.com/details/c501571c29d16d7f41d159d699d0e7fb37092cbd) and [Oxford-IIIT Pet Dataset](http://www.robots.ox.ac.uk/~vgg/data/pets/) ([license](https://creativecommons.org/licenses/by-sa/4.0/)). This can be downloaded using the script `download_cat.sh` provided in `data/`. 55 | 4. Synthetic car dataset generated from [ShapeNet](https://shapenet.org/) cars. The images are rendered from with random viewpoints from the top, where the cars are primarily oriented vertically. This can be downloaded using the script `download_syncar.sh` provided in `data/`. 56 | 57 | Please remember to cite the corresponding papers if you use these datasets. 58 | 59 | 60 | ## Pretrained Models 61 | Download pretrained models using the scripts provided in `pretrained/`, eg: 62 | ``` 63 | cd pretrained && sh download_pretrained_celeba.sh 64 | ``` 65 | 66 | 67 | ## Demo 68 | ``` 69 | python -m demo.demo --input demo/images/human_face --result demo/results/human_face --checkpoint pretrained/pretrained_celeba/checkpoint030.pth 70 | ``` 71 | 72 | *Options*: 73 | - `--gpu`: enable GPU 74 | - `--detect_human_face`: enable automatic human face detection and cropping using [MTCNN](https://arxiv.org/abs/1604.02878) provided in [facenet-pytorch](https://github.com/timesler/facenet-pytorch). This only works on human face images. You will need to manually crop the images for other objects. 75 | - `--render_video`: render 3D animations using [neural_renderer](https://github.com/daniilidis-group/neural_renderer) (GPU is required) 76 | 77 | 78 | ## Training and Testing 79 | Check the configuration files in `experiments/` and run experiments, eg: 80 | ``` 81 | python run.py --config experiments/train_celeba.yml --gpu 0 --num_workers 4 82 | ``` 83 | 84 | 85 | ## Citation 86 | ``` 87 | @InProceedings{Wu_2020_CVPR, 88 | author = {Shangzhe Wu and Christian Rupprecht and Andrea Vedaldi}, 89 | title = {Unsupervised Learning of Probably Symmetric Deformable 3D Objects from Images in the Wild}, 90 | booktitle = {CVPR}, 91 | year = {2020} 92 | } 93 | ``` 94 | -------------------------------------------------------------------------------- /data/celeba_crop.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | 5 | 6 | im_dir = './img_celeba' 7 | out_dir = './celeba_cropped' 8 | bbox_fpath = './celeba_crop_bbox.txt' 9 | out_im_size = 128 10 | 11 | im_list = np.loadtxt(bbox_fpath, dtype='str') 12 | total_num = len(im_list) 13 | split_dict = {'0': 'train', 14 | '1': 'val', 15 | '2': 'test'} 16 | 17 | for i, row in enumerate(im_list): 18 | if i%1000 == 0: 19 | print(f'{i}/{total_num}') 20 | 21 | fname = row[0] 22 | split = row[1] 23 | x0, y0, w, h = row[2:].astype(int) 24 | im = cv2.imread(os.path.join(im_dir, fname)) 25 | im_pad = cv2.copyMakeBorder(im, h, h, w, w, cv2.BORDER_REPLICATE) # allow cropping outside by replicating borders 26 | im_crop = im_pad[y0+h:y0+h*2, x0+w:x0+w*2] 27 | im_crop = cv2.resize(im_crop, (out_im_size,out_im_size)) 28 | 29 | out_folder = os.path.join(out_dir, split_dict[split]) 30 | os.makedirs(out_folder, exist_ok=True) 31 | cv2.imwrite(os.path.join(out_folder, fname), im_crop) 32 | -------------------------------------------------------------------------------- /data/download_cat.sh: -------------------------------------------------------------------------------- 1 | echo "----------------------- downloading cat dataset -----------------------" 2 | curl -o cat_combined.zip "https://www.robots.ox.ac.uk/~vgg/research/unsup3d/data/cat_combined.zip" && unzip cat_combined.zip 3 | -------------------------------------------------------------------------------- /data/download_syncar.sh: -------------------------------------------------------------------------------- 1 | echo "----------------------- downloading synthetic car dataset -----------------------" 2 | curl -o syncar.zip "https://www.robots.ox.ac.uk/~vgg/research/unsup3d/data/syncar.zip" && unzip syncar.zip 3 | -------------------------------------------------------------------------------- /data/download_synface.sh: -------------------------------------------------------------------------------- 1 | echo "----------------------- downloading synthetic face dataset -----------------------" 2 | curl -o synface.zip "https://www.robots.ox.ac.uk/~vgg/research/unsup3d/data/synface.zip" && unzip synface.zip 3 | -------------------------------------------------------------------------------- /demo/demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from PIL import Image 4 | import torch 5 | import torch.nn as nn 6 | from .utils import * 7 | 8 | 9 | EPS = 1e-7 10 | 11 | 12 | class Demo(): 13 | def __init__(self, args): 14 | ## configs 15 | self.device = 'cuda:0' if args.gpu else 'cpu' 16 | self.checkpoint_path = args.checkpoint 17 | self.detect_human_face = args.detect_human_face 18 | self.render_video = args.render_video 19 | self.output_size = args.output_size 20 | self.image_size = 64 21 | self.min_depth = 0.9 22 | self.max_depth = 1.1 23 | self.border_depth = 1.05 24 | self.xyz_rotation_range = 60 25 | self.xy_translation_range = 0.1 26 | self.z_translation_range = 0 27 | self.fov = 10 # in degrees 28 | 29 | self.depth_rescaler = lambda d : (1+d)/2 *self.max_depth + (1-d)/2 *self.min_depth # (-1,1) => (min_depth,max_depth) 30 | self.depth_inv_rescaler = lambda d : (d-self.min_depth) / (self.max_depth-self.min_depth) # (min_depth,max_depth) => (0,1) 31 | 32 | fx = (self.image_size-1)/2/(np.tan(self.fov/2 *np.pi/180)) 33 | fy = (self.image_size-1)/2/(np.tan(self.fov/2 *np.pi/180)) 34 | cx = (self.image_size-1)/2 35 | cy = (self.image_size-1)/2 36 | K = [[fx, 0., cx], 37 | [0., fy, cy], 38 | [0., 0., 1.]] 39 | K = torch.FloatTensor(K).to(self.device) 40 | self.inv_K = torch.inverse(K).unsqueeze(0) 41 | self.K = K.unsqueeze(0) 42 | 43 | ## NN models 44 | self.netD = EDDeconv(cin=3, cout=1, nf=64, zdim=256, activation=None) 45 | self.netA = EDDeconv(cin=3, cout=3, nf=64, zdim=256) 46 | self.netL = Encoder(cin=3, cout=4, nf=32) 47 | self.netV = Encoder(cin=3, cout=6, nf=32) 48 | 49 | self.netD = self.netD.to(self.device) 50 | self.netA = self.netA.to(self.device) 51 | self.netL = self.netL.to(self.device) 52 | self.netV = self.netV.to(self.device) 53 | self.load_checkpoint() 54 | 55 | self.netD.eval() 56 | self.netA.eval() 57 | self.netL.eval() 58 | self.netV.eval() 59 | 60 | ## face detecter 61 | if self.detect_human_face: 62 | from facenet_pytorch import MTCNN 63 | self.face_detector = MTCNN(select_largest=True, device=self.device) 64 | 65 | ## renderer 66 | if self.render_video: 67 | from unsup3d.renderer import Renderer 68 | assert 'cuda' in self.device, 'A GPU device is required for rendering because the neural_renderer only has GPU implementation.' 69 | cfgs = { 70 | 'device': self.device, 71 | 'image_size': self.output_size, 72 | 'min_depth': self.min_depth, 73 | 'max_depth': self.max_depth, 74 | 'fov': self.fov, 75 | } 76 | self.renderer = Renderer(cfgs) 77 | 78 | def load_checkpoint(self): 79 | print(f"Loading checkpoint from {self.checkpoint_path}") 80 | cp = torch.load(self.checkpoint_path, map_location=self.device) 81 | self.netD.load_state_dict(cp['netD']) 82 | self.netA.load_state_dict(cp['netA']) 83 | self.netL.load_state_dict(cp['netL']) 84 | self.netV.load_state_dict(cp['netV']) 85 | 86 | def depth_to_3d_grid(self, depth, inv_K=None): 87 | if inv_K is None: 88 | inv_K = self.inv_K 89 | b, h, w = depth.shape 90 | grid_2d = get_grid(b, h, w, normalize=False).to(depth.device) # Nxhxwx2 91 | depth = depth.unsqueeze(-1) 92 | grid_3d = torch.cat((grid_2d, torch.ones_like(depth)), dim=3) 93 | grid_3d = grid_3d.matmul(inv_K.transpose(2,1)) * depth 94 | return grid_3d 95 | 96 | def get_normal_from_depth(self, depth): 97 | b, h, w = depth.shape 98 | grid_3d = self.depth_to_3d_grid(depth) 99 | 100 | tu = grid_3d[:,1:-1,2:] - grid_3d[:,1:-1,:-2] 101 | tv = grid_3d[:,2:,1:-1] - grid_3d[:,:-2,1:-1] 102 | normal = tu.cross(tv, dim=3) 103 | 104 | zero = normal.new_tensor([0,0,1]) 105 | normal = torch.cat([zero.repeat(b,h-2,1,1), normal, zero.repeat(b,h-2,1,1)], 2) 106 | normal = torch.cat([zero.repeat(b,1,w,1), normal, zero.repeat(b,1,w,1)], 1) 107 | normal = normal / (((normal**2).sum(3, keepdim=True))**0.5 + EPS) 108 | return normal 109 | 110 | def detect_face(self, im): 111 | print("Detecting face using MTCNN face detector") 112 | try: 113 | bboxes, prob = self.face_detector.detect(im) 114 | w0, h0, w1, h1 = bboxes[0] 115 | except: 116 | print("Could not detect faces in the image") 117 | return None 118 | 119 | hc, wc = (h0+h1)/2, (w0+w1)/2 120 | crop = int(((h1-h0) + (w1-w0)) /2/2 *1.1) 121 | im = np.pad(im, ((crop,crop),(crop,crop),(0,0)), mode='edge') # allow cropping outside by replicating borders 122 | h0 = int(hc-crop+crop + crop*0.15) 123 | w0 = int(wc-crop+crop) 124 | return im[h0:h0+crop*2, w0:w0+crop*2] 125 | 126 | def run(self, pil_im): 127 | im = np.uint8(pil_im) 128 | 129 | ## face detection 130 | if self.detect_human_face: 131 | im = self.detect_face(im) 132 | if im is None: 133 | return -1 134 | 135 | h, w, _ = im.shape 136 | im = torch.FloatTensor(im /255.).permute(2,0,1).unsqueeze(0) 137 | # resize to 128 first if too large, to avoid bilinear downsampling artifacts 138 | if h > self.image_size*4 and w > self.image_size*4: 139 | im = nn.functional.interpolate(im, (self.image_size*2, self.image_size*2), mode='bilinear', align_corners=False) 140 | im = nn.functional.interpolate(im, (self.image_size, self.image_size), mode='bilinear', align_corners=False) 141 | 142 | with torch.no_grad(): 143 | self.input_im = im.to(self.device) *2.-1. 144 | b, c, h, w = self.input_im.shape 145 | 146 | ## predict canonical depth 147 | self.canon_depth_raw = self.netD(self.input_im).squeeze(1) # BxHxW 148 | self.canon_depth = self.canon_depth_raw - self.canon_depth_raw.view(b,-1).mean(1).view(b,1,1) 149 | self.canon_depth = self.canon_depth.tanh() 150 | self.canon_depth = self.depth_rescaler(self.canon_depth) 151 | 152 | ## clamp border depth 153 | depth_border = torch.zeros(1,h,w-4).to(self.input_im.device) 154 | depth_border = nn.functional.pad(depth_border, (2,2), mode='constant', value=1) 155 | self.canon_depth = self.canon_depth*(1-depth_border) + depth_border *self.border_depth 156 | 157 | ## predict canonical albedo 158 | self.canon_albedo = self.netA(self.input_im) # Bx3xHxW 159 | 160 | ## predict lighting 161 | canon_light = self.netL(self.input_im) # Bx4 162 | self.canon_light_a = canon_light[:,:1] /2+0.5 # ambience term 163 | self.canon_light_b = canon_light[:,1:2] /2+0.5 # diffuse term 164 | canon_light_dxy = canon_light[:,2:] 165 | self.canon_light_d = torch.cat([canon_light_dxy, torch.ones(b,1).to(self.input_im.device)], 1) 166 | self.canon_light_d = self.canon_light_d / ((self.canon_light_d**2).sum(1, keepdim=True))**0.5 # diffuse light direction 167 | 168 | ## shading 169 | self.canon_normal = self.get_normal_from_depth(self.canon_depth) 170 | self.canon_diffuse_shading = (self.canon_normal * self.canon_light_d.view(-1,1,1,3)).sum(3).clamp(min=0).unsqueeze(1) 171 | canon_shading = self.canon_light_a.view(-1,1,1,1) + self.canon_light_b.view(-1,1,1,1)*self.canon_diffuse_shading 172 | self.canon_im = (self.canon_albedo/2+0.5) * canon_shading *2-1 173 | 174 | ## predict viewpoint transformation 175 | self.view = self.netV(self.input_im) 176 | self.view = torch.cat([ 177 | self.view[:,:3] *np.pi/180 *self.xyz_rotation_range, 178 | self.view[:,3:5] *self.xy_translation_range, 179 | self.view[:,5:] *self.z_translation_range], 1) 180 | 181 | ## export to obj strings 182 | vertices = self.depth_to_3d_grid(self.canon_depth) # BxHxWx3 183 | self.objs, self.mtls = export_to_obj_string(vertices, self.canon_normal) 184 | 185 | ## resize to output size 186 | self.canon_depth = nn.functional.interpolate(self.canon_depth.unsqueeze(1), (self.output_size, self.output_size), mode='bilinear', align_corners=False).squeeze(1) 187 | self.canon_normal = nn.functional.interpolate(self.canon_normal.permute(0,3,1,2), (self.output_size, self.output_size), mode='bilinear', align_corners=False).permute(0,2,3,1) 188 | self.canon_normal = self.canon_normal / (self.canon_normal**2).sum(3, keepdim=True)**0.5 189 | self.canon_diffuse_shading = nn.functional.interpolate(self.canon_diffuse_shading, (self.output_size, self.output_size), mode='bilinear', align_corners=False) 190 | self.canon_albedo = nn.functional.interpolate(self.canon_albedo, (self.output_size, self.output_size), mode='bilinear', align_corners=False) 191 | self.canon_im = nn.functional.interpolate(self.canon_im, (self.output_size, self.output_size), mode='bilinear', align_corners=False) 192 | 193 | if self.render_video: 194 | self.render_animation() 195 | 196 | def render_animation(self): 197 | print(f"Rendering video animations") 198 | b, h, w = self.canon_depth.shape 199 | 200 | ## morph from target view to canonical 201 | morph_frames = 15 202 | view_zero = torch.FloatTensor([0.15*np.pi/180*60, 0,0,0,0,0]).to(self.canon_depth.device) 203 | morph_s = torch.linspace(0, 1, morph_frames).to(self.canon_depth.device) 204 | view_morph = morph_s.view(-1,1,1) * view_zero.view(1,1,-1) + (1-morph_s.view(-1,1,1)) * self.view.unsqueeze(0) # TxBx6 205 | 206 | ## yaw from canonical to both sides 207 | yaw_frames = 80 208 | yaw_rotations = np.linspace(-np.pi/2, np.pi/2, yaw_frames) 209 | # yaw_rotations = np.concatenate([yaw_rotations[40:], yaw_rotations[::-1], yaw_rotations[:40]], 0) 210 | 211 | ## whole rotation sequence 212 | view_after = torch.cat([view_morph, view_zero.repeat(yaw_frames, b, 1)], 0) 213 | yaw_rotations = np.concatenate([np.zeros(morph_frames), yaw_rotations], 0) 214 | 215 | def rearrange_frames(frames): 216 | morph_seq = frames[:, :morph_frames] 217 | yaw_seq = frames[:, morph_frames:] 218 | out_seq = torch.cat([ 219 | morph_seq[:,:1].repeat(1,5,1,1,1), 220 | morph_seq, 221 | morph_seq[:,-1:].repeat(1,5,1,1,1), 222 | yaw_seq[:, yaw_frames//2:], 223 | yaw_seq.flip(1), 224 | yaw_seq[:, :yaw_frames//2], 225 | morph_seq[:,-1:].repeat(1,5,1,1,1), 226 | morph_seq.flip(1), 227 | morph_seq[:,:1].repeat(1,5,1,1,1), 228 | ], 1) 229 | return out_seq 230 | 231 | ## textureless shape 232 | front_light = torch.FloatTensor([0,0,1]).to(self.canon_depth.device) 233 | canon_shape_im = (self.canon_normal * front_light.view(1,1,1,3)).sum(3).clamp(min=0).unsqueeze(1) 234 | canon_shape_im = canon_shape_im.repeat(1,3,1,1) *0.7 235 | shape_animation = self.renderer.render_yaw(canon_shape_im, self.canon_depth, v_after=view_after, rotations=yaw_rotations) # BxTxCxHxW 236 | self.shape_animation = rearrange_frames(shape_animation) 237 | 238 | ## normal map 239 | canon_normal_im = self.canon_normal.permute(0,3,1,2) /2+0.5 240 | normal_animation = self.renderer.render_yaw(canon_normal_im, self.canon_depth, v_after=view_after, rotations=yaw_rotations) # BxTxCxHxW 241 | self.normal_animation = rearrange_frames(normal_animation) 242 | 243 | ## textured 244 | texture_animation = self.renderer.render_yaw(self.canon_im /2+0.5, self.canon_depth, v_after=view_after, rotations=yaw_rotations) # BxTxCxHxW 245 | self.texture_animation = rearrange_frames(texture_animation) 246 | 247 | def save_results(self, save_dir): 248 | print(f"Saving results to {save_dir}") 249 | save_image(save_dir, self.input_im[0]/2+0.5, 'input_image') 250 | save_image(save_dir, self.depth_inv_rescaler(self.canon_depth)[0].repeat(3,1,1), 'canonical_depth') 251 | save_image(save_dir, self.canon_normal[0].permute(2,0,1)/2+0.5, 'canonical_normal') 252 | save_image(save_dir, self.canon_diffuse_shading[0].repeat(3,1,1), 'canonical_diffuse_shading') 253 | save_image(save_dir, self.canon_albedo[0]/2+0.5, 'canonical_albedo') 254 | save_image(save_dir, self.canon_im[0].clamp(-1,1)/2+0.5, 'canonical_image') 255 | 256 | with open(os.path.join(save_dir, 'result.mtl'), "w") as f: 257 | f.write(self.mtls[0].replace('$TXTFILE', './canonical_image.png')) 258 | with open(os.path.join(save_dir, 'result.obj'), "w") as f: 259 | f.write(self.objs[0].replace('$MTLFILE', './result.mtl')) 260 | 261 | if self.render_video: 262 | save_video(save_dir, self.shape_animation[0], 'shape_animation') 263 | save_video(save_dir, self.normal_animation[0], 'normal_animation') 264 | save_video(save_dir, self.texture_animation[0], 'texture_animation') 265 | 266 | 267 | if __name__ == "__main__": 268 | parser = argparse.ArgumentParser(description='Demo configurations.') 269 | parser.add_argument('--input', default='./demo/images/human_face', type=str, help='Path to the directory containing input images') 270 | parser.add_argument('--result', default='./demo/results/human_face', type=str, help='Path to the directory for saving results') 271 | parser.add_argument('--checkpoint', default='./pretrained/pretrained_celeba/checkpoint030.pth', type=str, help='Path to the checkpoint file') 272 | parser.add_argument('--output_size', default=128, type=int, help='Output image size') 273 | parser.add_argument('--gpu', default=False, action='store_true', help='Enable GPU') 274 | parser.add_argument('--detect_human_face', default=False, action='store_true', help='Enable automatic human face detection. This does not detect cat faces.') 275 | parser.add_argument('--render_video', default=False, action='store_true', help='Render 3D animations to video') 276 | args = parser.parse_args() 277 | 278 | input_dir = args.input 279 | result_dir = args.result 280 | model = Demo(args) 281 | im_list = [os.path.join(input_dir, f) for f in sorted(os.listdir(input_dir)) if is_image_file(f)] 282 | 283 | for im_path in im_list: 284 | print(f"Processing {im_path}") 285 | pil_im = Image.open(im_path).convert('RGB') 286 | result_code = model.run(pil_im) 287 | if result_code == -1: 288 | print(f"Failed! Skipping {im_path}") 289 | continue 290 | 291 | save_dir = os.path.join(result_dir, os.path.splitext(os.path.basename(im_path))[0]) 292 | model.save_results(save_dir) 293 | -------------------------------------------------------------------------------- /demo/images/cat_face/001_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/001_cat.png -------------------------------------------------------------------------------- /demo/images/cat_face/002_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/002_cat.png -------------------------------------------------------------------------------- /demo/images/cat_face/003_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/003_cat.png -------------------------------------------------------------------------------- /demo/images/cat_face/004_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/004_cat.png -------------------------------------------------------------------------------- /demo/images/cat_face/005_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/005_cat.png -------------------------------------------------------------------------------- /demo/images/cat_face/006_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/006_cat.png -------------------------------------------------------------------------------- /demo/images/cat_face/007_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/007_cat.png -------------------------------------------------------------------------------- /demo/images/cat_face/008_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/008_cat.png -------------------------------------------------------------------------------- /demo/images/cat_face/009_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/009_cat.png -------------------------------------------------------------------------------- /demo/images/cat_face/010_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/010_cat.png -------------------------------------------------------------------------------- /demo/images/cat_face/011_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/011_cat.png -------------------------------------------------------------------------------- /demo/images/cat_face/012_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/012_cat.png -------------------------------------------------------------------------------- /demo/images/cat_face/013_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/013_cat.png -------------------------------------------------------------------------------- /demo/images/cat_face/014_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/014_cat.png -------------------------------------------------------------------------------- /demo/images/cat_face/015_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/015_cat.png -------------------------------------------------------------------------------- /demo/images/cat_face/016_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/016_cat.png -------------------------------------------------------------------------------- /demo/images/cat_face/017_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/017_cat.png -------------------------------------------------------------------------------- /demo/images/cat_face/018_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/018_cat.png -------------------------------------------------------------------------------- /demo/images/cat_face/019_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/019_cat.png -------------------------------------------------------------------------------- /demo/images/cat_face/020_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/020_cat.png -------------------------------------------------------------------------------- /demo/images/cat_face/021_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/021_cat.png -------------------------------------------------------------------------------- /demo/images/cat_face/022_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/022_cat.png -------------------------------------------------------------------------------- /demo/images/cat_face/023_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/023_cat.png -------------------------------------------------------------------------------- /demo/images/cat_face/024_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/024_abstract.png -------------------------------------------------------------------------------- /demo/images/cat_face/025_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/025_abstract.png -------------------------------------------------------------------------------- /demo/images/cat_face/026_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/026_abstract.png -------------------------------------------------------------------------------- /demo/images/cat_face/027_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/027_abstract.png -------------------------------------------------------------------------------- /demo/images/cat_face/028_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/028_abstract.png -------------------------------------------------------------------------------- /demo/images/cat_face/029_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/029_abstract.png -------------------------------------------------------------------------------- /demo/images/cat_face/030_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/030_abstract.png -------------------------------------------------------------------------------- /demo/images/cat_face/031_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/031_abstract.png -------------------------------------------------------------------------------- /demo/images/cat_face/032_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/032_abstract.png -------------------------------------------------------------------------------- /demo/images/cat_face/033_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/033_abstract.png -------------------------------------------------------------------------------- /demo/images/cat_face/034_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/034_abstract.png -------------------------------------------------------------------------------- /demo/images/cat_face/035_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/035_abstract.png -------------------------------------------------------------------------------- /demo/images/cat_face/036_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/036_abstract.png -------------------------------------------------------------------------------- /demo/images/cat_face/037_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/037_abstract.png -------------------------------------------------------------------------------- /demo/images/cat_face/038_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/038_abstract.png -------------------------------------------------------------------------------- /demo/images/cat_face/039_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/039_abstract.png -------------------------------------------------------------------------------- /demo/images/cat_face/040_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/040_abstract.png -------------------------------------------------------------------------------- /demo/images/cat_face/041_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/041_abstract.png -------------------------------------------------------------------------------- /demo/images/cat_face/042_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/042_abstract.png -------------------------------------------------------------------------------- /demo/images/cat_face/043_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/043_abstract.png -------------------------------------------------------------------------------- /demo/images/cat_face/044_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/044_abstract.png -------------------------------------------------------------------------------- /demo/images/cat_face/045_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/045_abstract.png -------------------------------------------------------------------------------- /demo/images/human_face/001_face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/001_face.png -------------------------------------------------------------------------------- /demo/images/human_face/002_face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/002_face.png -------------------------------------------------------------------------------- /demo/images/human_face/003_face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/003_face.png -------------------------------------------------------------------------------- /demo/images/human_face/004_face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/004_face.png -------------------------------------------------------------------------------- /demo/images/human_face/005_face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/005_face.png -------------------------------------------------------------------------------- /demo/images/human_face/006_face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/006_face.png -------------------------------------------------------------------------------- /demo/images/human_face/007_face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/007_face.png -------------------------------------------------------------------------------- /demo/images/human_face/008_face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/008_face.png -------------------------------------------------------------------------------- /demo/images/human_face/009_face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/009_face.png -------------------------------------------------------------------------------- /demo/images/human_face/010_face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/010_face.png -------------------------------------------------------------------------------- /demo/images/human_face/011_face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/011_face.png -------------------------------------------------------------------------------- /demo/images/human_face/012_face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/012_face.png -------------------------------------------------------------------------------- /demo/images/human_face/013_face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/013_face.png -------------------------------------------------------------------------------- /demo/images/human_face/014_face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/014_face.png -------------------------------------------------------------------------------- /demo/images/human_face/015_face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/015_face.png -------------------------------------------------------------------------------- /demo/images/human_face/016_paint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/016_paint.png -------------------------------------------------------------------------------- /demo/images/human_face/017_paint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/017_paint.png -------------------------------------------------------------------------------- /demo/images/human_face/018_paint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/018_paint.png -------------------------------------------------------------------------------- /demo/images/human_face/019_paint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/019_paint.png -------------------------------------------------------------------------------- /demo/images/human_face/020_paint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/020_paint.png -------------------------------------------------------------------------------- /demo/images/human_face/021_paint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/021_paint.png -------------------------------------------------------------------------------- /demo/images/human_face/022_paint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/022_paint.png -------------------------------------------------------------------------------- /demo/images/human_face/023_paint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/023_paint.png -------------------------------------------------------------------------------- /demo/images/human_face/024_paint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/024_paint.png -------------------------------------------------------------------------------- /demo/images/human_face/025_paint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/025_paint.png -------------------------------------------------------------------------------- /demo/images/human_face/026_paint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/026_paint.png -------------------------------------------------------------------------------- /demo/images/human_face/027_paint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/027_paint.png -------------------------------------------------------------------------------- /demo/images/human_face/028_paint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/028_paint.png -------------------------------------------------------------------------------- /demo/images/human_face/029_paint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/029_paint.png -------------------------------------------------------------------------------- /demo/images/human_face/030_paint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/030_paint.png -------------------------------------------------------------------------------- /demo/images/human_face/031_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/031_abstract.png -------------------------------------------------------------------------------- /demo/images/human_face/032_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/032_abstract.png -------------------------------------------------------------------------------- /demo/images/human_face/033_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/033_abstract.png -------------------------------------------------------------------------------- /demo/images/human_face/034_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/034_abstract.png -------------------------------------------------------------------------------- /demo/images/human_face/035_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/035_abstract.png -------------------------------------------------------------------------------- /demo/images/human_face/036_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/036_abstract.png -------------------------------------------------------------------------------- /demo/images/human_face/037_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/037_abstract.png -------------------------------------------------------------------------------- /demo/images/human_face/038_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/038_abstract.png -------------------------------------------------------------------------------- /demo/images/human_face/039_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/039_abstract.png -------------------------------------------------------------------------------- /demo/images/human_face/040_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/040_abstract.png -------------------------------------------------------------------------------- /demo/images/human_face/041_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/041_abstract.png -------------------------------------------------------------------------------- /demo/images/human_face/042_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/042_abstract.png -------------------------------------------------------------------------------- /demo/images/human_face/043_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/043_abstract.png -------------------------------------------------------------------------------- /demo/images/human_face/044_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/044_abstract.png -------------------------------------------------------------------------------- /demo/images/human_face/045_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/045_abstract.png -------------------------------------------------------------------------------- /demo/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class Encoder(nn.Module): 9 | def __init__(self, cin, cout, nf=64, activation=nn.Tanh): 10 | super(Encoder, self).__init__() 11 | network = [ 12 | nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32 13 | nn.ReLU(inplace=True), 14 | nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16 15 | nn.ReLU(inplace=True), 16 | nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(nf*4, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4 19 | nn.ReLU(inplace=True), 20 | nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 21 | nn.ReLU(inplace=True), 22 | nn.Conv2d(nf*8, cout, kernel_size=1, stride=1, padding=0, bias=False)] 23 | if activation is not None: 24 | network += [activation()] 25 | self.network = nn.Sequential(*network) 26 | 27 | def forward(self, input): 28 | return self.network(input).reshape(input.size(0),-1) 29 | 30 | 31 | class EDDeconv(nn.Module): 32 | def __init__(self, cin, cout, zdim=128, nf=64, activation=nn.Tanh): 33 | super(EDDeconv, self).__init__() 34 | ## downsampling 35 | network = [ 36 | nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32 37 | nn.GroupNorm(16, nf), 38 | nn.LeakyReLU(0.2, inplace=True), 39 | nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16 40 | nn.GroupNorm(16*2, nf*2), 41 | nn.LeakyReLU(0.2, inplace=True), 42 | nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8 43 | nn.GroupNorm(16*4, nf*4), 44 | nn.LeakyReLU(0.2, inplace=True), 45 | nn.Conv2d(nf*4, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4 46 | nn.LeakyReLU(0.2, inplace=True), 47 | nn.Conv2d(nf*8, zdim, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 48 | nn.ReLU(inplace=True)] 49 | ## upsampling 50 | network += [ 51 | nn.ConvTranspose2d(zdim, nf*8, kernel_size=4, stride=1, padding=0, bias=False), # 1x1 -> 4x4 52 | nn.ReLU(inplace=True), 53 | nn.Conv2d(nf*8, nf*8, kernel_size=3, stride=1, padding=1, bias=False), 54 | nn.ReLU(inplace=True), 55 | nn.ConvTranspose2d(nf*8, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 4x4 -> 8x8 56 | nn.GroupNorm(16*4, nf*4), 57 | nn.ReLU(inplace=True), 58 | nn.Conv2d(nf*4, nf*4, kernel_size=3, stride=1, padding=1, bias=False), 59 | nn.GroupNorm(16*4, nf*4), 60 | nn.ReLU(inplace=True), 61 | nn.ConvTranspose2d(nf*4, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 16x16 62 | nn.GroupNorm(16*2, nf*2), 63 | nn.ReLU(inplace=True), 64 | nn.Conv2d(nf*2, nf*2, kernel_size=3, stride=1, padding=1, bias=False), 65 | nn.GroupNorm(16*2, nf*2), 66 | nn.ReLU(inplace=True), 67 | nn.ConvTranspose2d(nf*2, nf, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 32x32 68 | nn.GroupNorm(16, nf), 69 | nn.ReLU(inplace=True), 70 | nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=False), 71 | nn.GroupNorm(16, nf), 72 | nn.ReLU(inplace=True), 73 | nn.Upsample(scale_factor=2, mode='nearest'), # 32x32 -> 64x64 74 | nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=False), 75 | nn.GroupNorm(16, nf), 76 | nn.ReLU(inplace=True), 77 | nn.Conv2d(nf, nf, kernel_size=5, stride=1, padding=2, bias=False), 78 | nn.GroupNorm(16, nf), 79 | nn.ReLU(inplace=True), 80 | nn.Conv2d(nf, cout, kernel_size=5, stride=1, padding=2, bias=False)] 81 | if activation is not None: 82 | network += [activation()] 83 | self.network = nn.Sequential(*network) 84 | 85 | def forward(self, input): 86 | return self.network(input) 87 | 88 | 89 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp') 90 | def is_image_file(filename): 91 | return filename.lower().endswith(IMG_EXTENSIONS) 92 | 93 | 94 | def save_video(out_fold, frames, fname='image', ext='.mp4', cycle=False): 95 | os.makedirs(out_fold, exist_ok=True) 96 | frames = frames.detach().cpu().numpy().transpose(0,2,3,1) # TxCxHxW -> TxHxWxC 97 | if cycle: 98 | frames = np.concatenate([frames, frames[::-1]], 0) 99 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 100 | # fourcc = cv2.VideoWriter_fourcc(*'avc1') 101 | vid = cv2.VideoWriter(os.path.join(out_fold, fname+ext), fourcc, 25, (frames.shape[2], frames.shape[1])) 102 | [vid.write(np.uint8(f[...,::-1]*255.)) for f in frames] 103 | vid.release() 104 | 105 | 106 | def save_image(out_fold, img, fname='image', ext='.png'): 107 | os.makedirs(out_fold, exist_ok=True) 108 | img = img.detach().cpu().numpy().transpose(1,2,0) 109 | if 'depth' in fname: 110 | im_out = np.uint16(img*65535.) 111 | else: 112 | im_out = np.uint8(img*255.) 113 | cv2.imwrite(os.path.join(out_fold, fname+ext), im_out[:,:,::-1]) 114 | 115 | 116 | def get_grid(b, H, W, normalize=True): 117 | if normalize: 118 | h_range = torch.linspace(-1,1,H) 119 | w_range = torch.linspace(-1,1,W) 120 | else: 121 | h_range = torch.arange(0,H) 122 | w_range = torch.arange(0,W) 123 | grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).repeat(b,1,1,1).flip(3).float() # flip h,w to x,y 124 | return grid 125 | 126 | 127 | def export_to_obj_string(vertices, normal): 128 | b, h, w, _ = vertices.shape 129 | vertices[:,:,:,1:2] = -1*vertices[:,:,:,1:2] # flip y 130 | vertices[:,:,:,2:3] = 1-vertices[:,:,:,2:3] # flip and shift z 131 | vertices *= 100 132 | vertices_center = nn.functional.avg_pool2d(vertices.permute(0,3,1,2), 2, stride=1).permute(0,2,3,1) 133 | vertices = torch.cat([vertices.view(b,h*w,3), vertices_center.view(b,(h-1)*(w-1),3)], 1) 134 | 135 | vertice_textures = get_grid(b, h, w, normalize=True) # BxHxWx2 136 | vertice_textures[:,:,:,1:2] = -1*vertice_textures[:,:,:,1:2] # flip y 137 | vertice_textures_center = nn.functional.avg_pool2d(vertice_textures.permute(0,3,1,2), 2, stride=1).permute(0,2,3,1) 138 | vertice_textures = torch.cat([vertice_textures.view(b,h*w,2), vertice_textures_center.view(b,(h-1)*(w-1),2)], 1) /2+0.5 # Bx(H*W)x2, [0,1] 139 | 140 | vertice_normals = normal.clone() 141 | vertice_normals[:,:,:,0:1] = -1*vertice_normals[:,:,:,0:1] 142 | vertice_normals_center = nn.functional.avg_pool2d(vertice_normals.permute(0,3,1,2), 2, stride=1).permute(0,2,3,1) 143 | vertice_normals_center = vertice_normals_center / (vertice_normals_center**2).sum(3, keepdim=True)**0.5 144 | vertice_normals = torch.cat([vertice_normals.view(b,h*w,3), vertice_normals_center.view(b,(h-1)*(w-1),3)], 1) # Bx(H*W)x2, [0,1] 145 | 146 | idx_map = torch.arange(h*w).reshape(h,w) 147 | idx_map_center = torch.arange((h-1)*(w-1)).reshape(h-1,w-1) 148 | faces1 = torch.stack([idx_map[:h-1,:w-1], idx_map[1:,:w-1], idx_map_center+h*w], -1).reshape(-1,3).repeat(b,1,1).int() # Bx((H-1)*(W-1))x4 149 | faces2 = torch.stack([idx_map[1:,:w-1], idx_map[1:,1:], idx_map_center+h*w], -1).reshape(-1,3).repeat(b,1,1).int() # Bx((H-1)*(W-1))x4 150 | faces3 = torch.stack([idx_map[1:,1:], idx_map[:h-1,1:], idx_map_center+h*w], -1).reshape(-1,3).repeat(b,1,1).int() # Bx((H-1)*(W-1))x4 151 | faces4 = torch.stack([idx_map[:h-1,1:], idx_map[:h-1,:w-1], idx_map_center+h*w], -1).reshape(-1,3).repeat(b,1,1).int() # Bx((H-1)*(W-1))x4 152 | faces = torch.cat([faces1, faces2, faces3, faces4], 1) 153 | 154 | objs = [] 155 | mtls = [] 156 | for bi in range(b): 157 | obj = "# OBJ File:" 158 | obj += "\n\nmtllib $MTLFILE" 159 | obj += "\n\n# vertices:" 160 | for v in vertices[bi]: 161 | obj += "\nv " + " ".join(["%.4f"%x for x in v]) 162 | obj += "\n\n# vertice textures:" 163 | for vt in vertice_textures[bi]: 164 | obj += "\nvt " + " ".join(["%.4f"%x for x in vt]) 165 | obj += "\n\n# vertice normals:" 166 | for vn in vertice_normals[bi]: 167 | obj += "\nvn " + " ".join(["%.4f"%x for x in vn]) 168 | obj += "\n\n# faces:" 169 | obj += "\n\nusemtl tex" 170 | for f in faces[bi]: 171 | obj += "\nf " + " ".join(["%d/%d/%d"%(x+1,x+1,x+1) for x in f]) 172 | objs += [obj] 173 | 174 | mtl = "newmtl tex" 175 | mtl += "\nKa 1.0000 1.0000 1.0000" 176 | mtl += "\nKd 1.0000 1.0000 1.0000" 177 | mtl += "\nKs 0.0000 0.0000 0.0000" 178 | mtl += "\nd 1.0" 179 | mtl += "\nillum 0" 180 | mtl += "\nmap_Kd $TXTFILE" 181 | mtls += [mtl] 182 | return objs, mtls 183 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: unsup3d 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - bzip2=1.0.8=h516909a_2 8 | - ca-certificates=2020.4.5.1=hecc5488_0 9 | - cairo=1.16.0=h18b612c_1001 10 | - certifi=2020.4.5.1=py37hc8dfbb8_0 11 | - cffi=1.14.0=py37hd463f26_0 12 | - chardet=3.0.4=py37hc8dfbb8_1006 13 | - cloudpickle=1.3.0=py_0 14 | - cryptography=2.8=py37hb09aad4_2 15 | - cycler=0.10.0=py_2 16 | - cytoolz=0.10.1=py37h516909a_0 17 | - dask-core=2.14.0=py_0 18 | - dbus=1.13.6=he372182_0 19 | - decorator=4.4.2=py_0 20 | - expat=2.2.9=he1b5a44_2 21 | - ffmpeg=4.1.3=h167e202_0 22 | - fontconfig=2.13.1=he4413a7_1000 23 | - freetype=2.10.1=he06d7ca_0 24 | - gettext=0.19.8.1=hc5be6a0_1002 25 | - giflib=5.2.1=h516909a_2 26 | - glib=2.58.3=py37he00f558_1003 27 | - gmp=6.2.0=he1b5a44_2 28 | - gnutls=3.6.5=hd3a4fd2_1002 29 | - graphite2=1.3.13=he1b5a44_1001 30 | - gst-plugins-base=1.14.5=h0935bb2_2 31 | - gstreamer=1.14.5=h36ae1b5_2 32 | - harfbuzz=2.4.0=h37c48d4_1 33 | - hdf5=1.10.5=nompi_h3c11f04_1104 34 | - icu=58.2=hf484d3e_1000 35 | - idna=2.9=py_1 36 | - imageio=2.8.0=py_0 37 | - imageio-ffmpeg=0.4.1=py_0 38 | - jasper=1.900.1=h07fcdf6_1006 39 | - jpeg=9c=h14c3975_1001 40 | - kiwisolver=1.2.0=py37h99015e2_0 41 | - lame=3.100=h14c3975_1001 42 | - ld_impl_linux-64=2.33.1=h53a641e_7 43 | - libblas=3.8.0=14_openblas 44 | - libcblas=3.8.0=14_openblas 45 | - libedit=3.1.20181209=hc058e9b_0 46 | - libffi=3.2.1=hd88cf55_4 47 | - libgcc-ng=9.1.0=hdf63c60_0 48 | - libgfortran-ng=7.3.0=hdf63c60_5 49 | - libiconv=1.15=h516909a_1006 50 | - liblapack=3.8.0=14_openblas 51 | - liblapacke=3.8.0=14_openblas 52 | - libopenblas=0.3.7=h5ec1e0e_6 53 | - libpng=1.6.37=hed695b0_1 54 | - libprotobuf=3.11.4=h8b12597_0 55 | - libstdcxx-ng=9.1.0=hdf63c60_0 56 | - libtiff=4.1.0=hc3755c2_3 57 | - libuuid=2.32.1=h14c3975_1000 58 | - libwebp=1.0.2=h56121f0_5 59 | - libxcb=1.13=h14c3975_1002 60 | - libxml2=2.9.9=h13577e0_2 61 | - lz4-c=1.8.3=he1b5a44_1001 62 | - matplotlib=3.1.3=py37_0 63 | - matplotlib-base=3.1.3=py37hef1b27d_0 64 | - moviepy=1.0.1=py_0 65 | - ncurses=6.2=he6710b0_0 66 | - nettle=3.4.1=h1bed415_1002 67 | - networkx=2.4=py_1 68 | - numpy=1.18.1=py37h8960a57_1 69 | - olefile=0.46=py_0 70 | - opencv=4.1.1=py37ha799480_1 71 | - openh264=1.8.0=hdbcaa40_1000 72 | - openssl=1.1.1f=h516909a_0 73 | - pcre=8.44=he1b5a44_0 74 | - pillow=5.3.0=py37h00a061d_1000 75 | - pip=20.0.2=py37_1 76 | - pixman=0.38.0=h516909a_1003 77 | - proglog=0.1.9=py_0 78 | - protobuf=3.11.4=py37h3340039_1 79 | - pthread-stubs=0.4=h14c3975_1001 80 | - pycparser=2.20=py_0 81 | - pyopenssl=19.1.0=py_1 82 | - pyparsing=2.4.7=pyh9f0ad1d_0 83 | - pyqt=5.9.2=py37hcca6a23_4 84 | - pysocks=1.7.1=py37hc8dfbb8_1 85 | - python=3.7.7=hcf32534_0_cpython 86 | - python-dateutil=2.8.1=py_0 87 | - python_abi=3.7=1_cp37m 88 | - pywavelets=1.1.1=py37hc1659b7_0 89 | - pyyaml=5.3.1=py37h8f50634_0 90 | - qt=5.9.7=h52cfd70_2 91 | - readline=8.0=h7b6447c_0 92 | - requests=2.23.0=pyh8c360ce_2 93 | - scikit-image=0.16.2=py37hb3f55d8_0 94 | - scipy=1.4.1=py37ha3d9a3c_2 95 | - setuptools=46.1.3=py37_0 96 | - sip=4.19.8=py37hf484d3e_1000 97 | - six=1.14.0=py_1 98 | - sqlite=3.31.1=h7b6447c_0 99 | - tensorboardx=2.0=py_0 100 | - tk=8.6.8=hbc83047_0 101 | - toolz=0.10.0=py_0 102 | - tornado=6.0.4=py37h8f50634_1 103 | - tqdm=4.45.0=pyh9f0ad1d_0 104 | - urllib3=1.25.7=py37hc8dfbb8_1 105 | - wheel=0.34.2=py37_0 106 | - x264=1!152.20180806=h14c3975_0 107 | - xorg-kbproto=1.0.7=h14c3975_1002 108 | - xorg-libice=1.0.10=h516909a_0 109 | - xorg-libsm=1.2.3=h84519dc_1000 110 | - xorg-libx11=1.6.9=h516909a_0 111 | - xorg-libxau=1.0.9=h14c3975_0 112 | - xorg-libxdmcp=1.1.3=h516909a_0 113 | - xorg-libxext=1.3.4=h516909a_0 114 | - xorg-libxrender=0.9.10=h516909a_1002 115 | - xorg-renderproto=0.11.1=h14c3975_1002 116 | - xorg-xextproto=7.3.0=h14c3975_1002 117 | - xorg-xproto=7.0.31=h14c3975_1007 118 | - xz=5.2.4=h14c3975_4 119 | - yaml=0.2.2=h516909a_1 120 | - zlib=1.2.11=h7b6447c_3 121 | - zstd=1.4.4=h3b9ef0a_2 122 | -------------------------------------------------------------------------------- /experiments/test_cat.yml: -------------------------------------------------------------------------------- 1 | ## test cat 2 | ## trainer 3 | run_test: true 4 | batch_size: 64 5 | checkpoint_dir: results/cat 6 | checkpoint_name: checkpoint100.pth 7 | test_result_dir: results/cat/test_results_checkpoint100 8 | 9 | ## dataloader 10 | num_workers: 4 11 | image_size: 64 12 | crop: 170 13 | load_gt_depth: false 14 | test_data_dir: data/cat_combined/test 15 | 16 | ## model 17 | model_name: unsup3d_cat 18 | min_depth: 0.9 19 | max_depth: 1.1 20 | xyz_rotation_range: 60 # (-r,r) in degrees 21 | xy_translation_range: 0.1 # (-t,t) in 3D 22 | z_translation_range: 0.1 # (-t,t) in 3D 23 | lam_perc: 1 24 | lam_flip: 0.5 25 | 26 | ## renderer 27 | rot_center_depth: 1.0 28 | fov: 10 # in degrees 29 | tex_cube_size: 2 30 | -------------------------------------------------------------------------------- /experiments/test_celeba.yml: -------------------------------------------------------------------------------- 1 | ## test celeba 2 | ## trainer 3 | run_test: true 4 | batch_size: 64 5 | checkpoint_dir: results/celeba 6 | checkpoint_name: checkpoint030.pth 7 | test_result_dir: results/celeba/test_results_checkpoint030 8 | 9 | ## dataloader 10 | num_workers: 4 11 | image_size: 64 12 | load_gt_depth: false 13 | test_data_dir: data/celeba_cropped/test 14 | 15 | ## model 16 | model_name: unsup3d_celeba 17 | min_depth: 0.9 18 | max_depth: 1.1 19 | xyz_rotation_range: 60 # (-r,r) in degrees 20 | xy_translation_range: 0.1 # (-t,t) in 3D 21 | z_translation_range: 0 # (-t,t) in 3D 22 | lam_perc: 1 23 | lam_flip: 0.5 24 | 25 | ## renderer 26 | rot_center_depth: 1.0 27 | fov: 10 # in degrees 28 | tex_cube_size: 2 29 | -------------------------------------------------------------------------------- /experiments/test_syncar.yml: -------------------------------------------------------------------------------- 1 | ## test syncar 2 | ## trainer 3 | run_test: true 4 | batch_size: 64 5 | checkpoint_dir: results/syncar 6 | checkpoint_name: checkpoint100.pth 7 | test_result_dir: results/syncar/test_results_checkpoint100 8 | 9 | ## dataloader 10 | num_workers: 4 11 | image_size: 64 12 | crop: [8, 14, 100, 100] 13 | load_gt_depth: false 14 | test_data_dir: data/syncar/test 15 | 16 | ## model 17 | model_name: unsup3d_syncar 18 | min_depth: 0.9 19 | max_depth: 1.1 20 | min_amb_light: 0. 21 | max_amb_light: 1. 22 | min_diff_light: 0.5 23 | max_diff_light: 1. 24 | xyz_rotation_range: 60 # (-r,r) in degrees 25 | xy_translation_range: 0.1 # (-t,t) in 3D 26 | z_translation_range: 0 # (-t,t) in 3D 27 | use_conf_map: false 28 | lam_perc: 0.01 29 | lam_flip: 1 30 | lam_depth_sm: 0.1 31 | 32 | ## renderer 33 | rot_center_depth: 1.0 34 | fov: 10 # in degrees 35 | tex_cube_size: 2 36 | -------------------------------------------------------------------------------- /experiments/test_synface.yml: -------------------------------------------------------------------------------- 1 | ## test synface 2 | ## trainer 3 | run_test: true 4 | batch_size: 64 5 | checkpoint_dir: results/synface 6 | checkpoint_name: checkpoint030.pth 7 | test_result_dir: results/synface/test_results_checkpoint030 8 | 9 | ## dataloader 10 | num_workers: 4 11 | image_size: 64 12 | crop: 170 13 | load_gt_depth: true 14 | paired_data_dir_names: ['image', 'depth'] 15 | paired_data_filename_diff: ['image', 'depth'] 16 | test_data_dir: data/synface/test 17 | 18 | ## model 19 | model_name: unsup3d_synface 20 | min_depth: 0.9 21 | max_depth: 1.1 22 | xyz_rotation_range: 60 # (-r,r) in degrees 23 | xy_translation_range: 0.1 # (-t,t) in 3D 24 | z_translation_range: 0 # (-t,t) in 3D 25 | lam_perc: 1 26 | lam_flip: 0.5 27 | 28 | ## renderer 29 | rot_center_depth: 1.0 30 | fov: 10 # in degrees 31 | tex_cube_size: 2 32 | -------------------------------------------------------------------------------- /experiments/train_cat.yml: -------------------------------------------------------------------------------- 1 | ## train cat 2 | ## trainer 3 | run_train: true 4 | num_epochs: 200 5 | batch_size: 64 6 | checkpoint_dir: results/cat 7 | save_checkpoint_freq: 10 8 | keep_num_checkpoint: 2 9 | resume: true 10 | use_logger: true 11 | log_freq: 500 12 | 13 | ## dataloader 14 | num_workers: 4 15 | image_size: 64 16 | crop: 170 17 | load_gt_depth: false 18 | train_val_data_dir: data/cat_combined 19 | 20 | ## model 21 | model_name: unsup3d_cat 22 | min_depth: 0.9 23 | max_depth: 1.1 24 | xyz_rotation_range: 60 # (-r,r) in degrees 25 | xy_translation_range: 0.1 # (-t,t) in 3D 26 | z_translation_range: 0.1 # (-t,t) in 3D 27 | lam_perc: 1 28 | lam_flip: 0.5 29 | lam_flip_start_epoch: 10 30 | lr: 0.0001 31 | 32 | ## renderer 33 | rot_center_depth: 1.0 34 | fov: 10 # in degrees 35 | tex_cube_size: 2 36 | -------------------------------------------------------------------------------- /experiments/train_celeba.yml: -------------------------------------------------------------------------------- 1 | ## train celeba 2 | ## trainer 3 | run_train: true 4 | num_epochs: 30 5 | batch_size: 64 6 | checkpoint_dir: results/celeba 7 | save_checkpoint_freq: 1 8 | keep_num_checkpoint: 2 9 | resume: true 10 | use_logger: true 11 | log_freq: 500 12 | 13 | ## dataloader 14 | num_workers: 4 15 | image_size: 64 16 | load_gt_depth: false 17 | train_val_data_dir: data/celeba_cropped 18 | 19 | ## model 20 | model_name: unsup3d_celeba 21 | min_depth: 0.9 22 | max_depth: 1.1 23 | xyz_rotation_range: 60 # (-r,r) in degrees 24 | xy_translation_range: 0.1 # (-t,t) in 3D 25 | z_translation_range: 0 # (-t,t) in 3D 26 | lam_perc: 1 27 | lam_flip: 0.5 28 | lr: 0.0001 29 | 30 | ## renderer 31 | rot_center_depth: 1.0 32 | fov: 10 # in degrees 33 | tex_cube_size: 2 34 | -------------------------------------------------------------------------------- /experiments/train_syncar.yml: -------------------------------------------------------------------------------- 1 | ## train syncar 2 | ## trainer 3 | run_train: true 4 | num_epochs: 100 5 | batch_size: 64 6 | checkpoint_dir: results/syncar 7 | save_checkpoint_freq: 10 8 | keep_num_checkpoint: 2 9 | resume: true 10 | use_logger: true 11 | log_freq: 500 12 | 13 | ## dataloader 14 | num_workers: 4 15 | image_size: 64 16 | crop: [8, 14, 100, 100] 17 | load_gt_depth: false 18 | test_data_dir: data/syncar 19 | 20 | ## model 21 | model_name: unsup3d_syncar 22 | min_depth: 0.9 23 | max_depth: 1.1 24 | border_depth: 1.08 25 | min_amb_light: 0. 26 | max_amb_light: 1. 27 | min_diff_light: 0.5 28 | max_diff_light: 1. 29 | xyz_rotation_range: 60 # (-r,r) in degrees 30 | xy_translation_range: 0.1 # (-t,t) in 3D 31 | z_translation_range: 0 # (-t,t) in 3D 32 | use_conf_map: false 33 | lam_perc: 0.01 34 | lam_flip: 1 35 | lam_depth_sm: 0.1 36 | lr: 0.0001 37 | 38 | ## renderer 39 | rot_center_depth: 1.0 40 | fov: 10 # in degrees 41 | tex_cube_size: 2 42 | -------------------------------------------------------------------------------- /experiments/train_synface.yml: -------------------------------------------------------------------------------- 1 | ## train synface 2 | ## trainer 3 | run_train: true 4 | num_epochs: 30 5 | batch_size: 64 6 | checkpoint_dir: results/synface 7 | save_checkpoint_freq: 1 8 | keep_num_checkpoint: 2 9 | resume: true 10 | use_logger: true 11 | log_freq: 500 12 | 13 | ## dataloader 14 | num_workers: 4 15 | image_size: 64 16 | crop: 170 17 | load_gt_depth: true 18 | paired_data_dir_names: ['image', 'depth'] 19 | paired_data_filename_diff: ['image', 'depth'] 20 | train_val_data_dir: data/synface 21 | 22 | ## model 23 | model_name: unsup3d_synface 24 | min_depth: 0.9 25 | max_depth: 1.1 26 | xyz_rotation_range: 60 # (-r,r) in degrees 27 | xy_translation_range: 0.1 # (-t,t) in 3D 28 | z_translation_range: 0 # (-t,t) in 3D 29 | lam_perc: 1 30 | lam_flip: 0.5 31 | lr: 0.0001 32 | 33 | ## renderer 34 | rot_center_depth: 1.0 35 | fov: 10 # in degrees 36 | tex_cube_size: 2 37 | -------------------------------------------------------------------------------- /img/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/img/teaser.jpg -------------------------------------------------------------------------------- /pretrained/download_pretrained_cat.sh: -------------------------------------------------------------------------------- 1 | echo "----------------------- downloading pretrained model on cat dataset -----------------------" 2 | curl -o pretrained_cat.zip "https://www.robots.ox.ac.uk/~vgg/research/unsup3d/data/pretrained_cat.zip" && unzip pretrained_cat.zip 3 | -------------------------------------------------------------------------------- /pretrained/download_pretrained_celeba.sh: -------------------------------------------------------------------------------- 1 | echo "----------------------- downloading pretrained model on celeba face dataset -----------------------" 2 | curl -o pretrained_celeba.zip "https://www.robots.ox.ac.uk/~vgg/research/unsup3d/data/pretrained_celeba.zip" && unzip pretrained_celeba.zip 3 | -------------------------------------------------------------------------------- /pretrained/download_pretrained_syncar.sh: -------------------------------------------------------------------------------- 1 | echo "----------------------- downloading pretrained model on synthetic car dataset -----------------------" 2 | curl -o pretrained_syncar.zip "https://www.robots.ox.ac.uk/~vgg/research/unsup3d/data/pretrained_syncar.zip" && unzip pretrained_syncar.zip 3 | -------------------------------------------------------------------------------- /pretrained/download_pretrained_synface.sh: -------------------------------------------------------------------------------- 1 | echo "----------------------- downloading pretrained model on synthetic face dataset -----------------------" 2 | curl -o pretrained_synface.zip "https://www.robots.ox.ac.uk/~vgg/research/unsup3d/data/pretrained_synface.zip" && unzip pretrained_synface.zip 3 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from unsup3d import setup_runtime, Trainer, Unsup3D 4 | 5 | 6 | ## runtime arguments 7 | parser = argparse.ArgumentParser(description='Training configurations.') 8 | parser.add_argument('--config', default=None, type=str, help='Specify a config file path') 9 | parser.add_argument('--gpu', default=None, type=int, help='Specify a GPU device') 10 | parser.add_argument('--num_workers', default=4, type=int, help='Specify the number of worker threads for data loaders') 11 | parser.add_argument('--seed', default=0, type=int, help='Specify a random seed') 12 | args = parser.parse_args() 13 | 14 | ## set up 15 | cfgs = setup_runtime(args) 16 | trainer = Trainer(cfgs, Unsup3D) 17 | run_train = cfgs.get('run_train', False) 18 | run_test = cfgs.get('run_test', False) 19 | 20 | ## run 21 | if run_train: 22 | trainer.train() 23 | if run_test: 24 | trainer.test() 25 | -------------------------------------------------------------------------------- /unsup3d/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import setup_runtime 2 | from .trainer import Trainer 3 | from .model import Unsup3D 4 | -------------------------------------------------------------------------------- /unsup3d/dataloaders.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchvision.transforms as tfs 3 | import torch.utils.data 4 | import numpy as np 5 | from PIL import Image 6 | 7 | 8 | def get_data_loaders(cfgs): 9 | batch_size = cfgs.get('batch_size', 64) 10 | num_workers = cfgs.get('num_workers', 4) 11 | image_size = cfgs.get('image_size', 64) 12 | crop = cfgs.get('crop', None) 13 | 14 | run_train = cfgs.get('run_train', False) 15 | train_val_data_dir = cfgs.get('train_val_data_dir', './data') 16 | run_test = cfgs.get('run_test', False) 17 | test_data_dir = cfgs.get('test_data_dir', './data/test') 18 | 19 | load_gt_depth = cfgs.get('load_gt_depth', False) 20 | AB_dnames = cfgs.get('paired_data_dir_names', ['A', 'B']) 21 | AB_fnames = cfgs.get('paired_data_filename_diff', None) 22 | 23 | train_loader = val_loader = test_loader = None 24 | if load_gt_depth: 25 | get_loader = lambda **kargs: get_paired_image_loader(**kargs, batch_size=batch_size, image_size=image_size, crop=crop, AB_dnames=AB_dnames, AB_fnames=AB_fnames) 26 | else: 27 | get_loader = lambda **kargs: get_image_loader(**kargs, batch_size=batch_size, image_size=image_size, crop=crop) 28 | 29 | if run_train: 30 | train_data_dir = os.path.join(train_val_data_dir, "train") 31 | val_data_dir = os.path.join(train_val_data_dir, "val") 32 | assert os.path.isdir(train_data_dir), "Training data directory does not exist: %s" %train_data_dir 33 | assert os.path.isdir(val_data_dir), "Validation data directory does not exist: %s" %val_data_dir 34 | print(f"Loading training data from {train_data_dir}") 35 | train_loader = get_loader(data_dir=train_data_dir, is_validation=False) 36 | print(f"Loading validation data from {val_data_dir}") 37 | val_loader = get_loader(data_dir=val_data_dir, is_validation=True) 38 | if run_test: 39 | assert os.path.isdir(test_data_dir), "Testing data directory does not exist: %s" %test_data_dir 40 | print(f"Loading testing data from {test_data_dir}") 41 | test_loader = get_loader(data_dir=test_data_dir, is_validation=True) 42 | 43 | return train_loader, val_loader, test_loader 44 | 45 | 46 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp') 47 | def is_image_file(filename): 48 | return filename.lower().endswith(IMG_EXTENSIONS) 49 | 50 | 51 | ## simple image dataset ## 52 | def make_dataset(dir): 53 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 54 | 55 | images = [] 56 | for root, _, fnames in sorted(os.walk(dir)): 57 | for fname in sorted(fnames): 58 | if is_image_file(fname): 59 | fpath = os.path.join(root, fname) 60 | images.append(fpath) 61 | return images 62 | 63 | 64 | class ImageDataset(torch.utils.data.Dataset): 65 | def __init__(self, data_dir, image_size=256, crop=None, is_validation=False): 66 | super(ImageDataset, self).__init__() 67 | self.root = data_dir 68 | self.paths = make_dataset(data_dir) 69 | self.size = len(self.paths) 70 | self.image_size = image_size 71 | self.crop = crop 72 | self.is_validation = is_validation 73 | 74 | def transform(self, img, hflip=False): 75 | if self.crop is not None: 76 | if isinstance(self.crop, int): 77 | img = tfs.CenterCrop(self.crop)(img) 78 | else: 79 | assert len(self.crop) == 4, 'Crop size must be an integer for center crop, or a list of 4 integers (y0,x0,h,w)' 80 | img = tfs.functional.crop(img, *self.crop) 81 | img = tfs.functional.resize(img, (self.image_size, self.image_size)) 82 | if hflip: 83 | img = tfs.functional.hflip(img) 84 | return tfs.functional.to_tensor(img) 85 | 86 | def __getitem__(self, index): 87 | fpath = self.paths[index % self.size] 88 | img = Image.open(fpath).convert('RGB') 89 | hflip = not self.is_validation and np.random.rand()>0.5 90 | return self.transform(img, hflip=hflip) 91 | 92 | def __len__(self): 93 | return self.size 94 | 95 | def name(self): 96 | return 'ImageDataset' 97 | 98 | 99 | def get_image_loader(data_dir, is_validation=False, 100 | batch_size=256, num_workers=4, image_size=256, crop=None): 101 | 102 | dataset = ImageDataset(data_dir, image_size=image_size, crop=crop, is_validation=is_validation) 103 | loader = torch.utils.data.DataLoader( 104 | dataset, 105 | batch_size=batch_size, 106 | shuffle=not is_validation, 107 | num_workers=num_workers, 108 | pin_memory=True 109 | ) 110 | return loader 111 | 112 | 113 | ## paired AB image dataset ## 114 | def make_paied_dataset(dir, AB_dnames=None, AB_fnames=None): 115 | A_dname, B_dname = AB_dnames or ('A', 'B') 116 | dir_A = os.path.join(dir, A_dname) 117 | dir_B = os.path.join(dir, B_dname) 118 | assert os.path.isdir(dir_A), '%s is not a valid directory' % dir_A 119 | assert os.path.isdir(dir_B), '%s is not a valid directory' % dir_B 120 | 121 | images = [] 122 | for root_A, _, fnames_A in sorted(os.walk(dir_A)): 123 | for fname_A in sorted(fnames_A): 124 | if is_image_file(fname_A): 125 | path_A = os.path.join(root_A, fname_A) 126 | root_B = root_A.replace(dir_A, dir_B, 1) 127 | if AB_fnames is not None: 128 | fname_B = fname_A.replace(*AB_fnames) 129 | else: 130 | fname_B = fname_A 131 | path_B = os.path.join(root_B, fname_B) 132 | if os.path.isfile(path_B): 133 | images.append((path_A, path_B)) 134 | return images 135 | 136 | 137 | class PairedDataset(torch.utils.data.Dataset): 138 | def __init__(self, data_dir, image_size=256, crop=None, is_validation=False, AB_dnames=None, AB_fnames=None): 139 | super(PairedDataset, self).__init__() 140 | self.root = data_dir 141 | self.paths = make_paied_dataset(data_dir, AB_dnames=AB_dnames, AB_fnames=AB_fnames) 142 | self.size = len(self.paths) 143 | self.image_size = image_size 144 | self.crop = crop 145 | self.is_validation = is_validation 146 | 147 | def transform(self, img, hflip=False): 148 | if self.crop is not None: 149 | if isinstance(self.crop, int): 150 | img = tfs.CenterCrop(self.crop)(img) 151 | else: 152 | assert len(self.crop) == 4, 'Crop size must be an integer for center crop, or a list of 4 integers (y0,x0,h,w)' 153 | img = tfs.functional.crop(img, *self.crop) 154 | img = tfs.functional.resize(img, (self.image_size, self.image_size)) 155 | if hflip: 156 | img = tfs.functional.hflip(img) 157 | return tfs.functional.to_tensor(img) 158 | 159 | def __getitem__(self, index): 160 | path_A, path_B = self.paths[index % self.size] 161 | img_A = Image.open(path_A).convert('RGB') 162 | img_B = Image.open(path_B).convert('RGB') 163 | hflip = not self.is_validation and np.random.rand()>0.5 164 | return self.transform(img_A, hflip=hflip), self.transform(img_B, hflip=hflip) 165 | 166 | def __len__(self): 167 | return self.size 168 | 169 | def name(self): 170 | return 'PairedDataset' 171 | 172 | 173 | def get_paired_image_loader(data_dir, is_validation=False, 174 | batch_size=256, num_workers=4, image_size=256, crop=None, AB_dnames=None, AB_fnames=None): 175 | 176 | dataset = PairedDataset(data_dir, image_size=image_size, crop=crop, \ 177 | is_validation=is_validation, AB_dnames=AB_dnames, AB_fnames=AB_fnames) 178 | loader = torch.utils.data.DataLoader( 179 | dataset, 180 | batch_size=batch_size, 181 | shuffle=not is_validation, 182 | num_workers=num_workers, 183 | pin_memory=True 184 | ) 185 | return loader 186 | -------------------------------------------------------------------------------- /unsup3d/meters.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | import torch 5 | import operator 6 | from functools import reduce 7 | import matplotlib.pyplot as plt 8 | import collections 9 | from .utils import xmkdir 10 | 11 | 12 | class TotalAverage(): 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.last_value = 0. 18 | self.mass = 0. 19 | self.sum = 0. 20 | 21 | def update(self, value, mass=1): 22 | self.last_value = value 23 | self.mass += mass 24 | self.sum += value * mass 25 | 26 | def get(self): 27 | return self.sum / self.mass 28 | 29 | class MovingAverage(): 30 | def __init__(self, inertia=0.9): 31 | self.inertia = inertia 32 | self.reset() 33 | self.last_value = None 34 | 35 | def reset(self): 36 | self.last_value = None 37 | self.average = None 38 | 39 | def update(self, value, mass=1): 40 | self.last_value = value 41 | if self.average is None: 42 | self.average = value 43 | else: 44 | self.average = self.inertia * self.average + (1 - self.inertia) * value 45 | 46 | def get(self): 47 | return self.average 48 | 49 | class MetricsTrace(): 50 | def __init__(self): 51 | self.reset() 52 | 53 | def reset(self): 54 | self.data = {} 55 | 56 | def append(self, dataset, metric): 57 | if dataset not in self.data: 58 | self.data[dataset] = [] 59 | self.data[dataset].append(metric.get_data_dict()) 60 | 61 | def load(self, path): 62 | """Load the metrics trace from the specified JSON file.""" 63 | with open(path, 'r') as f: 64 | self.data = json.load(f) 65 | 66 | def save(self, path): 67 | """Save the metrics trace to the specified JSON file.""" 68 | if path is None: 69 | return 70 | xmkdir(os.path.dirname(path)) 71 | with open(path, 'w') as f: 72 | json.dump(self.data, f, indent=2) 73 | 74 | def plot(self, pdf_path=None): 75 | """Plots and optionally save as PDF the metrics trace.""" 76 | plot_metrics(self.data, pdf_path=pdf_path) 77 | 78 | def get(self): 79 | return self.data 80 | 81 | def __str__(self): 82 | pass 83 | 84 | class Metrics(): 85 | def __init__(self): 86 | self.iteration_time = MovingAverage(inertia=0.9) 87 | self.now = time.time() 88 | 89 | def update(self, prediction=None, ground_truth=None): 90 | self.iteration_time.update(time.time() - self.now) 91 | self.now = time.time() 92 | 93 | def get_data_dict(self): 94 | return {"objective" : self.objective.get(), "iteration_time" : self.iteration_time.get()} 95 | 96 | class StandardMetrics(Metrics): 97 | def __init__(self, m=None): 98 | super(StandardMetrics, self).__init__() 99 | self.metrics = m or {} 100 | self.speed = MovingAverage(inertia=0.9) 101 | 102 | def update(self, metric_dict, mass=1): 103 | super(StandardMetrics, self).update() 104 | for metric, val in metric_dict.items(): 105 | if torch.is_tensor(val): 106 | val = val.item() 107 | if metric not in self.metrics: 108 | self.metrics[metric] = TotalAverage() 109 | self.metrics[metric].update(val, mass) 110 | self.speed.update(mass / self.iteration_time.last_value) 111 | 112 | def get_data_dict(self): 113 | data_dict = {k: v.get() for k,v in self.metrics.items()} 114 | data_dict['speed'] = self.speed.get() 115 | return data_dict 116 | 117 | def __str__(self): 118 | pstr = '%7.1fHz\t' %self.speed.get() 119 | pstr += '\t'.join(['%s: %6.5f' %(k,v.get()) for k,v in self.metrics.items()]) 120 | return pstr 121 | 122 | def plot_metrics(stats, pdf_path=None, fig=1, datasets=None, metrics=None): 123 | """Plot metrics. `stats` should be a dictionary of type 124 | 125 | stats[dataset][t][metric][i] 126 | 127 | where dataset is the dataset name (e.g. `train` or `val`), t is an iteration number, 128 | metric is the name of a metric (e.g. `loss` or `top1`), and i is a loss dimension. 129 | 130 | Alternatively, if a loss has a single dimension, `stats[dataset][t][metric]` can 131 | be a scalar. 132 | 133 | The supported options are: 134 | 135 | - pdf_file: path to a PDF file to store the figure (default: None) 136 | - fig: MatPlotLib figure index (default: 1) 137 | - datasets: list of dataset names to plot (default: None) 138 | - metrics: list of metrics to plot (default: None) 139 | """ 140 | plt.figure(fig) 141 | plt.clf() 142 | linestyles = ['-', '--', '-.', ':'] 143 | datasets = list(stats.keys()) if datasets is None else datasets 144 | # Filter out empty datasets 145 | datasets = [d for d in datasets if len(stats[d]) > 0] 146 | duration = len(stats[datasets[0]]) 147 | metrics = list(stats[datasets[0]][0].keys()) if metrics is None else metrics 148 | for m, metric in enumerate(metrics): 149 | plt.subplot(len(metrics),1,m+1) 150 | legend_content = [] 151 | for d, dataset in enumerate(datasets): 152 | ls = linestyles[d % len(linestyles)] 153 | if isinstance(stats[dataset][0][metric], collections.Iterable): 154 | metric_dimension = len(stats[dataset][0][metric]) 155 | for sl in range(metric_dimension): 156 | x = [stats[dataset][t][metric][sl] for t in range(duration)] 157 | plt.plot(x, linestyle=ls) 158 | name = f'{dataset} {metric}[{sl}]' 159 | legend_content.append(name) 160 | else: 161 | x = [stats[dataset][t][metric] for t in range(duration)] 162 | plt.plot(x, linestyle=ls) 163 | name = f'{dataset} {metric}' 164 | legend_content.append(name) 165 | plt.legend(legend_content, loc=(1.04,0)) 166 | plt.grid(True) 167 | if pdf_path is not None: 168 | plt.savefig(pdf_path, format='pdf', bbox_inches='tight') 169 | plt.draw() 170 | plt.pause(0.0001) 171 | -------------------------------------------------------------------------------- /unsup3d/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import glob 4 | import torch 5 | import torch.nn as nn 6 | import torchvision 7 | from . import networks 8 | from . import utils 9 | from .renderer import Renderer 10 | 11 | 12 | EPS = 1e-7 13 | 14 | 15 | class Unsup3D(): 16 | def __init__(self, cfgs): 17 | self.model_name = cfgs.get('model_name', self.__class__.__name__) 18 | self.device = cfgs.get('device', 'cpu') 19 | self.image_size = cfgs.get('image_size', 64) 20 | self.min_depth = cfgs.get('min_depth', 0.9) 21 | self.max_depth = cfgs.get('max_depth', 1.1) 22 | self.border_depth = cfgs.get('border_depth', (0.7*self.max_depth + 0.3*self.min_depth)) 23 | self.min_amb_light = cfgs.get('min_amb_light', 0.) 24 | self.max_amb_light = cfgs.get('max_amb_light', 1.) 25 | self.min_diff_light = cfgs.get('min_diff_light', 0.) 26 | self.max_diff_light = cfgs.get('max_diff_light', 1.) 27 | self.xyz_rotation_range = cfgs.get('xyz_rotation_range', 60) 28 | self.xy_translation_range = cfgs.get('xy_translation_range', 0.1) 29 | self.z_translation_range = cfgs.get('z_translation_range', 0.1) 30 | self.use_conf_map = cfgs.get('use_conf_map', True) 31 | self.lam_perc = cfgs.get('lam_perc', 1) 32 | self.lam_flip = cfgs.get('lam_flip', 0.5) 33 | self.lam_flip_start_epoch = cfgs.get('lam_flip_start_epoch', 0) 34 | self.lam_depth_sm = cfgs.get('lam_depth_sm', 0) 35 | self.lr = cfgs.get('lr', 1e-4) 36 | self.load_gt_depth = cfgs.get('load_gt_depth', False) 37 | self.renderer = Renderer(cfgs) 38 | 39 | ## networks and optimizers 40 | self.netD = networks.EDDeconv(cin=3, cout=1, nf=64, zdim=256, activation=None) 41 | self.netA = networks.EDDeconv(cin=3, cout=3, nf=64, zdim=256) 42 | self.netL = networks.Encoder(cin=3, cout=4, nf=32) 43 | self.netV = networks.Encoder(cin=3, cout=6, nf=32) 44 | if self.use_conf_map: 45 | self.netC = networks.ConfNet(cin=3, cout=2, nf=64, zdim=128) 46 | self.network_names = [k for k in vars(self) if 'net' in k] 47 | self.make_optimizer = lambda model: torch.optim.Adam( 48 | filter(lambda p: p.requires_grad, model.parameters()), 49 | lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4) 50 | 51 | ## other parameters 52 | self.PerceptualLoss = networks.PerceptualLoss(requires_grad=False) 53 | self.other_param_names = ['PerceptualLoss'] 54 | 55 | ## depth rescaler: -1~1 -> min_deph~max_deph 56 | self.depth_rescaler = lambda d : (1+d)/2 *self.max_depth + (1-d)/2 *self.min_depth 57 | self.amb_light_rescaler = lambda x : (1+x)/2 *self.max_amb_light + (1-x)/2 *self.min_amb_light 58 | self.diff_light_rescaler = lambda x : (1+x)/2 *self.max_diff_light + (1-x)/2 *self.min_diff_light 59 | 60 | def init_optimizers(self): 61 | self.optimizer_names = [] 62 | for net_name in self.network_names: 63 | optimizer = self.make_optimizer(getattr(self, net_name)) 64 | optim_name = net_name.replace('net','optimizer') 65 | setattr(self, optim_name, optimizer) 66 | self.optimizer_names += [optim_name] 67 | 68 | def load_model_state(self, cp): 69 | for k in cp: 70 | if k and k in self.network_names: 71 | getattr(self, k).load_state_dict(cp[k]) 72 | 73 | def load_optimizer_state(self, cp): 74 | for k in cp: 75 | if k and k in self.optimizer_names: 76 | getattr(self, k).load_state_dict(cp[k]) 77 | 78 | def get_model_state(self): 79 | states = {} 80 | for net_name in self.network_names: 81 | states[net_name] = getattr(self, net_name).state_dict() 82 | return states 83 | 84 | def get_optimizer_state(self): 85 | states = {} 86 | for optim_name in self.optimizer_names: 87 | states[optim_name] = getattr(self, optim_name).state_dict() 88 | return states 89 | 90 | def to_device(self, device): 91 | self.device = device 92 | for net_name in self.network_names: 93 | setattr(self, net_name, getattr(self, net_name).to(device)) 94 | if self.other_param_names: 95 | for param_name in self.other_param_names: 96 | setattr(self, param_name, getattr(self, param_name).to(device)) 97 | 98 | def set_train(self): 99 | for net_name in self.network_names: 100 | getattr(self, net_name).train() 101 | 102 | def set_eval(self): 103 | for net_name in self.network_names: 104 | getattr(self, net_name).eval() 105 | 106 | def photometric_loss(self, im1, im2, mask=None, conf_sigma=None): 107 | loss = (im1-im2).abs() 108 | if conf_sigma is not None: 109 | loss = loss *2**0.5 / (conf_sigma +EPS) + (conf_sigma +EPS).log() 110 | if mask is not None: 111 | mask = mask.expand_as(loss) 112 | loss = (loss * mask).sum() / mask.sum() 113 | else: 114 | loss = loss.mean() 115 | return loss 116 | 117 | def backward(self): 118 | for optim_name in self.optimizer_names: 119 | getattr(self, optim_name).zero_grad() 120 | self.loss_total.backward() 121 | for optim_name in self.optimizer_names: 122 | getattr(self, optim_name).step() 123 | 124 | def forward(self, input): 125 | """Feedforward once.""" 126 | if self.load_gt_depth: 127 | input, depth_gt = input 128 | self.input_im = input.to(self.device) *2.-1. 129 | b, c, h, w = self.input_im.shape 130 | 131 | ## predict canonical depth 132 | self.canon_depth_raw = self.netD(self.input_im).squeeze(1) # BxHxW 133 | self.canon_depth = self.canon_depth_raw - self.canon_depth_raw.view(b,-1).mean(1).view(b,1,1) 134 | self.canon_depth = self.canon_depth.tanh() 135 | self.canon_depth = self.depth_rescaler(self.canon_depth) 136 | 137 | ## optional depth smoothness loss (only used in synthetic car experiments) 138 | self.loss_depth_sm = ((self.canon_depth[:,:-1,:] - self.canon_depth[:,1:,:]) /(self.max_depth-self.min_depth)).abs().mean() 139 | self.loss_depth_sm += ((self.canon_depth[:,:,:-1] - self.canon_depth[:,:,1:]) /(self.max_depth-self.min_depth)).abs().mean() 140 | 141 | ## clamp border depth 142 | depth_border = torch.zeros(1,h,w-4).to(self.input_im.device) 143 | depth_border = nn.functional.pad(depth_border, (2,2), mode='constant', value=1) 144 | self.canon_depth = self.canon_depth*(1-depth_border) + depth_border *self.border_depth 145 | self.canon_depth = torch.cat([self.canon_depth, self.canon_depth.flip(2)], 0) # flip 146 | 147 | ## predict canonical albedo 148 | self.canon_albedo = self.netA(self.input_im) # Bx3xHxW 149 | self.canon_albedo = torch.cat([self.canon_albedo, self.canon_albedo.flip(3)], 0) # flip 150 | 151 | ## predict confidence map 152 | if self.use_conf_map: 153 | conf_sigma_l1, conf_sigma_percl = self.netC(self.input_im) # Bx2xHxW 154 | self.conf_sigma_l1 = conf_sigma_l1[:,:1] 155 | self.conf_sigma_l1_flip = conf_sigma_l1[:,1:] 156 | self.conf_sigma_percl = conf_sigma_percl[:,:1] 157 | self.conf_sigma_percl_flip = conf_sigma_percl[:,1:] 158 | else: 159 | self.conf_sigma_l1 = None 160 | self.conf_sigma_l1_flip = None 161 | self.conf_sigma_percl = None 162 | self.conf_sigma_percl_flip = None 163 | 164 | ## predict lighting 165 | canon_light = self.netL(self.input_im).repeat(2,1) # Bx4 166 | self.canon_light_a = self.amb_light_rescaler(canon_light[:,:1]) # ambience term 167 | self.canon_light_b = self.diff_light_rescaler(canon_light[:,1:2]) # diffuse term 168 | canon_light_dxy = canon_light[:,2:] 169 | self.canon_light_d = torch.cat([canon_light_dxy, torch.ones(b*2,1).to(self.input_im.device)], 1) 170 | self.canon_light_d = self.canon_light_d / ((self.canon_light_d**2).sum(1, keepdim=True))**0.5 # diffuse light direction 171 | 172 | ## shading 173 | self.canon_normal = self.renderer.get_normal_from_depth(self.canon_depth) 174 | self.canon_diffuse_shading = (self.canon_normal * self.canon_light_d.view(-1,1,1,3)).sum(3).clamp(min=0).unsqueeze(1) 175 | canon_shading = self.canon_light_a.view(-1,1,1,1) + self.canon_light_b.view(-1,1,1,1)*self.canon_diffuse_shading 176 | self.canon_im = (self.canon_albedo/2+0.5) * canon_shading *2-1 177 | 178 | ## predict viewpoint transformation 179 | self.view = self.netV(self.input_im).repeat(2,1) 180 | self.view = torch.cat([ 181 | self.view[:,:3] *math.pi/180 *self.xyz_rotation_range, 182 | self.view[:,3:5] *self.xy_translation_range, 183 | self.view[:,5:] *self.z_translation_range], 1) 184 | 185 | ## reconstruct input view 186 | self.renderer.set_transform_matrices(self.view) 187 | self.recon_depth = self.renderer.warp_canon_depth(self.canon_depth) 188 | self.recon_normal = self.renderer.get_normal_from_depth(self.recon_depth) 189 | grid_2d_from_canon = self.renderer.get_inv_warped_2d_grid(self.recon_depth) 190 | self.recon_im = nn.functional.grid_sample(self.canon_im, grid_2d_from_canon, mode='bilinear') 191 | 192 | margin = (self.max_depth - self.min_depth) /2 193 | recon_im_mask = (self.recon_depth < self.max_depth+margin).float() # invalid border pixels have been clamped at max_depth+margin 194 | recon_im_mask_both = recon_im_mask[:b] * recon_im_mask[b:] # both original and flip reconstruction 195 | recon_im_mask_both = recon_im_mask_both.repeat(2,1,1).unsqueeze(1).detach() 196 | self.recon_im = self.recon_im * recon_im_mask_both 197 | 198 | ## render symmetry axis 199 | canon_sym_axis = torch.zeros(h, w).to(self.input_im.device) 200 | canon_sym_axis[:, w//2-1:w//2+1] = 1 201 | self.recon_sym_axis = nn.functional.grid_sample(canon_sym_axis.repeat(b*2,1,1,1), grid_2d_from_canon, mode='bilinear') 202 | self.recon_sym_axis = self.recon_sym_axis * recon_im_mask_both 203 | green = torch.FloatTensor([-1,1,-1]).to(self.input_im.device).view(1,3,1,1) 204 | self.input_im_symline = (0.5*self.recon_sym_axis) *green + (1-0.5*self.recon_sym_axis) *self.input_im.repeat(2,1,1,1) 205 | 206 | ## loss function 207 | self.loss_l1_im = self.photometric_loss(self.recon_im[:b], self.input_im, mask=recon_im_mask_both[:b], conf_sigma=self.conf_sigma_l1) 208 | self.loss_l1_im_flip = self.photometric_loss(self.recon_im[b:], self.input_im, mask=recon_im_mask_both[b:], conf_sigma=self.conf_sigma_l1_flip) 209 | self.loss_perc_im = self.PerceptualLoss(self.recon_im[:b], self.input_im, mask=recon_im_mask_both[:b], conf_sigma=self.conf_sigma_percl) 210 | self.loss_perc_im_flip = self.PerceptualLoss(self.recon_im[b:], self.input_im, mask=recon_im_mask_both[b:], conf_sigma=self.conf_sigma_percl_flip) 211 | lam_flip = 1 if self.trainer.current_epoch < self.lam_flip_start_epoch else self.lam_flip 212 | self.loss_total = self.loss_l1_im + lam_flip*self.loss_l1_im_flip + self.lam_perc*(self.loss_perc_im + lam_flip*self.loss_perc_im_flip) + self.lam_depth_sm*self.loss_depth_sm 213 | 214 | metrics = {'loss': self.loss_total} 215 | 216 | ## compute accuracy if gt depth is available 217 | if self.load_gt_depth: 218 | self.depth_gt = depth_gt[:,0,:,:].to(self.input_im.device) 219 | self.depth_gt = (1-self.depth_gt)*2-1 220 | self.depth_gt = self.depth_rescaler(self.depth_gt) 221 | self.normal_gt = self.renderer.get_normal_from_depth(self.depth_gt) 222 | 223 | # mask out background 224 | mask_gt = (self.depth_gt 0.99).float() # erode by 1 pixel 226 | mask_pred = (nn.functional.avg_pool2d(recon_im_mask[:b].unsqueeze(1), 3, stride=1, padding=1).squeeze(1) > 0.99).float() # erode by 1 pixel 227 | mask = mask_gt * mask_pred 228 | self.acc_mae_masked = ((self.recon_depth[:b] - self.depth_gt[:b]).abs() *mask).view(b,-1).sum(1) / mask.view(b,-1).sum(1) 229 | self.acc_mse_masked = (((self.recon_depth[:b] - self.depth_gt[:b])**2) *mask).view(b,-1).sum(1) / mask.view(b,-1).sum(1) 230 | self.sie_map_masked = utils.compute_sc_inv_err(self.recon_depth[:b].log(), self.depth_gt[:b].log(), mask=mask) 231 | self.acc_sie_masked = (self.sie_map_masked.view(b,-1).sum(1) / mask.view(b,-1).sum(1))**0.5 232 | self.norm_err_map_masked = utils.compute_angular_distance(self.recon_normal[:b], self.normal_gt[:b], mask=mask) 233 | self.acc_normal_masked = self.norm_err_map_masked.view(b,-1).sum(1) / mask.view(b,-1).sum(1) 234 | 235 | metrics['SIE_masked'] = self.acc_sie_masked.mean() 236 | metrics['NorErr_masked'] = self.acc_normal_masked.mean() 237 | 238 | return metrics 239 | 240 | def visualize(self, logger, total_iter, max_bs=25): 241 | b, c, h, w = self.input_im.shape 242 | b0 = min(max_bs, b) 243 | 244 | ## render rotations 245 | with torch.no_grad(): 246 | v0 = torch.FloatTensor([-0.1*math.pi/180*60,0,0,0,0,0]).to(self.input_im.device).repeat(b0,1) 247 | canon_im_rotate = self.renderer.render_yaw(self.canon_im[:b0], self.canon_depth[:b0], v_before=v0, maxr=90).detach().cpu() /2.+0.5 # (B,T,C,H,W) 248 | canon_normal_rotate = self.renderer.render_yaw(self.canon_normal[:b0].permute(0,3,1,2), self.canon_depth[:b0], v_before=v0, maxr=90).detach().cpu() /2.+0.5 # (B,T,C,H,W) 249 | 250 | input_im = self.input_im[:b0].detach().cpu().numpy() /2+0.5 251 | input_im_symline = self.input_im_symline[:b0].detach().cpu() /2.+0.5 252 | canon_albedo = self.canon_albedo[:b0].detach().cpu() /2.+0.5 253 | canon_im = self.canon_im[:b0].detach().cpu() /2.+0.5 254 | recon_im = self.recon_im[:b0].detach().cpu() /2.+0.5 255 | recon_im_flip = self.recon_im[b:b+b0].detach().cpu() /2.+0.5 256 | canon_depth_raw_hist = self.canon_depth_raw.detach().unsqueeze(1).cpu() 257 | canon_depth_raw = self.canon_depth_raw[:b0].detach().unsqueeze(1).cpu() /2.+0.5 258 | canon_depth = ((self.canon_depth[:b0] -self.min_depth)/(self.max_depth-self.min_depth)).detach().cpu().unsqueeze(1) 259 | recon_depth = ((self.recon_depth[:b0] -self.min_depth)/(self.max_depth-self.min_depth)).detach().cpu().unsqueeze(1) 260 | canon_diffuse_shading = self.canon_diffuse_shading[:b0].detach().cpu() 261 | canon_normal = self.canon_normal.permute(0,3,1,2)[:b0].detach().cpu() /2+0.5 262 | recon_normal = self.recon_normal.permute(0,3,1,2)[:b0].detach().cpu() /2+0.5 263 | if self.use_conf_map: 264 | conf_map_l1 = 1/(1+self.conf_sigma_l1[:b0].detach().cpu()+EPS) 265 | conf_map_l1_flip = 1/(1+self.conf_sigma_l1_flip[:b0].detach().cpu()+EPS) 266 | conf_map_percl = 1/(1+self.conf_sigma_percl[:b0].detach().cpu()+EPS) 267 | conf_map_percl_flip = 1/(1+self.conf_sigma_percl_flip[:b0].detach().cpu()+EPS) 268 | 269 | canon_im_rotate_grid = [torchvision.utils.make_grid(img, nrow=int(math.ceil(b0**0.5))) for img in torch.unbind(canon_im_rotate, 1)] # [(C,H,W)]*T 270 | canon_im_rotate_grid = torch.stack(canon_im_rotate_grid, 0).unsqueeze(0) # (1,T,C,H,W) 271 | canon_normal_rotate_grid = [torchvision.utils.make_grid(img, nrow=int(math.ceil(b0**0.5))) for img in torch.unbind(canon_normal_rotate, 1)] # [(C,H,W)]*T 272 | canon_normal_rotate_grid = torch.stack(canon_normal_rotate_grid, 0).unsqueeze(0) # (1,T,C,H,W) 273 | 274 | ## write summary 275 | logger.add_scalar('Loss/loss_total', self.loss_total, total_iter) 276 | logger.add_scalar('Loss/loss_l1_im', self.loss_l1_im, total_iter) 277 | logger.add_scalar('Loss/loss_l1_im_flip', self.loss_l1_im_flip, total_iter) 278 | logger.add_scalar('Loss/loss_perc_im', self.loss_perc_im, total_iter) 279 | logger.add_scalar('Loss/loss_perc_im_flip', self.loss_perc_im_flip, total_iter) 280 | logger.add_scalar('Loss/loss_depth_sm', self.loss_depth_sm, total_iter) 281 | 282 | logger.add_histogram('Depth/canon_depth_raw_hist', canon_depth_raw_hist, total_iter) 283 | vlist = ['view_rx', 'view_ry', 'view_rz', 'view_tx', 'view_ty', 'view_tz'] 284 | for i in range(self.view.shape[1]): 285 | logger.add_histogram('View/'+vlist[i], self.view[:,i], total_iter) 286 | logger.add_histogram('Light/canon_light_a', self.canon_light_a, total_iter) 287 | logger.add_histogram('Light/canon_light_b', self.canon_light_b, total_iter) 288 | llist = ['canon_light_dx', 'canon_light_dy', 'canon_light_dz'] 289 | for i in range(self.canon_light_d.shape[1]): 290 | logger.add_histogram('Light/'+llist[i], self.canon_light_d[:,i], total_iter) 291 | 292 | def log_grid_image(label, im, nrow=int(math.ceil(b0**0.5)), iter=total_iter): 293 | im_grid = torchvision.utils.make_grid(im, nrow=nrow) 294 | logger.add_image(label, im_grid, iter) 295 | 296 | log_grid_image('Image/input_image_symline', input_im_symline) 297 | log_grid_image('Image/canonical_albedo', canon_albedo) 298 | log_grid_image('Image/canonical_image', canon_im) 299 | log_grid_image('Image/recon_image', recon_im) 300 | log_grid_image('Image/recon_image_flip', recon_im_flip) 301 | log_grid_image('Image/recon_side', canon_im_rotate[:,0,:,:,:]) 302 | 303 | log_grid_image('Depth/canonical_depth_raw', canon_depth_raw) 304 | log_grid_image('Depth/canonical_depth', canon_depth) 305 | log_grid_image('Depth/recon_depth', recon_depth) 306 | log_grid_image('Depth/canonical_diffuse_shading', canon_diffuse_shading) 307 | log_grid_image('Depth/canonical_normal', canon_normal) 308 | log_grid_image('Depth/recon_normal', recon_normal) 309 | 310 | logger.add_histogram('Image/canonical_albedo_hist', canon_albedo, total_iter) 311 | logger.add_histogram('Image/canonical_diffuse_shading_hist', canon_diffuse_shading, total_iter) 312 | 313 | if self.use_conf_map: 314 | log_grid_image('Conf/conf_map_l1', conf_map_l1) 315 | logger.add_histogram('Conf/conf_sigma_l1_hist', self.conf_sigma_l1, total_iter) 316 | log_grid_image('Conf/conf_map_l1_flip', conf_map_l1_flip) 317 | logger.add_histogram('Conf/conf_sigma_l1_flip_hist', self.conf_sigma_l1_flip, total_iter) 318 | log_grid_image('Conf/conf_map_percl', conf_map_percl) 319 | logger.add_histogram('Conf/conf_sigma_percl_hist', self.conf_sigma_percl, total_iter) 320 | log_grid_image('Conf/conf_map_percl_flip', conf_map_percl_flip) 321 | logger.add_histogram('Conf/conf_sigma_percl_flip_hist', self.conf_sigma_percl_flip, total_iter) 322 | 323 | logger.add_video('Image_rotate/recon_rotate', canon_im_rotate_grid, total_iter, fps=4) 324 | logger.add_video('Image_rotate/canon_normal_rotate', canon_normal_rotate_grid, total_iter, fps=4) 325 | 326 | # visualize images and accuracy if gt is loaded 327 | if self.load_gt_depth: 328 | depth_gt = ((self.depth_gt[:b0] -self.min_depth)/(self.max_depth-self.min_depth)).detach().cpu().unsqueeze(1) 329 | normal_gt = self.normal_gt.permute(0,3,1,2)[:b0].detach().cpu() /2+0.5 330 | sie_map_masked = self.sie_map_masked[:b0].detach().unsqueeze(1).cpu() *1000 331 | norm_err_map_masked = self.norm_err_map_masked[:b0].detach().unsqueeze(1).cpu() /100 332 | 333 | logger.add_scalar('Acc_masked/MAE_masked', self.acc_mae_masked.mean(), total_iter) 334 | logger.add_scalar('Acc_masked/MSE_masked', self.acc_mse_masked.mean(), total_iter) 335 | logger.add_scalar('Acc_masked/SIE_masked', self.acc_sie_masked.mean(), total_iter) 336 | logger.add_scalar('Acc_masked/NorErr_masked', self.acc_normal_masked.mean(), total_iter) 337 | 338 | log_grid_image('Depth_gt/depth_gt', depth_gt) 339 | log_grid_image('Depth_gt/normal_gt', normal_gt) 340 | log_grid_image('Depth_gt/sie_map_masked', sie_map_masked) 341 | log_grid_image('Depth_gt/norm_err_map_masked', norm_err_map_masked) 342 | 343 | def save_results(self, save_dir): 344 | b, c, h, w = self.input_im.shape 345 | 346 | with torch.no_grad(): 347 | v0 = torch.FloatTensor([-0.1*math.pi/180*60,0,0,0,0,0]).to(self.input_im.device).repeat(b,1) 348 | canon_im_rotate = self.renderer.render_yaw(self.canon_im[:b], self.canon_depth[:b], v_before=v0, maxr=90, nsample=15) # (B,T,C,H,W) 349 | canon_im_rotate = canon_im_rotate.clamp(-1,1).detach().cpu() /2+0.5 350 | canon_normal_rotate = self.renderer.render_yaw(self.canon_normal[:b].permute(0,3,1,2), self.canon_depth[:b], v_before=v0, maxr=90, nsample=15) # (B,T,C,H,W) 351 | canon_normal_rotate = canon_normal_rotate.clamp(-1,1).detach().cpu() /2+0.5 352 | 353 | input_im = self.input_im[:b].detach().cpu().numpy() /2+0.5 354 | input_im_symline = self.input_im_symline.detach().cpu().numpy() /2.+0.5 355 | canon_albedo = self.canon_albedo[:b].detach().cpu().numpy() /2+0.5 356 | canon_im = self.canon_im[:b].clamp(-1,1).detach().cpu().numpy() /2+0.5 357 | recon_im = self.recon_im[:b].clamp(-1,1).detach().cpu().numpy() /2+0.5 358 | recon_im_flip = self.recon_im[b:].clamp(-1,1).detach().cpu().numpy() /2+0.5 359 | canon_depth = ((self.canon_depth[:b] -self.min_depth)/(self.max_depth-self.min_depth)).clamp(0,1).detach().cpu().unsqueeze(1).numpy() 360 | recon_depth = ((self.recon_depth[:b] -self.min_depth)/(self.max_depth-self.min_depth)).clamp(0,1).detach().cpu().unsqueeze(1).numpy() 361 | canon_diffuse_shading = self.canon_diffuse_shading[:b].detach().cpu().numpy() 362 | canon_normal = self.canon_normal[:b].permute(0,3,1,2).detach().cpu().numpy() /2+0.5 363 | recon_normal = self.recon_normal[:b].permute(0,3,1,2).detach().cpu().numpy() /2+0.5 364 | if self.use_conf_map: 365 | conf_map_l1 = 1/(1+self.conf_sigma_l1[:b].detach().cpu().numpy()+EPS) 366 | conf_map_l1_flip = 1/(1+self.conf_sigma_l1_flip[:b].detach().cpu().numpy()+EPS) 367 | conf_map_percl = 1/(1+self.conf_sigma_percl[:b].detach().cpu().numpy()+EPS) 368 | conf_map_percl_flip = 1/(1+self.conf_sigma_percl_flip[:b].detach().cpu().numpy()+EPS) 369 | canon_light = torch.cat([self.canon_light_a, self.canon_light_b, self.canon_light_d], 1)[:b].detach().cpu().numpy() 370 | view = self.view[:b].detach().cpu().numpy() 371 | 372 | canon_im_rotate_grid = [torchvision.utils.make_grid(img, nrow=int(math.ceil(b**0.5))) for img in torch.unbind(canon_im_rotate,1)] # [(C,H,W)]*T 373 | canon_im_rotate_grid = torch.stack(canon_im_rotate_grid, 0).unsqueeze(0).numpy() # (1,T,C,H,W) 374 | canon_normal_rotate_grid = [torchvision.utils.make_grid(img, nrow=int(math.ceil(b**0.5))) for img in torch.unbind(canon_normal_rotate,1)] # [(C,H,W)]*T 375 | canon_normal_rotate_grid = torch.stack(canon_normal_rotate_grid, 0).unsqueeze(0).numpy() # (1,T,C,H,W) 376 | 377 | sep_folder = True 378 | utils.save_images(save_dir, input_im, suffix='input_image', sep_folder=sep_folder) 379 | utils.save_images(save_dir, input_im_symline, suffix='input_image_symline', sep_folder=sep_folder) 380 | utils.save_images(save_dir, canon_albedo, suffix='canonical_albedo', sep_folder=sep_folder) 381 | utils.save_images(save_dir, canon_im, suffix='canonical_image', sep_folder=sep_folder) 382 | utils.save_images(save_dir, recon_im, suffix='recon_image', sep_folder=sep_folder) 383 | utils.save_images(save_dir, recon_im_flip, suffix='recon_image_flip', sep_folder=sep_folder) 384 | utils.save_images(save_dir, canon_depth, suffix='canonical_depth', sep_folder=sep_folder) 385 | utils.save_images(save_dir, recon_depth, suffix='recon_depth', sep_folder=sep_folder) 386 | utils.save_images(save_dir, canon_diffuse_shading, suffix='canonical_diffuse_shading', sep_folder=sep_folder) 387 | utils.save_images(save_dir, canon_normal, suffix='canonical_normal', sep_folder=sep_folder) 388 | utils.save_images(save_dir, recon_normal, suffix='recon_normal', sep_folder=sep_folder) 389 | if self.use_conf_map: 390 | utils.save_images(save_dir, conf_map_l1, suffix='conf_map_l1', sep_folder=sep_folder) 391 | utils.save_images(save_dir, conf_map_l1_flip, suffix='conf_map_l1_flip', sep_folder=sep_folder) 392 | utils.save_images(save_dir, conf_map_percl, suffix='conf_map_percl', sep_folder=sep_folder) 393 | utils.save_images(save_dir, conf_map_percl_flip, suffix='conf_map_percl_flip', sep_folder=sep_folder) 394 | utils.save_txt(save_dir, canon_light, suffix='canonical_light', sep_folder=sep_folder) 395 | utils.save_txt(save_dir, view, suffix='viewpoint', sep_folder=sep_folder) 396 | 397 | utils.save_videos(save_dir, canon_im_rotate_grid, suffix='image_video', sep_folder=sep_folder, cycle=True) 398 | utils.save_videos(save_dir, canon_normal_rotate_grid, suffix='normal_video', sep_folder=sep_folder, cycle=True) 399 | 400 | # save scores if gt is loaded 401 | if self.load_gt_depth: 402 | depth_gt = ((self.depth_gt[:b] -self.min_depth)/(self.max_depth-self.min_depth)).clamp(0,1).detach().cpu().unsqueeze(1).numpy() 403 | normal_gt = self.normal_gt[:b].permute(0,3,1,2).detach().cpu().numpy() /2+0.5 404 | utils.save_images(save_dir, depth_gt, suffix='depth_gt', sep_folder=sep_folder) 405 | utils.save_images(save_dir, normal_gt, suffix='normal_gt', sep_folder=sep_folder) 406 | 407 | all_scores = torch.stack([ 408 | self.acc_mae_masked.detach().cpu(), 409 | self.acc_mse_masked.detach().cpu(), 410 | self.acc_sie_masked.detach().cpu(), 411 | self.acc_normal_masked.detach().cpu()], 1) 412 | if not hasattr(self, 'all_scores'): 413 | self.all_scores = torch.FloatTensor() 414 | self.all_scores = torch.cat([self.all_scores, all_scores], 0) 415 | 416 | def save_scores(self, path): 417 | # save scores if gt is loaded 418 | if self.load_gt_depth: 419 | header = 'MAE_masked, \ 420 | MSE_masked, \ 421 | SIE_masked, \ 422 | NorErr_masked' 423 | mean = self.all_scores.mean(0) 424 | std = self.all_scores.std(0) 425 | header = header + '\nMean: ' + ',\t'.join(['%.8f'%x for x in mean]) 426 | header = header + '\nStd: ' + ',\t'.join(['%.8f'%x for x in std]) 427 | utils.save_scores(path, self.all_scores, header=header) 428 | -------------------------------------------------------------------------------- /unsup3d/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | 6 | EPS = 1e-7 7 | 8 | 9 | class Encoder(nn.Module): 10 | def __init__(self, cin, cout, nf=64, activation=nn.Tanh): 11 | super(Encoder, self).__init__() 12 | network = [ 13 | nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32 14 | nn.ReLU(inplace=True), 15 | nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16 16 | nn.ReLU(inplace=True), 17 | nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(nf*4, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 22 | nn.ReLU(inplace=True), 23 | nn.Conv2d(nf*8, cout, kernel_size=1, stride=1, padding=0, bias=False)] 24 | if activation is not None: 25 | network += [activation()] 26 | self.network = nn.Sequential(*network) 27 | 28 | def forward(self, input): 29 | return self.network(input).reshape(input.size(0),-1) 30 | 31 | 32 | class EDDeconv(nn.Module): 33 | def __init__(self, cin, cout, zdim=128, nf=64, activation=nn.Tanh): 34 | super(EDDeconv, self).__init__() 35 | ## downsampling 36 | network = [ 37 | nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32 38 | nn.GroupNorm(16, nf), 39 | nn.LeakyReLU(0.2, inplace=True), 40 | nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16 41 | nn.GroupNorm(16*2, nf*2), 42 | nn.LeakyReLU(0.2, inplace=True), 43 | nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8 44 | nn.GroupNorm(16*4, nf*4), 45 | nn.LeakyReLU(0.2, inplace=True), 46 | nn.Conv2d(nf*4, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4 47 | nn.LeakyReLU(0.2, inplace=True), 48 | nn.Conv2d(nf*8, zdim, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 49 | nn.ReLU(inplace=True)] 50 | ## upsampling 51 | network += [ 52 | nn.ConvTranspose2d(zdim, nf*8, kernel_size=4, stride=1, padding=0, bias=False), # 1x1 -> 4x4 53 | nn.ReLU(inplace=True), 54 | nn.Conv2d(nf*8, nf*8, kernel_size=3, stride=1, padding=1, bias=False), 55 | nn.ReLU(inplace=True), 56 | nn.ConvTranspose2d(nf*8, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 4x4 -> 8x8 57 | nn.GroupNorm(16*4, nf*4), 58 | nn.ReLU(inplace=True), 59 | nn.Conv2d(nf*4, nf*4, kernel_size=3, stride=1, padding=1, bias=False), 60 | nn.GroupNorm(16*4, nf*4), 61 | nn.ReLU(inplace=True), 62 | nn.ConvTranspose2d(nf*4, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 16x16 63 | nn.GroupNorm(16*2, nf*2), 64 | nn.ReLU(inplace=True), 65 | nn.Conv2d(nf*2, nf*2, kernel_size=3, stride=1, padding=1, bias=False), 66 | nn.GroupNorm(16*2, nf*2), 67 | nn.ReLU(inplace=True), 68 | nn.ConvTranspose2d(nf*2, nf, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 32x32 69 | nn.GroupNorm(16, nf), 70 | nn.ReLU(inplace=True), 71 | nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=False), 72 | nn.GroupNorm(16, nf), 73 | nn.ReLU(inplace=True), 74 | nn.Upsample(scale_factor=2, mode='nearest'), # 32x32 -> 64x64 75 | nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=False), 76 | nn.GroupNorm(16, nf), 77 | nn.ReLU(inplace=True), 78 | nn.Conv2d(nf, nf, kernel_size=5, stride=1, padding=2, bias=False), 79 | nn.GroupNorm(16, nf), 80 | nn.ReLU(inplace=True), 81 | nn.Conv2d(nf, cout, kernel_size=5, stride=1, padding=2, bias=False)] 82 | if activation is not None: 83 | network += [activation()] 84 | self.network = nn.Sequential(*network) 85 | 86 | def forward(self, input): 87 | return self.network(input) 88 | 89 | 90 | class ConfNet(nn.Module): 91 | def __init__(self, cin, cout, zdim=128, nf=64): 92 | super(ConfNet, self).__init__() 93 | ## downsampling 94 | network = [ 95 | nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32 96 | nn.GroupNorm(16, nf), 97 | nn.LeakyReLU(0.2, inplace=True), 98 | nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16 99 | nn.GroupNorm(16*2, nf*2), 100 | nn.LeakyReLU(0.2, inplace=True), 101 | nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8 102 | nn.GroupNorm(16*4, nf*4), 103 | nn.LeakyReLU(0.2, inplace=True), 104 | nn.Conv2d(nf*4, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4 105 | nn.LeakyReLU(0.2, inplace=True), 106 | nn.Conv2d(nf*8, zdim, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 107 | nn.ReLU(inplace=True)] 108 | ## upsampling 109 | network += [ 110 | nn.ConvTranspose2d(zdim, nf*8, kernel_size=4, padding=0, bias=False), # 1x1 -> 4x4 111 | nn.ReLU(inplace=True), 112 | nn.ConvTranspose2d(nf*8, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 4x4 -> 8x8 113 | nn.GroupNorm(16*4, nf*4), 114 | nn.ReLU(inplace=True), 115 | nn.ConvTranspose2d(nf*4, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 16x16 116 | nn.GroupNorm(16*2, nf*2), 117 | nn.ReLU(inplace=True)] 118 | self.network = nn.Sequential(*network) 119 | 120 | out_net1 = [ 121 | nn.ConvTranspose2d(nf*2, nf, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 32x32 122 | nn.GroupNorm(16, nf), 123 | nn.ReLU(inplace=True), 124 | nn.ConvTranspose2d(nf, nf, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 64x64 125 | nn.GroupNorm(16, nf), 126 | nn.ReLU(inplace=True), 127 | nn.Conv2d(nf, 2, kernel_size=5, stride=1, padding=2, bias=False), # 64x64 128 | nn.Softplus()] 129 | self.out_net1 = nn.Sequential(*out_net1) 130 | 131 | out_net2 = [nn.Conv2d(nf*2, 2, kernel_size=3, stride=1, padding=1, bias=False), # 16x16 132 | nn.Softplus()] 133 | self.out_net2 = nn.Sequential(*out_net2) 134 | 135 | def forward(self, input): 136 | out = self.network(input) 137 | return self.out_net1(out), self.out_net2(out) 138 | 139 | 140 | class PerceptualLoss(nn.Module): 141 | def __init__(self, requires_grad=False): 142 | super(PerceptualLoss, self).__init__() 143 | mean_rgb = torch.FloatTensor([0.485, 0.456, 0.406]) 144 | std_rgb = torch.FloatTensor([0.229, 0.224, 0.225]) 145 | self.register_buffer('mean_rgb', mean_rgb) 146 | self.register_buffer('std_rgb', std_rgb) 147 | 148 | vgg_pretrained_features = torchvision.models.vgg16(pretrained=True).features 149 | self.slice1 = nn.Sequential() 150 | self.slice2 = nn.Sequential() 151 | self.slice3 = nn.Sequential() 152 | self.slice4 = nn.Sequential() 153 | for x in range(4): 154 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 155 | for x in range(4, 9): 156 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 157 | for x in range(9, 16): 158 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 159 | for x in range(16, 23): 160 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 161 | if not requires_grad: 162 | for param in self.parameters(): 163 | param.requires_grad = False 164 | 165 | def normalize(self, x): 166 | out = x/2 + 0.5 167 | out = (out - self.mean_rgb.view(1,3,1,1)) / self.std_rgb.view(1,3,1,1) 168 | return out 169 | 170 | def __call__(self, im1, im2, mask=None, conf_sigma=None): 171 | im = torch.cat([im1,im2], 0) 172 | im = self.normalize(im) # normalize input 173 | 174 | ## compute features 175 | feats = [] 176 | f = self.slice1(im) 177 | feats += [torch.chunk(f, 2, dim=0)] 178 | f = self.slice2(f) 179 | feats += [torch.chunk(f, 2, dim=0)] 180 | f = self.slice3(f) 181 | feats += [torch.chunk(f, 2, dim=0)] 182 | f = self.slice4(f) 183 | feats += [torch.chunk(f, 2, dim=0)] 184 | 185 | losses = [] 186 | for f1, f2 in feats[2:3]: # use relu3_3 features only 187 | loss = (f1-f2)**2 188 | if conf_sigma is not None: 189 | loss = loss / (2*conf_sigma**2 +EPS) + (conf_sigma +EPS).log() 190 | if mask is not None: 191 | b, c, h, w = loss.shape 192 | _, _, hm, wm = mask.shape 193 | sh, sw = hm//h, wm//w 194 | mask0 = nn.functional.avg_pool2d(mask, kernel_size=(sh,sw), stride=(sh,sw)).expand_as(loss) 195 | loss = (loss * mask0).sum() / mask0.sum() 196 | else: 197 | loss = loss.mean() 198 | losses += [loss] 199 | return sum(losses) 200 | -------------------------------------------------------------------------------- /unsup3d/renderer/__init__.py: -------------------------------------------------------------------------------- 1 | from .renderer import Renderer 2 | -------------------------------------------------------------------------------- /unsup3d/renderer/renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import neural_renderer as nr 4 | from .utils import * 5 | 6 | 7 | EPS = 1e-7 8 | 9 | 10 | class Renderer(): 11 | def __init__(self, cfgs): 12 | self.device = cfgs.get('device', 'cpu') 13 | self.image_size = cfgs.get('image_size', 64) 14 | self.min_depth = cfgs.get('min_depth', 0.9) 15 | self.max_depth = cfgs.get('max_depth', 1.1) 16 | self.rot_center_depth = cfgs.get('rot_center_depth', (self.min_depth+self.max_depth)/2) 17 | self.fov = cfgs.get('fov', 10) 18 | self.tex_cube_size = cfgs.get('tex_cube_size', 2) 19 | self.renderer_min_depth = cfgs.get('renderer_min_depth', 0.1) 20 | self.renderer_max_depth = cfgs.get('renderer_max_depth', 10.) 21 | 22 | #### camera intrinsics 23 | # (u) (x) 24 | # d * K^-1 (v) = (y) 25 | # (1) (z) 26 | 27 | ## renderer for visualization 28 | R = [[[1.,0.,0.], 29 | [0.,1.,0.], 30 | [0.,0.,1.]]] 31 | R = torch.FloatTensor(R).to(self.device) 32 | t = torch.zeros(1,3, dtype=torch.float32).to(self.device) 33 | fx = (self.image_size-1)/2/(math.tan(self.fov/2 *math.pi/180)) 34 | fy = (self.image_size-1)/2/(math.tan(self.fov/2 *math.pi/180)) 35 | cx = (self.image_size-1)/2 36 | cy = (self.image_size-1)/2 37 | K = [[fx, 0., cx], 38 | [0., fy, cy], 39 | [0., 0., 1.]] 40 | K = torch.FloatTensor(K).to(self.device) 41 | self.inv_K = torch.inverse(K).unsqueeze(0) 42 | self.K = K.unsqueeze(0) 43 | self.renderer = nr.Renderer(camera_mode='projection', 44 | light_intensity_ambient=1.0, 45 | light_intensity_directional=0., 46 | K=self.K, R=R, t=t, 47 | near=self.renderer_min_depth, far=self.renderer_max_depth, 48 | image_size=self.image_size, orig_size=self.image_size, 49 | fill_back=True, 50 | background_color=[1,1,1]) 51 | 52 | def set_transform_matrices(self, view): 53 | self.rot_mat, self.trans_xyz = get_transform_matrices(view) 54 | 55 | def rotate_pts(self, pts, rot_mat): 56 | centroid = torch.FloatTensor([0.,0.,self.rot_center_depth]).to(pts.device).view(1,1,3) 57 | pts = pts - centroid # move to centroid 58 | pts = pts.matmul(rot_mat.transpose(2,1)) # rotate 59 | pts = pts + centroid # move back 60 | return pts 61 | 62 | def translate_pts(self, pts, trans_xyz): 63 | return pts + trans_xyz 64 | 65 | def depth_to_3d_grid(self, depth): 66 | b, h, w = depth.shape 67 | grid_2d = get_grid(b, h, w, normalize=False).to(depth.device) # Nxhxwx2 68 | depth = depth.unsqueeze(-1) 69 | grid_3d = torch.cat((grid_2d, torch.ones_like(depth)), dim=3) 70 | grid_3d = grid_3d.matmul(self.inv_K.to(depth.device).transpose(2,1)) * depth 71 | return grid_3d 72 | 73 | def grid_3d_to_2d(self, grid_3d): 74 | b, h, w, _ = grid_3d.shape 75 | grid_2d = grid_3d / grid_3d[...,2:] 76 | grid_2d = grid_2d.matmul(self.K.to(grid_3d.device).transpose(2,1))[:,:,:,:2] 77 | WH = torch.FloatTensor([w-1, h-1]).to(grid_3d.device).view(1,1,1,2) 78 | grid_2d = grid_2d / WH *2.-1. # normalize to -1~1 79 | return grid_2d 80 | 81 | def get_warped_3d_grid(self, depth): 82 | b, h, w = depth.shape 83 | grid_3d = self.depth_to_3d_grid(depth).reshape(b,-1,3) 84 | grid_3d = self.rotate_pts(grid_3d, self.rot_mat) 85 | grid_3d = self.translate_pts(grid_3d, self.trans_xyz) 86 | return grid_3d.reshape(b,h,w,3) # return 3d vertices 87 | 88 | def get_inv_warped_3d_grid(self, depth): 89 | b, h, w = depth.shape 90 | grid_3d = self.depth_to_3d_grid(depth).reshape(b,-1,3) 91 | grid_3d = self.translate_pts(grid_3d, -self.trans_xyz) 92 | grid_3d = self.rotate_pts(grid_3d, self.rot_mat.transpose(2,1)) 93 | return grid_3d.reshape(b,h,w,3) # return 3d vertices 94 | 95 | def get_warped_2d_grid(self, depth): 96 | b, h, w = depth.shape 97 | grid_3d = self.get_warped_3d_grid(depth) 98 | grid_2d = self.grid_3d_to_2d(grid_3d) 99 | return grid_2d 100 | 101 | def get_inv_warped_2d_grid(self, depth): 102 | b, h, w = depth.shape 103 | grid_3d = self.get_inv_warped_3d_grid(depth) 104 | grid_2d = self.grid_3d_to_2d(grid_3d) 105 | return grid_2d 106 | 107 | def warp_canon_depth(self, canon_depth): 108 | b, h, w = canon_depth.shape 109 | grid_3d = self.get_warped_3d_grid(canon_depth).reshape(b,-1,3) 110 | faces = get_face_idx(b, h, w).to(canon_depth.device) 111 | warped_depth = self.renderer.render_depth(grid_3d, faces) 112 | 113 | # allow some margin out of valid range 114 | margin = (self.max_depth - self.min_depth) /2 115 | warped_depth = warped_depth.clamp(min=self.min_depth-margin, max=self.max_depth+margin) 116 | return warped_depth 117 | 118 | def get_normal_from_depth(self, depth): 119 | b, h, w = depth.shape 120 | grid_3d = self.depth_to_3d_grid(depth) 121 | 122 | tu = grid_3d[:,1:-1,2:] - grid_3d[:,1:-1,:-2] 123 | tv = grid_3d[:,2:,1:-1] - grid_3d[:,:-2,1:-1] 124 | normal = tu.cross(tv, dim=3) 125 | 126 | zero = torch.FloatTensor([0,0,1]).to(depth.device) 127 | normal = torch.cat([zero.repeat(b,h-2,1,1), normal, zero.repeat(b,h-2,1,1)], 2) 128 | normal = torch.cat([zero.repeat(b,1,w,1), normal, zero.repeat(b,1,w,1)], 1) 129 | normal = normal / (((normal**2).sum(3, keepdim=True))**0.5 + EPS) 130 | return normal 131 | 132 | def render_yaw(self, im, depth, v_before=None, v_after=None, rotations=None, maxr=90, nsample=9, crop_mesh=None): 133 | b, c, h, w = im.shape 134 | grid_3d = self.depth_to_3d_grid(depth) 135 | 136 | if crop_mesh is not None: 137 | top, bottom, left, right = crop_mesh # pixels from border to be cropped 138 | if top > 0: 139 | grid_3d[:,:top,:,1] = grid_3d[:,top:top+1,:,1].repeat(1,top,1) 140 | grid_3d[:,:top,:,2] = grid_3d[:,top:top+1,:,2].repeat(1,top,1) 141 | if bottom > 0: 142 | grid_3d[:,-bottom:,:,1] = grid_3d[:,-bottom-1:-bottom,:,1].repeat(1,bottom,1) 143 | grid_3d[:,-bottom:,:,2] = grid_3d[:,-bottom-1:-bottom,:,2].repeat(1,bottom,1) 144 | if left > 0: 145 | grid_3d[:,:,:left,0] = grid_3d[:,:,left:left+1,0].repeat(1,1,left) 146 | grid_3d[:,:,:left,2] = grid_3d[:,:,left:left+1,2].repeat(1,1,left) 147 | if right > 0: 148 | grid_3d[:,:,-right:,0] = grid_3d[:,:,-right-1:-right,0].repeat(1,1,right) 149 | grid_3d[:,:,-right:,2] = grid_3d[:,:,-right-1:-right,2].repeat(1,1,right) 150 | 151 | grid_3d = grid_3d.reshape(b,-1,3) 152 | im_trans = [] 153 | 154 | # inverse warp 155 | if v_before is not None: 156 | rot_mat, trans_xyz = get_transform_matrices(v_before) 157 | grid_3d = self.translate_pts(grid_3d, -trans_xyz) 158 | grid_3d = self.rotate_pts(grid_3d, rot_mat.transpose(2,1)) 159 | 160 | if rotations is None: 161 | rotations = torch.linspace(-math.pi/180*maxr, math.pi/180*maxr, nsample) 162 | for i, ri in enumerate(rotations): 163 | ri = torch.FloatTensor([0, ri, 0]).to(im.device).view(1,3) 164 | rot_mat_i, _ = get_transform_matrices(ri) 165 | grid_3d_i = self.rotate_pts(grid_3d, rot_mat_i.repeat(b,1,1)) 166 | 167 | if v_after is not None: 168 | if len(v_after.shape) == 3: 169 | v_after_i = v_after[i] 170 | else: 171 | v_after_i = v_after 172 | rot_mat, trans_xyz = get_transform_matrices(v_after_i) 173 | grid_3d_i = self.rotate_pts(grid_3d_i, rot_mat) 174 | grid_3d_i = self.translate_pts(grid_3d_i, trans_xyz) 175 | 176 | faces = get_face_idx(b, h, w).to(im.device) 177 | textures = get_textures_from_im(im, tx_size=self.tex_cube_size) 178 | warped_images = self.renderer.render_rgb(grid_3d_i, faces, textures).clamp(min=-1., max=1.) 179 | im_trans += [warped_images] 180 | return torch.stack(im_trans, 1) # b x t x c x h x w 181 | -------------------------------------------------------------------------------- /unsup3d/renderer/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mm_normalize(x, min=0, max=1): 5 | x_min = x.min() 6 | x_max = x.max() 7 | x_range = x_max - x_min 8 | x_z = (x - x_min) / x_range 9 | x_out = x_z * (max - min) + min 10 | return x_out 11 | 12 | 13 | def rand_range(size, min, max): 14 | return torch.rand(size)*(max-min)+min 15 | 16 | 17 | def rand_posneg_range(size, min, max): 18 | i = (torch.rand(size) > 0.5).type(torch.float)*2.-1. 19 | return i*rand_range(size, min, max) 20 | 21 | 22 | def get_grid(b, H, W, normalize=True): 23 | if normalize: 24 | h_range = torch.linspace(-1,1,H) 25 | w_range = torch.linspace(-1,1,W) 26 | else: 27 | h_range = torch.arange(0,H) 28 | w_range = torch.arange(0,W) 29 | grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).repeat(b,1,1,1).flip(3).float() # flip h,w to x,y 30 | return grid 31 | 32 | 33 | def get_rotation_matrix(tx, ty, tz): 34 | m_x = torch.zeros((len(tx), 3, 3)).to(tx.device) 35 | m_y = torch.zeros((len(tx), 3, 3)).to(tx.device) 36 | m_z = torch.zeros((len(tx), 3, 3)).to(tx.device) 37 | 38 | m_x[:, 1, 1], m_x[:, 1, 2] = tx.cos(), -tx.sin() 39 | m_x[:, 2, 1], m_x[:, 2, 2] = tx.sin(), tx.cos() 40 | m_x[:, 0, 0] = 1 41 | 42 | m_y[:, 0, 0], m_y[:, 0, 2] = ty.cos(), ty.sin() 43 | m_y[:, 2, 0], m_y[:, 2, 2] = -ty.sin(), ty.cos() 44 | m_y[:, 1, 1] = 1 45 | 46 | m_z[:, 0, 0], m_z[:, 0, 1] = tz.cos(), -tz.sin() 47 | m_z[:, 1, 0], m_z[:, 1, 1] = tz.sin(), tz.cos() 48 | m_z[:, 2, 2] = 1 49 | return torch.matmul(m_z, torch.matmul(m_y, m_x)) 50 | 51 | 52 | def get_transform_matrices(view): 53 | b = view.size(0) 54 | if view.size(1) == 6: 55 | rx = view[:,0] 56 | ry = view[:,1] 57 | rz = view[:,2] 58 | trans_xyz = view[:,3:].reshape(b,1,3) 59 | elif view.size(1) == 5: 60 | rx = view[:,0] 61 | ry = view[:,1] 62 | rz = view[:,2] 63 | delta_xy = view[:,3:].reshape(b,1,2) 64 | trans_xyz = torch.cat([delta_xy, torch.zeros(b,1,1).to(view.device)], 2) 65 | elif view.size(1) == 3: 66 | rx = view[:,0] 67 | ry = view[:,1] 68 | rz = view[:,2] 69 | trans_xyz = torch.zeros(b,1,3).to(view.device) 70 | rot_mat = get_rotation_matrix(rx, ry, rz) 71 | return rot_mat, trans_xyz 72 | 73 | 74 | def get_face_idx(b, h, w): 75 | idx_map = torch.arange(h*w).reshape(h,w) 76 | faces1 = torch.stack([idx_map[:h-1,:w-1], idx_map[1:,:w-1], idx_map[:h-1,1:]], -1).reshape(-1,3) 77 | faces2 = torch.stack([idx_map[:h-1,1:], idx_map[1:,:w-1], idx_map[1:,1:]], -1).reshape(-1,3) 78 | return torch.cat([faces1,faces2], 0).repeat(b,1,1).int() 79 | 80 | 81 | def vcolor_to_texture_cube(vcolors): 82 | # input bxcxnx3 83 | b, c, n, f = vcolors.shape 84 | coeffs = torch.FloatTensor( 85 | [[ 0.5, 0.5, 0.5], 86 | [ 0. , 0. , 1. ], 87 | [ 0. , 1. , 0. ], 88 | [-0.5, 0.5, 0.5], 89 | [ 1. , 0. , 0. ], 90 | [ 0.5, -0.5, 0.5], 91 | [ 0.5, 0.5, -0.5], 92 | [ 0. , 0. , 0. ]]).to(vcolors.device) 93 | return coeffs.matmul(vcolors.permute(0,2,3,1)).reshape(b,n,2,2,2,c) 94 | 95 | 96 | def get_textures_from_im(im, tx_size=1): 97 | b, c, h, w = im.shape 98 | if tx_size == 1: 99 | textures = torch.cat([im[:,:,:h-1,:w-1].reshape(b,c,-1), im[:,:,1:,1:].reshape(b,c,-1)], 2) 100 | textures = textures.transpose(2,1).reshape(b,-1,1,1,1,c) 101 | elif tx_size == 2: 102 | textures1 = torch.stack([im[:,:,:h-1,:w-1], im[:,:,:h-1,1:], im[:,:,1:,:w-1]], -1).reshape(b,c,-1,3) 103 | textures2 = torch.stack([im[:,:,1:,:w-1], im[:,:,:h-1,1:], im[:,:,1:,1:]], -1).reshape(b,c,-1,3) 104 | textures = vcolor_to_texture_cube(torch.cat([textures1, textures2], 2)) # bxnx2x2x2xc 105 | else: 106 | raise NotImplementedError("Currently support texture size of 1 or 2 only.") 107 | return textures 108 | -------------------------------------------------------------------------------- /unsup3d/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from datetime import datetime 4 | import numpy as np 5 | import torch 6 | from . import meters 7 | from . import utils 8 | from .dataloaders import get_data_loaders 9 | 10 | 11 | class Trainer(): 12 | def __init__(self, cfgs, model): 13 | self.device = cfgs.get('device', 'cpu') 14 | self.num_epochs = cfgs.get('num_epochs', 30) 15 | self.batch_size = cfgs.get('batch_size', 64) 16 | self.checkpoint_dir = cfgs.get('checkpoint_dir', 'results') 17 | self.save_checkpoint_freq = cfgs.get('save_checkpoint_freq', 1) 18 | self.keep_num_checkpoint = cfgs.get('keep_num_checkpoint', 2) # -1 for keeping all checkpoints 19 | self.resume = cfgs.get('resume', True) 20 | self.use_logger = cfgs.get('use_logger', True) 21 | self.log_freq = cfgs.get('log_freq', 1000) 22 | self.archive_code = cfgs.get('archive_code', True) 23 | self.checkpoint_name = cfgs.get('checkpoint_name', None) 24 | self.test_result_dir = cfgs.get('test_result_dir', None) 25 | self.cfgs = cfgs 26 | 27 | self.metrics_trace = meters.MetricsTrace() 28 | self.make_metrics = lambda m=None: meters.StandardMetrics(m) 29 | self.model = model(cfgs) 30 | self.model.trainer = self 31 | self.train_loader, self.val_loader, self.test_loader = get_data_loaders(cfgs) 32 | 33 | def load_checkpoint(self, optim=True): 34 | """Search the specified/latest checkpoint in checkpoint_dir and load the model and optimizer.""" 35 | if self.checkpoint_name is not None: 36 | checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name) 37 | else: 38 | checkpoints = sorted(glob.glob(os.path.join(self.checkpoint_dir, '*.pth'))) 39 | if len(checkpoints) == 0: 40 | return 0 41 | checkpoint_path = checkpoints[-1] 42 | self.checkpoint_name = os.path.basename(checkpoint_path) 43 | print(f"Loading checkpoint from {checkpoint_path}") 44 | cp = torch.load(checkpoint_path, map_location=self.device) 45 | self.model.load_model_state(cp) 46 | if optim: 47 | self.model.load_optimizer_state(cp) 48 | self.metrics_trace = cp['metrics_trace'] 49 | epoch = cp['epoch'] 50 | return epoch 51 | 52 | def save_checkpoint(self, epoch, optim=True): 53 | """Save model, optimizer, and metrics state to a checkpoint in checkpoint_dir for the specified epoch.""" 54 | utils.xmkdir(self.checkpoint_dir) 55 | checkpoint_path = os.path.join(self.checkpoint_dir, f'checkpoint{epoch:03}.pth') 56 | state_dict = self.model.get_model_state() 57 | if optim: 58 | optimizer_state = self.model.get_optimizer_state() 59 | state_dict = {**state_dict, **optimizer_state} 60 | state_dict['metrics_trace'] = self.metrics_trace 61 | state_dict['epoch'] = epoch 62 | print(f"Saving checkpoint to {checkpoint_path}") 63 | torch.save(state_dict, checkpoint_path) 64 | if self.keep_num_checkpoint > 0: 65 | utils.clean_checkpoint(self.checkpoint_dir, keep_num=self.keep_num_checkpoint) 66 | 67 | def save_clean_checkpoint(self, path): 68 | """Save model state only to specified path.""" 69 | torch.save(self.model.get_model_state(), path) 70 | 71 | def test(self): 72 | """Perform testing.""" 73 | self.model.to_device(self.device) 74 | self.current_epoch = self.load_checkpoint(optim=False) 75 | if self.test_result_dir is None: 76 | self.test_result_dir = os.path.join(self.checkpoint_dir, f'test_results_{self.checkpoint_name}'.replace('.pth','')) 77 | print(f"Saving testing results to {self.test_result_dir}") 78 | 79 | with torch.no_grad(): 80 | m = self.run_epoch(self.test_loader, epoch=self.current_epoch, is_test=True) 81 | 82 | score_path = os.path.join(self.test_result_dir, 'eval_scores.txt') 83 | self.model.save_scores(score_path) 84 | 85 | def train(self): 86 | """Perform training.""" 87 | ## archive code and configs 88 | if self.archive_code: 89 | utils.archive_code(os.path.join(self.checkpoint_dir, 'archived_code.zip'), filetypes=['.py', '.yml']) 90 | utils.dump_yaml(os.path.join(self.checkpoint_dir, 'configs.yml'), self.cfgs) 91 | 92 | ## initialize 93 | start_epoch = 0 94 | self.metrics_trace.reset() 95 | self.train_iter_per_epoch = len(self.train_loader) 96 | self.model.to_device(self.device) 97 | self.model.init_optimizers() 98 | 99 | ## resume from checkpoint 100 | if self.resume: 101 | start_epoch = self.load_checkpoint(optim=True) 102 | 103 | ## initialize tensorboardX logger 104 | if self.use_logger: 105 | from tensorboardX import SummaryWriter 106 | self.logger = SummaryWriter(os.path.join(self.checkpoint_dir, 'logs', datetime.now().strftime("%Y%m%d-%H%M%S"))) 107 | 108 | ## cache one batch for visualization 109 | self.viz_input = self.val_loader.__iter__().__next__() 110 | 111 | ## run epochs 112 | print(f"{self.model.model_name}: optimizing to {self.num_epochs} epochs") 113 | for epoch in range(start_epoch, self.num_epochs): 114 | self.current_epoch = epoch 115 | metrics = self.run_epoch(self.train_loader, epoch) 116 | self.metrics_trace.append("train", metrics) 117 | 118 | with torch.no_grad(): 119 | metrics = self.run_epoch(self.val_loader, epoch, is_validation=True) 120 | self.metrics_trace.append("val", metrics) 121 | 122 | if (epoch+1) % self.save_checkpoint_freq == 0: 123 | self.save_checkpoint(epoch+1, optim=True) 124 | self.metrics_trace.plot(pdf_path=os.path.join(self.checkpoint_dir, 'metrics.pdf')) 125 | self.metrics_trace.save(os.path.join(self.checkpoint_dir, 'metrics.json')) 126 | 127 | print(f"Training completed after {epoch+1} epochs.") 128 | 129 | def run_epoch(self, loader, epoch=0, is_validation=False, is_test=False): 130 | """Run one epoch.""" 131 | is_train = not is_validation and not is_test 132 | metrics = self.make_metrics() 133 | 134 | if is_train: 135 | print(f"Starting training epoch {epoch}") 136 | self.model.set_train() 137 | else: 138 | print(f"Starting validation epoch {epoch}") 139 | self.model.set_eval() 140 | 141 | for iter, input in enumerate(loader): 142 | m = self.model.forward(input) 143 | if is_train: 144 | self.model.backward() 145 | elif is_test: 146 | self.model.save_results(self.test_result_dir) 147 | 148 | metrics.update(m, self.batch_size) 149 | print(f"{'T' if is_train else 'V'}{epoch:02}/{iter:05}/{metrics}") 150 | 151 | if self.use_logger and is_train: 152 | total_iter = iter + epoch*self.train_iter_per_epoch 153 | if total_iter % self.log_freq == 0: 154 | self.model.forward(self.viz_input) 155 | self.model.visualize(self.logger, total_iter=total_iter, max_bs=25) 156 | return metrics 157 | -------------------------------------------------------------------------------- /unsup3d/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | import yaml 5 | import random 6 | import numpy as np 7 | import cv2 8 | import torch 9 | import zipfile 10 | 11 | 12 | def setup_runtime(args): 13 | """Load configs, initialize CUDA, CuDNN and the random seeds.""" 14 | 15 | # Setup CUDA 16 | cuda_device_id = args.gpu 17 | if cuda_device_id is not None: 18 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 19 | os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda_device_id) 20 | if torch.cuda.is_available(): 21 | torch.backends.cudnn.enabled = True 22 | torch.backends.cudnn.benchmark = True 23 | torch.backends.cudnn.deterministic = True 24 | 25 | # Setup random seeds for reproducibility 26 | random.seed(args.seed) 27 | np.random.seed(args.seed) 28 | torch.manual_seed(args.seed) 29 | if torch.cuda.is_available(): 30 | torch.cuda.manual_seed_all(args.seed) 31 | 32 | ## Load config 33 | cfgs = {} 34 | if args.config is not None and os.path.isfile(args.config): 35 | cfgs = load_yaml(args.config) 36 | 37 | cfgs['config'] = args.config 38 | cfgs['seed'] = args.seed 39 | cfgs['num_workers'] = args.num_workers 40 | cfgs['device'] = 'cuda:0' if torch.cuda.is_available() and cuda_device_id is not None else 'cpu' 41 | 42 | print(f"Environment: GPU {cuda_device_id} seed {args.seed} number of workers {args.num_workers}") 43 | return cfgs 44 | 45 | 46 | def load_yaml(path): 47 | print(f"Loading configs from {path}") 48 | with open(path, 'r') as f: 49 | return yaml.safe_load(f) 50 | 51 | 52 | def dump_yaml(path, cfgs): 53 | print(f"Saving configs to {path}") 54 | xmkdir(os.path.dirname(path)) 55 | with open(path, 'w') as f: 56 | return yaml.safe_dump(cfgs, f) 57 | 58 | 59 | def xmkdir(path): 60 | """Create directory PATH recursively if it does not exist.""" 61 | os.makedirs(path, exist_ok=True) 62 | 63 | 64 | def clean_checkpoint(checkpoint_dir, keep_num=2): 65 | if keep_num > 0: 66 | names = list(sorted( 67 | glob.glob(os.path.join(checkpoint_dir, 'checkpoint*.pth')) 68 | )) 69 | if len(names) > keep_num: 70 | for name in names[:-keep_num]: 71 | print(f"Deleting obslete checkpoint file {name}") 72 | os.remove(name) 73 | 74 | 75 | def archive_code(arc_path, filetypes=['.py', '.yml']): 76 | print(f"Archiving code to {arc_path}") 77 | xmkdir(os.path.dirname(arc_path)) 78 | zipf = zipfile.ZipFile(arc_path, 'w', zipfile.ZIP_DEFLATED) 79 | cur_dir = os.getcwd() 80 | flist = [] 81 | for ftype in filetypes: 82 | flist.extend(glob.glob(os.path.join(cur_dir, '**', '*'+ftype), recursive=True)) 83 | [zipf.write(f, arcname=f.replace(cur_dir,'archived_code', 1)) for f in flist] 84 | zipf.close() 85 | 86 | 87 | def get_model_device(model): 88 | return next(model.parameters()).device 89 | 90 | 91 | def set_requires_grad(nets, requires_grad=False): 92 | if not isinstance(nets, list): 93 | nets = [nets] 94 | for net in nets: 95 | if net is not None: 96 | for param in net.parameters(): 97 | param.requires_grad = requires_grad 98 | 99 | 100 | def draw_bbox(im, size): 101 | b, c, h, w = im.shape 102 | h2, w2 = (h-size)//2, (w-size)//2 103 | marker = np.tile(np.array([[1.],[0.],[0.]]), (1,size)) 104 | marker = torch.FloatTensor(marker) 105 | im[:, :, h2, w2:w2+size] = marker 106 | im[:, :, h2+size, w2:w2+size] = marker 107 | im[:, :, h2:h2+size, w2] = marker 108 | im[:, :, h2:h2+size, w2+size] = marker 109 | return im 110 | 111 | 112 | def save_videos(out_fold, imgs, prefix='', suffix='', sep_folder=True, ext='.mp4', cycle=False): 113 | if sep_folder: 114 | out_fold = os.path.join(out_fold, suffix) 115 | xmkdir(out_fold) 116 | prefix = prefix + '_' if prefix else '' 117 | suffix = '_' + suffix if suffix else '' 118 | offset = len(glob.glob(os.path.join(out_fold, prefix+'*'+suffix+ext))) +1 119 | 120 | imgs = imgs.transpose(0,1,3,4,2) # BxTxCxHxW -> BxTxHxWxC 121 | for i, fs in enumerate(imgs): 122 | if cycle: 123 | fs = np.concatenate([fs, fs[::-1]], 0) 124 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 125 | # fourcc = cv2.VideoWriter_fourcc(*'avc1') 126 | vid = cv2.VideoWriter(os.path.join(out_fold, prefix+'%05d'%(i+offset)+suffix+ext), fourcc, 5, (fs.shape[2], fs.shape[1])) 127 | [vid.write(np.uint8(f[...,::-1]*255.)) for f in fs] 128 | vid.release() 129 | 130 | 131 | def save_images(out_fold, imgs, prefix='', suffix='', sep_folder=True, ext='.png'): 132 | if sep_folder: 133 | out_fold = os.path.join(out_fold, suffix) 134 | xmkdir(out_fold) 135 | prefix = prefix + '_' if prefix else '' 136 | suffix = '_' + suffix if suffix else '' 137 | offset = len(glob.glob(os.path.join(out_fold, prefix+'*'+suffix+ext))) +1 138 | 139 | imgs = imgs.transpose(0,2,3,1) 140 | for i, img in enumerate(imgs): 141 | if 'depth' in suffix: 142 | im_out = np.uint16(img[...,::-1]*65535.) 143 | else: 144 | im_out = np.uint8(img[...,::-1]*255.) 145 | cv2.imwrite(os.path.join(out_fold, prefix+'%05d'%(i+offset)+suffix+ext), im_out) 146 | 147 | 148 | def save_txt(out_fold, data, prefix='', suffix='', sep_folder=True, ext='.txt'): 149 | if sep_folder: 150 | out_fold = os.path.join(out_fold, suffix) 151 | xmkdir(out_fold) 152 | prefix = prefix + '_' if prefix else '' 153 | suffix = '_' + suffix if suffix else '' 154 | offset = len(glob.glob(os.path.join(out_fold, prefix+'*'+suffix+ext))) +1 155 | 156 | [np.savetxt(os.path.join(out_fold, prefix+'%05d'%(i+offset)+suffix+ext), d, fmt='%.6f', delimiter=', ') for i,d in enumerate(data)] 157 | 158 | 159 | def compute_sc_inv_err(d_pred, d_gt, mask=None): 160 | b = d_pred.size(0) 161 | diff = d_pred - d_gt 162 | if mask is not None: 163 | diff = diff * mask 164 | avg = diff.view(b, -1).sum(1) / (mask.view(b, -1).sum(1)) 165 | score = (diff - avg.view(b,1,1))**2 * mask 166 | else: 167 | avg = diff.view(b, -1).mean(1) 168 | score = (diff - avg.view(b,1,1))**2 169 | return score # masked error maps 170 | 171 | 172 | def compute_angular_distance(n1, n2, mask=None): 173 | dist = (n1*n2).sum(3).clamp(-1,1).acos() /np.pi*180 174 | return dist*mask if mask is not None else dist 175 | 176 | 177 | def save_scores(out_path, scores, header=''): 178 | print('Saving scores to %s' %out_path) 179 | np.savetxt(out_path, scores, fmt='%.8f', delimiter=',\t', header=header) 180 | --------------------------------------------------------------------------------