├── .gitignore
├── LICENSE
├── README.md
├── data
├── celeba_crop.py
├── celeba_crop_bbox.txt
├── download_cat.sh
├── download_syncar.sh
└── download_synface.sh
├── demo
├── demo.py
├── images
│ ├── cat_face
│ │ ├── 001_cat.png
│ │ ├── 002_cat.png
│ │ ├── 003_cat.png
│ │ ├── 004_cat.png
│ │ ├── 005_cat.png
│ │ ├── 006_cat.png
│ │ ├── 007_cat.png
│ │ ├── 008_cat.png
│ │ ├── 009_cat.png
│ │ ├── 010_cat.png
│ │ ├── 011_cat.png
│ │ ├── 012_cat.png
│ │ ├── 013_cat.png
│ │ ├── 014_cat.png
│ │ ├── 015_cat.png
│ │ ├── 016_cat.png
│ │ ├── 017_cat.png
│ │ ├── 018_cat.png
│ │ ├── 019_cat.png
│ │ ├── 020_cat.png
│ │ ├── 021_cat.png
│ │ ├── 022_cat.png
│ │ ├── 023_cat.png
│ │ ├── 024_abstract.png
│ │ ├── 025_abstract.png
│ │ ├── 026_abstract.png
│ │ ├── 027_abstract.png
│ │ ├── 028_abstract.png
│ │ ├── 029_abstract.png
│ │ ├── 030_abstract.png
│ │ ├── 031_abstract.png
│ │ ├── 032_abstract.png
│ │ ├── 033_abstract.png
│ │ ├── 034_abstract.png
│ │ ├── 035_abstract.png
│ │ ├── 036_abstract.png
│ │ ├── 037_abstract.png
│ │ ├── 038_abstract.png
│ │ ├── 039_abstract.png
│ │ ├── 040_abstract.png
│ │ ├── 041_abstract.png
│ │ ├── 042_abstract.png
│ │ ├── 043_abstract.png
│ │ ├── 044_abstract.png
│ │ └── 045_abstract.png
│ └── human_face
│ │ ├── 001_face.png
│ │ ├── 002_face.png
│ │ ├── 003_face.png
│ │ ├── 004_face.png
│ │ ├── 005_face.png
│ │ ├── 006_face.png
│ │ ├── 007_face.png
│ │ ├── 008_face.png
│ │ ├── 009_face.png
│ │ ├── 010_face.png
│ │ ├── 011_face.png
│ │ ├── 012_face.png
│ │ ├── 013_face.png
│ │ ├── 014_face.png
│ │ ├── 015_face.png
│ │ ├── 016_paint.png
│ │ ├── 017_paint.png
│ │ ├── 018_paint.png
│ │ ├── 019_paint.png
│ │ ├── 020_paint.png
│ │ ├── 021_paint.png
│ │ ├── 022_paint.png
│ │ ├── 023_paint.png
│ │ ├── 024_paint.png
│ │ ├── 025_paint.png
│ │ ├── 026_paint.png
│ │ ├── 027_paint.png
│ │ ├── 028_paint.png
│ │ ├── 029_paint.png
│ │ ├── 030_paint.png
│ │ ├── 031_abstract.png
│ │ ├── 032_abstract.png
│ │ ├── 033_abstract.png
│ │ ├── 034_abstract.png
│ │ ├── 035_abstract.png
│ │ ├── 036_abstract.png
│ │ ├── 037_abstract.png
│ │ ├── 038_abstract.png
│ │ ├── 039_abstract.png
│ │ ├── 040_abstract.png
│ │ ├── 041_abstract.png
│ │ ├── 042_abstract.png
│ │ ├── 043_abstract.png
│ │ ├── 044_abstract.png
│ │ └── 045_abstract.png
└── utils.py
├── environment.yml
├── experiments
├── test_cat.yml
├── test_celeba.yml
├── test_syncar.yml
├── test_synface.yml
├── train_cat.yml
├── train_celeba.yml
├── train_syncar.yml
└── train_synface.yml
├── img
└── teaser.jpg
├── pretrained
├── download_pretrained_cat.sh
├── download_pretrained_celeba.sh
├── download_pretrained_syncar.sh
└── download_pretrained_synface.sh
├── run.py
└── unsup3d
├── __init__.py
├── dataloaders.py
├── meters.py
├── model.py
├── networks.py
├── renderer
├── __init__.py
├── renderer.py
└── utils.py
├── trainer.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | data/*/
3 | pretrained/*/
4 | results
5 | neural_renderer
6 | *.zip
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 elliottwu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Unsupervised Learning of Probably Symmetric Deformable 3D Objects from Images in the Wild
2 | #### [Demo](http://www.robots.ox.ac.uk/~vgg/blog/unsupervised-learning-of-probably-symmetric-deformable-3d-objects-from-images-in-the-wild.html) | [Project Page](https://elliottwu.com/projects/unsup3d/) | [Video](https://www.youtube.com/watch?v=5rPJyrU-WE4) | [Paper](https://arxiv.org/abs/1911.11130)
3 | [Shangzhe Wu](https://elliottwu.com/), [Christian Rupprecht](https://chrirupp.github.io/), [Andrea Vedaldi](http://www.robots.ox.ac.uk/~vedaldi/), Visual Geometry Group, University of Oxford. In CVPR 2020 (Best Paper Award).
4 |
5 |
6 |
7 | We propose a method to learn weakly symmetric deformable 3D object categories from raw single-view images, without ground-truth 3D, multiple views, 2D/3D keypoints, prior shape models or any other supervision.
8 |
9 |
10 | ## Setup (with [Anaconda](https://www.anaconda.com/))
11 |
12 | ### 1. Install dependencies:
13 | ```
14 | conda env create -f environment.yml
15 | ```
16 | OR manually:
17 | ```
18 | conda install -c conda-forge scikit-image matplotlib opencv moviepy pyyaml tensorboardX
19 | ```
20 |
21 |
22 | ### 2. Install [PyTorch](https://pytorch.org/):
23 | ```
24 | conda install pytorch==1.2.0 torchvision==0.4.0 cudatoolkit=9.2 -c pytorch
25 | ```
26 | *Note*: The code is tested with PyTorch 1.2.0 and CUDA 9.2 on CentOS 7. A GPU version is required for training and testing, since the [neural_renderer](https://github.com/daniilidis-group/neural_renderer) package only has GPU implementation. You are still able to run the demo without GPU.
27 |
28 |
29 | ### 3. Install [neural_renderer](https://github.com/daniilidis-group/neural_renderer):
30 | This package is required for training and testing, and optional for the demo. It requires a GPU device and GPU-enabled PyTorch.
31 | ```
32 | pip install neural_renderer_pytorch
33 | ```
34 |
35 | *Note*: It may fail if you have a GCC version below 5. If you do not want to upgrade your GCC, one alternative solution is to use conda's GCC and compile the package from source. For example:
36 | ```
37 | conda install gxx_linux-64=7.3
38 | git clone https://github.com/daniilidis-group/neural_renderer.git
39 | cd neural_renderer
40 | python setup.py install
41 | ```
42 |
43 |
44 | ### 4. (For demo only) Install [facenet-pytorch](https://github.com/timesler/facenet-pytorch):
45 | This package is optional for the demo. It allows automatic human face detection.
46 | ```
47 | pip install facenet-pytorch
48 | ```
49 |
50 |
51 | ## Datasets
52 | 1. [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) face dataset. Please download the original images (`img_celeba.7z`) from their [website](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) and run `celeba_crop.py` in `data/` to crop the images.
53 | 2. Synthetic face dataset generated using [Basel Face Model](https://faces.dmi.unibas.ch/bfm/). This can be downloaded using the script `download_synface.sh` provided in `data/`.
54 | 3. Cat face dataset composed of [Cat Head Dataset](http://academictorrents.com/details/c501571c29d16d7f41d159d699d0e7fb37092cbd) and [Oxford-IIIT Pet Dataset](http://www.robots.ox.ac.uk/~vgg/data/pets/) ([license](https://creativecommons.org/licenses/by-sa/4.0/)). This can be downloaded using the script `download_cat.sh` provided in `data/`.
55 | 4. Synthetic car dataset generated from [ShapeNet](https://shapenet.org/) cars. The images are rendered from with random viewpoints from the top, where the cars are primarily oriented vertically. This can be downloaded using the script `download_syncar.sh` provided in `data/`.
56 |
57 | Please remember to cite the corresponding papers if you use these datasets.
58 |
59 |
60 | ## Pretrained Models
61 | Download pretrained models using the scripts provided in `pretrained/`, eg:
62 | ```
63 | cd pretrained && sh download_pretrained_celeba.sh
64 | ```
65 |
66 |
67 | ## Demo
68 | ```
69 | python -m demo.demo --input demo/images/human_face --result demo/results/human_face --checkpoint pretrained/pretrained_celeba/checkpoint030.pth
70 | ```
71 |
72 | *Options*:
73 | - `--gpu`: enable GPU
74 | - `--detect_human_face`: enable automatic human face detection and cropping using [MTCNN](https://arxiv.org/abs/1604.02878) provided in [facenet-pytorch](https://github.com/timesler/facenet-pytorch). This only works on human face images. You will need to manually crop the images for other objects.
75 | - `--render_video`: render 3D animations using [neural_renderer](https://github.com/daniilidis-group/neural_renderer) (GPU is required)
76 |
77 |
78 | ## Training and Testing
79 | Check the configuration files in `experiments/` and run experiments, eg:
80 | ```
81 | python run.py --config experiments/train_celeba.yml --gpu 0 --num_workers 4
82 | ```
83 |
84 |
85 | ## Citation
86 | ```
87 | @InProceedings{Wu_2020_CVPR,
88 | author = {Shangzhe Wu and Christian Rupprecht and Andrea Vedaldi},
89 | title = {Unsupervised Learning of Probably Symmetric Deformable 3D Objects from Images in the Wild},
90 | booktitle = {CVPR},
91 | year = {2020}
92 | }
93 | ```
94 |
--------------------------------------------------------------------------------
/data/celeba_crop.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 |
5 |
6 | im_dir = './img_celeba'
7 | out_dir = './celeba_cropped'
8 | bbox_fpath = './celeba_crop_bbox.txt'
9 | out_im_size = 128
10 |
11 | im_list = np.loadtxt(bbox_fpath, dtype='str')
12 | total_num = len(im_list)
13 | split_dict = {'0': 'train',
14 | '1': 'val',
15 | '2': 'test'}
16 |
17 | for i, row in enumerate(im_list):
18 | if i%1000 == 0:
19 | print(f'{i}/{total_num}')
20 |
21 | fname = row[0]
22 | split = row[1]
23 | x0, y0, w, h = row[2:].astype(int)
24 | im = cv2.imread(os.path.join(im_dir, fname))
25 | im_pad = cv2.copyMakeBorder(im, h, h, w, w, cv2.BORDER_REPLICATE) # allow cropping outside by replicating borders
26 | im_crop = im_pad[y0+h:y0+h*2, x0+w:x0+w*2]
27 | im_crop = cv2.resize(im_crop, (out_im_size,out_im_size))
28 |
29 | out_folder = os.path.join(out_dir, split_dict[split])
30 | os.makedirs(out_folder, exist_ok=True)
31 | cv2.imwrite(os.path.join(out_folder, fname), im_crop)
32 |
--------------------------------------------------------------------------------
/data/download_cat.sh:
--------------------------------------------------------------------------------
1 | echo "----------------------- downloading cat dataset -----------------------"
2 | curl -o cat_combined.zip "https://www.robots.ox.ac.uk/~vgg/research/unsup3d/data/cat_combined.zip" && unzip cat_combined.zip
3 |
--------------------------------------------------------------------------------
/data/download_syncar.sh:
--------------------------------------------------------------------------------
1 | echo "----------------------- downloading synthetic car dataset -----------------------"
2 | curl -o syncar.zip "https://www.robots.ox.ac.uk/~vgg/research/unsup3d/data/syncar.zip" && unzip syncar.zip
3 |
--------------------------------------------------------------------------------
/data/download_synface.sh:
--------------------------------------------------------------------------------
1 | echo "----------------------- downloading synthetic face dataset -----------------------"
2 | curl -o synface.zip "https://www.robots.ox.ac.uk/~vgg/research/unsup3d/data/synface.zip" && unzip synface.zip
3 |
--------------------------------------------------------------------------------
/demo/demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | from PIL import Image
4 | import torch
5 | import torch.nn as nn
6 | from .utils import *
7 |
8 |
9 | EPS = 1e-7
10 |
11 |
12 | class Demo():
13 | def __init__(self, args):
14 | ## configs
15 | self.device = 'cuda:0' if args.gpu else 'cpu'
16 | self.checkpoint_path = args.checkpoint
17 | self.detect_human_face = args.detect_human_face
18 | self.render_video = args.render_video
19 | self.output_size = args.output_size
20 | self.image_size = 64
21 | self.min_depth = 0.9
22 | self.max_depth = 1.1
23 | self.border_depth = 1.05
24 | self.xyz_rotation_range = 60
25 | self.xy_translation_range = 0.1
26 | self.z_translation_range = 0
27 | self.fov = 10 # in degrees
28 |
29 | self.depth_rescaler = lambda d : (1+d)/2 *self.max_depth + (1-d)/2 *self.min_depth # (-1,1) => (min_depth,max_depth)
30 | self.depth_inv_rescaler = lambda d : (d-self.min_depth) / (self.max_depth-self.min_depth) # (min_depth,max_depth) => (0,1)
31 |
32 | fx = (self.image_size-1)/2/(np.tan(self.fov/2 *np.pi/180))
33 | fy = (self.image_size-1)/2/(np.tan(self.fov/2 *np.pi/180))
34 | cx = (self.image_size-1)/2
35 | cy = (self.image_size-1)/2
36 | K = [[fx, 0., cx],
37 | [0., fy, cy],
38 | [0., 0., 1.]]
39 | K = torch.FloatTensor(K).to(self.device)
40 | self.inv_K = torch.inverse(K).unsqueeze(0)
41 | self.K = K.unsqueeze(0)
42 |
43 | ## NN models
44 | self.netD = EDDeconv(cin=3, cout=1, nf=64, zdim=256, activation=None)
45 | self.netA = EDDeconv(cin=3, cout=3, nf=64, zdim=256)
46 | self.netL = Encoder(cin=3, cout=4, nf=32)
47 | self.netV = Encoder(cin=3, cout=6, nf=32)
48 |
49 | self.netD = self.netD.to(self.device)
50 | self.netA = self.netA.to(self.device)
51 | self.netL = self.netL.to(self.device)
52 | self.netV = self.netV.to(self.device)
53 | self.load_checkpoint()
54 |
55 | self.netD.eval()
56 | self.netA.eval()
57 | self.netL.eval()
58 | self.netV.eval()
59 |
60 | ## face detecter
61 | if self.detect_human_face:
62 | from facenet_pytorch import MTCNN
63 | self.face_detector = MTCNN(select_largest=True, device=self.device)
64 |
65 | ## renderer
66 | if self.render_video:
67 | from unsup3d.renderer import Renderer
68 | assert 'cuda' in self.device, 'A GPU device is required for rendering because the neural_renderer only has GPU implementation.'
69 | cfgs = {
70 | 'device': self.device,
71 | 'image_size': self.output_size,
72 | 'min_depth': self.min_depth,
73 | 'max_depth': self.max_depth,
74 | 'fov': self.fov,
75 | }
76 | self.renderer = Renderer(cfgs)
77 |
78 | def load_checkpoint(self):
79 | print(f"Loading checkpoint from {self.checkpoint_path}")
80 | cp = torch.load(self.checkpoint_path, map_location=self.device)
81 | self.netD.load_state_dict(cp['netD'])
82 | self.netA.load_state_dict(cp['netA'])
83 | self.netL.load_state_dict(cp['netL'])
84 | self.netV.load_state_dict(cp['netV'])
85 |
86 | def depth_to_3d_grid(self, depth, inv_K=None):
87 | if inv_K is None:
88 | inv_K = self.inv_K
89 | b, h, w = depth.shape
90 | grid_2d = get_grid(b, h, w, normalize=False).to(depth.device) # Nxhxwx2
91 | depth = depth.unsqueeze(-1)
92 | grid_3d = torch.cat((grid_2d, torch.ones_like(depth)), dim=3)
93 | grid_3d = grid_3d.matmul(inv_K.transpose(2,1)) * depth
94 | return grid_3d
95 |
96 | def get_normal_from_depth(self, depth):
97 | b, h, w = depth.shape
98 | grid_3d = self.depth_to_3d_grid(depth)
99 |
100 | tu = grid_3d[:,1:-1,2:] - grid_3d[:,1:-1,:-2]
101 | tv = grid_3d[:,2:,1:-1] - grid_3d[:,:-2,1:-1]
102 | normal = tu.cross(tv, dim=3)
103 |
104 | zero = normal.new_tensor([0,0,1])
105 | normal = torch.cat([zero.repeat(b,h-2,1,1), normal, zero.repeat(b,h-2,1,1)], 2)
106 | normal = torch.cat([zero.repeat(b,1,w,1), normal, zero.repeat(b,1,w,1)], 1)
107 | normal = normal / (((normal**2).sum(3, keepdim=True))**0.5 + EPS)
108 | return normal
109 |
110 | def detect_face(self, im):
111 | print("Detecting face using MTCNN face detector")
112 | try:
113 | bboxes, prob = self.face_detector.detect(im)
114 | w0, h0, w1, h1 = bboxes[0]
115 | except:
116 | print("Could not detect faces in the image")
117 | return None
118 |
119 | hc, wc = (h0+h1)/2, (w0+w1)/2
120 | crop = int(((h1-h0) + (w1-w0)) /2/2 *1.1)
121 | im = np.pad(im, ((crop,crop),(crop,crop),(0,0)), mode='edge') # allow cropping outside by replicating borders
122 | h0 = int(hc-crop+crop + crop*0.15)
123 | w0 = int(wc-crop+crop)
124 | return im[h0:h0+crop*2, w0:w0+crop*2]
125 |
126 | def run(self, pil_im):
127 | im = np.uint8(pil_im)
128 |
129 | ## face detection
130 | if self.detect_human_face:
131 | im = self.detect_face(im)
132 | if im is None:
133 | return -1
134 |
135 | h, w, _ = im.shape
136 | im = torch.FloatTensor(im /255.).permute(2,0,1).unsqueeze(0)
137 | # resize to 128 first if too large, to avoid bilinear downsampling artifacts
138 | if h > self.image_size*4 and w > self.image_size*4:
139 | im = nn.functional.interpolate(im, (self.image_size*2, self.image_size*2), mode='bilinear', align_corners=False)
140 | im = nn.functional.interpolate(im, (self.image_size, self.image_size), mode='bilinear', align_corners=False)
141 |
142 | with torch.no_grad():
143 | self.input_im = im.to(self.device) *2.-1.
144 | b, c, h, w = self.input_im.shape
145 |
146 | ## predict canonical depth
147 | self.canon_depth_raw = self.netD(self.input_im).squeeze(1) # BxHxW
148 | self.canon_depth = self.canon_depth_raw - self.canon_depth_raw.view(b,-1).mean(1).view(b,1,1)
149 | self.canon_depth = self.canon_depth.tanh()
150 | self.canon_depth = self.depth_rescaler(self.canon_depth)
151 |
152 | ## clamp border depth
153 | depth_border = torch.zeros(1,h,w-4).to(self.input_im.device)
154 | depth_border = nn.functional.pad(depth_border, (2,2), mode='constant', value=1)
155 | self.canon_depth = self.canon_depth*(1-depth_border) + depth_border *self.border_depth
156 |
157 | ## predict canonical albedo
158 | self.canon_albedo = self.netA(self.input_im) # Bx3xHxW
159 |
160 | ## predict lighting
161 | canon_light = self.netL(self.input_im) # Bx4
162 | self.canon_light_a = canon_light[:,:1] /2+0.5 # ambience term
163 | self.canon_light_b = canon_light[:,1:2] /2+0.5 # diffuse term
164 | canon_light_dxy = canon_light[:,2:]
165 | self.canon_light_d = torch.cat([canon_light_dxy, torch.ones(b,1).to(self.input_im.device)], 1)
166 | self.canon_light_d = self.canon_light_d / ((self.canon_light_d**2).sum(1, keepdim=True))**0.5 # diffuse light direction
167 |
168 | ## shading
169 | self.canon_normal = self.get_normal_from_depth(self.canon_depth)
170 | self.canon_diffuse_shading = (self.canon_normal * self.canon_light_d.view(-1,1,1,3)).sum(3).clamp(min=0).unsqueeze(1)
171 | canon_shading = self.canon_light_a.view(-1,1,1,1) + self.canon_light_b.view(-1,1,1,1)*self.canon_diffuse_shading
172 | self.canon_im = (self.canon_albedo/2+0.5) * canon_shading *2-1
173 |
174 | ## predict viewpoint transformation
175 | self.view = self.netV(self.input_im)
176 | self.view = torch.cat([
177 | self.view[:,:3] *np.pi/180 *self.xyz_rotation_range,
178 | self.view[:,3:5] *self.xy_translation_range,
179 | self.view[:,5:] *self.z_translation_range], 1)
180 |
181 | ## export to obj strings
182 | vertices = self.depth_to_3d_grid(self.canon_depth) # BxHxWx3
183 | self.objs, self.mtls = export_to_obj_string(vertices, self.canon_normal)
184 |
185 | ## resize to output size
186 | self.canon_depth = nn.functional.interpolate(self.canon_depth.unsqueeze(1), (self.output_size, self.output_size), mode='bilinear', align_corners=False).squeeze(1)
187 | self.canon_normal = nn.functional.interpolate(self.canon_normal.permute(0,3,1,2), (self.output_size, self.output_size), mode='bilinear', align_corners=False).permute(0,2,3,1)
188 | self.canon_normal = self.canon_normal / (self.canon_normal**2).sum(3, keepdim=True)**0.5
189 | self.canon_diffuse_shading = nn.functional.interpolate(self.canon_diffuse_shading, (self.output_size, self.output_size), mode='bilinear', align_corners=False)
190 | self.canon_albedo = nn.functional.interpolate(self.canon_albedo, (self.output_size, self.output_size), mode='bilinear', align_corners=False)
191 | self.canon_im = nn.functional.interpolate(self.canon_im, (self.output_size, self.output_size), mode='bilinear', align_corners=False)
192 |
193 | if self.render_video:
194 | self.render_animation()
195 |
196 | def render_animation(self):
197 | print(f"Rendering video animations")
198 | b, h, w = self.canon_depth.shape
199 |
200 | ## morph from target view to canonical
201 | morph_frames = 15
202 | view_zero = torch.FloatTensor([0.15*np.pi/180*60, 0,0,0,0,0]).to(self.canon_depth.device)
203 | morph_s = torch.linspace(0, 1, morph_frames).to(self.canon_depth.device)
204 | view_morph = morph_s.view(-1,1,1) * view_zero.view(1,1,-1) + (1-morph_s.view(-1,1,1)) * self.view.unsqueeze(0) # TxBx6
205 |
206 | ## yaw from canonical to both sides
207 | yaw_frames = 80
208 | yaw_rotations = np.linspace(-np.pi/2, np.pi/2, yaw_frames)
209 | # yaw_rotations = np.concatenate([yaw_rotations[40:], yaw_rotations[::-1], yaw_rotations[:40]], 0)
210 |
211 | ## whole rotation sequence
212 | view_after = torch.cat([view_morph, view_zero.repeat(yaw_frames, b, 1)], 0)
213 | yaw_rotations = np.concatenate([np.zeros(morph_frames), yaw_rotations], 0)
214 |
215 | def rearrange_frames(frames):
216 | morph_seq = frames[:, :morph_frames]
217 | yaw_seq = frames[:, morph_frames:]
218 | out_seq = torch.cat([
219 | morph_seq[:,:1].repeat(1,5,1,1,1),
220 | morph_seq,
221 | morph_seq[:,-1:].repeat(1,5,1,1,1),
222 | yaw_seq[:, yaw_frames//2:],
223 | yaw_seq.flip(1),
224 | yaw_seq[:, :yaw_frames//2],
225 | morph_seq[:,-1:].repeat(1,5,1,1,1),
226 | morph_seq.flip(1),
227 | morph_seq[:,:1].repeat(1,5,1,1,1),
228 | ], 1)
229 | return out_seq
230 |
231 | ## textureless shape
232 | front_light = torch.FloatTensor([0,0,1]).to(self.canon_depth.device)
233 | canon_shape_im = (self.canon_normal * front_light.view(1,1,1,3)).sum(3).clamp(min=0).unsqueeze(1)
234 | canon_shape_im = canon_shape_im.repeat(1,3,1,1) *0.7
235 | shape_animation = self.renderer.render_yaw(canon_shape_im, self.canon_depth, v_after=view_after, rotations=yaw_rotations) # BxTxCxHxW
236 | self.shape_animation = rearrange_frames(shape_animation)
237 |
238 | ## normal map
239 | canon_normal_im = self.canon_normal.permute(0,3,1,2) /2+0.5
240 | normal_animation = self.renderer.render_yaw(canon_normal_im, self.canon_depth, v_after=view_after, rotations=yaw_rotations) # BxTxCxHxW
241 | self.normal_animation = rearrange_frames(normal_animation)
242 |
243 | ## textured
244 | texture_animation = self.renderer.render_yaw(self.canon_im /2+0.5, self.canon_depth, v_after=view_after, rotations=yaw_rotations) # BxTxCxHxW
245 | self.texture_animation = rearrange_frames(texture_animation)
246 |
247 | def save_results(self, save_dir):
248 | print(f"Saving results to {save_dir}")
249 | save_image(save_dir, self.input_im[0]/2+0.5, 'input_image')
250 | save_image(save_dir, self.depth_inv_rescaler(self.canon_depth)[0].repeat(3,1,1), 'canonical_depth')
251 | save_image(save_dir, self.canon_normal[0].permute(2,0,1)/2+0.5, 'canonical_normal')
252 | save_image(save_dir, self.canon_diffuse_shading[0].repeat(3,1,1), 'canonical_diffuse_shading')
253 | save_image(save_dir, self.canon_albedo[0]/2+0.5, 'canonical_albedo')
254 | save_image(save_dir, self.canon_im[0].clamp(-1,1)/2+0.5, 'canonical_image')
255 |
256 | with open(os.path.join(save_dir, 'result.mtl'), "w") as f:
257 | f.write(self.mtls[0].replace('$TXTFILE', './canonical_image.png'))
258 | with open(os.path.join(save_dir, 'result.obj'), "w") as f:
259 | f.write(self.objs[0].replace('$MTLFILE', './result.mtl'))
260 |
261 | if self.render_video:
262 | save_video(save_dir, self.shape_animation[0], 'shape_animation')
263 | save_video(save_dir, self.normal_animation[0], 'normal_animation')
264 | save_video(save_dir, self.texture_animation[0], 'texture_animation')
265 |
266 |
267 | if __name__ == "__main__":
268 | parser = argparse.ArgumentParser(description='Demo configurations.')
269 | parser.add_argument('--input', default='./demo/images/human_face', type=str, help='Path to the directory containing input images')
270 | parser.add_argument('--result', default='./demo/results/human_face', type=str, help='Path to the directory for saving results')
271 | parser.add_argument('--checkpoint', default='./pretrained/pretrained_celeba/checkpoint030.pth', type=str, help='Path to the checkpoint file')
272 | parser.add_argument('--output_size', default=128, type=int, help='Output image size')
273 | parser.add_argument('--gpu', default=False, action='store_true', help='Enable GPU')
274 | parser.add_argument('--detect_human_face', default=False, action='store_true', help='Enable automatic human face detection. This does not detect cat faces.')
275 | parser.add_argument('--render_video', default=False, action='store_true', help='Render 3D animations to video')
276 | args = parser.parse_args()
277 |
278 | input_dir = args.input
279 | result_dir = args.result
280 | model = Demo(args)
281 | im_list = [os.path.join(input_dir, f) for f in sorted(os.listdir(input_dir)) if is_image_file(f)]
282 |
283 | for im_path in im_list:
284 | print(f"Processing {im_path}")
285 | pil_im = Image.open(im_path).convert('RGB')
286 | result_code = model.run(pil_im)
287 | if result_code == -1:
288 | print(f"Failed! Skipping {im_path}")
289 | continue
290 |
291 | save_dir = os.path.join(result_dir, os.path.splitext(os.path.basename(im_path))[0])
292 | model.save_results(save_dir)
293 |
--------------------------------------------------------------------------------
/demo/images/cat_face/001_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/001_cat.png
--------------------------------------------------------------------------------
/demo/images/cat_face/002_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/002_cat.png
--------------------------------------------------------------------------------
/demo/images/cat_face/003_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/003_cat.png
--------------------------------------------------------------------------------
/demo/images/cat_face/004_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/004_cat.png
--------------------------------------------------------------------------------
/demo/images/cat_face/005_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/005_cat.png
--------------------------------------------------------------------------------
/demo/images/cat_face/006_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/006_cat.png
--------------------------------------------------------------------------------
/demo/images/cat_face/007_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/007_cat.png
--------------------------------------------------------------------------------
/demo/images/cat_face/008_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/008_cat.png
--------------------------------------------------------------------------------
/demo/images/cat_face/009_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/009_cat.png
--------------------------------------------------------------------------------
/demo/images/cat_face/010_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/010_cat.png
--------------------------------------------------------------------------------
/demo/images/cat_face/011_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/011_cat.png
--------------------------------------------------------------------------------
/demo/images/cat_face/012_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/012_cat.png
--------------------------------------------------------------------------------
/demo/images/cat_face/013_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/013_cat.png
--------------------------------------------------------------------------------
/demo/images/cat_face/014_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/014_cat.png
--------------------------------------------------------------------------------
/demo/images/cat_face/015_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/015_cat.png
--------------------------------------------------------------------------------
/demo/images/cat_face/016_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/016_cat.png
--------------------------------------------------------------------------------
/demo/images/cat_face/017_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/017_cat.png
--------------------------------------------------------------------------------
/demo/images/cat_face/018_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/018_cat.png
--------------------------------------------------------------------------------
/demo/images/cat_face/019_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/019_cat.png
--------------------------------------------------------------------------------
/demo/images/cat_face/020_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/020_cat.png
--------------------------------------------------------------------------------
/demo/images/cat_face/021_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/021_cat.png
--------------------------------------------------------------------------------
/demo/images/cat_face/022_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/022_cat.png
--------------------------------------------------------------------------------
/demo/images/cat_face/023_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/023_cat.png
--------------------------------------------------------------------------------
/demo/images/cat_face/024_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/024_abstract.png
--------------------------------------------------------------------------------
/demo/images/cat_face/025_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/025_abstract.png
--------------------------------------------------------------------------------
/demo/images/cat_face/026_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/026_abstract.png
--------------------------------------------------------------------------------
/demo/images/cat_face/027_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/027_abstract.png
--------------------------------------------------------------------------------
/demo/images/cat_face/028_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/028_abstract.png
--------------------------------------------------------------------------------
/demo/images/cat_face/029_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/029_abstract.png
--------------------------------------------------------------------------------
/demo/images/cat_face/030_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/030_abstract.png
--------------------------------------------------------------------------------
/demo/images/cat_face/031_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/031_abstract.png
--------------------------------------------------------------------------------
/demo/images/cat_face/032_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/032_abstract.png
--------------------------------------------------------------------------------
/demo/images/cat_face/033_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/033_abstract.png
--------------------------------------------------------------------------------
/demo/images/cat_face/034_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/034_abstract.png
--------------------------------------------------------------------------------
/demo/images/cat_face/035_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/035_abstract.png
--------------------------------------------------------------------------------
/demo/images/cat_face/036_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/036_abstract.png
--------------------------------------------------------------------------------
/demo/images/cat_face/037_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/037_abstract.png
--------------------------------------------------------------------------------
/demo/images/cat_face/038_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/038_abstract.png
--------------------------------------------------------------------------------
/demo/images/cat_face/039_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/039_abstract.png
--------------------------------------------------------------------------------
/demo/images/cat_face/040_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/040_abstract.png
--------------------------------------------------------------------------------
/demo/images/cat_face/041_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/041_abstract.png
--------------------------------------------------------------------------------
/demo/images/cat_face/042_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/042_abstract.png
--------------------------------------------------------------------------------
/demo/images/cat_face/043_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/043_abstract.png
--------------------------------------------------------------------------------
/demo/images/cat_face/044_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/044_abstract.png
--------------------------------------------------------------------------------
/demo/images/cat_face/045_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/cat_face/045_abstract.png
--------------------------------------------------------------------------------
/demo/images/human_face/001_face.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/001_face.png
--------------------------------------------------------------------------------
/demo/images/human_face/002_face.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/002_face.png
--------------------------------------------------------------------------------
/demo/images/human_face/003_face.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/003_face.png
--------------------------------------------------------------------------------
/demo/images/human_face/004_face.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/004_face.png
--------------------------------------------------------------------------------
/demo/images/human_face/005_face.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/005_face.png
--------------------------------------------------------------------------------
/demo/images/human_face/006_face.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/006_face.png
--------------------------------------------------------------------------------
/demo/images/human_face/007_face.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/007_face.png
--------------------------------------------------------------------------------
/demo/images/human_face/008_face.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/008_face.png
--------------------------------------------------------------------------------
/demo/images/human_face/009_face.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/009_face.png
--------------------------------------------------------------------------------
/demo/images/human_face/010_face.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/010_face.png
--------------------------------------------------------------------------------
/demo/images/human_face/011_face.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/011_face.png
--------------------------------------------------------------------------------
/demo/images/human_face/012_face.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/012_face.png
--------------------------------------------------------------------------------
/demo/images/human_face/013_face.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/013_face.png
--------------------------------------------------------------------------------
/demo/images/human_face/014_face.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/014_face.png
--------------------------------------------------------------------------------
/demo/images/human_face/015_face.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/015_face.png
--------------------------------------------------------------------------------
/demo/images/human_face/016_paint.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/016_paint.png
--------------------------------------------------------------------------------
/demo/images/human_face/017_paint.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/017_paint.png
--------------------------------------------------------------------------------
/demo/images/human_face/018_paint.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/018_paint.png
--------------------------------------------------------------------------------
/demo/images/human_face/019_paint.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/019_paint.png
--------------------------------------------------------------------------------
/demo/images/human_face/020_paint.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/020_paint.png
--------------------------------------------------------------------------------
/demo/images/human_face/021_paint.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/021_paint.png
--------------------------------------------------------------------------------
/demo/images/human_face/022_paint.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/022_paint.png
--------------------------------------------------------------------------------
/demo/images/human_face/023_paint.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/023_paint.png
--------------------------------------------------------------------------------
/demo/images/human_face/024_paint.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/024_paint.png
--------------------------------------------------------------------------------
/demo/images/human_face/025_paint.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/025_paint.png
--------------------------------------------------------------------------------
/demo/images/human_face/026_paint.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/026_paint.png
--------------------------------------------------------------------------------
/demo/images/human_face/027_paint.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/027_paint.png
--------------------------------------------------------------------------------
/demo/images/human_face/028_paint.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/028_paint.png
--------------------------------------------------------------------------------
/demo/images/human_face/029_paint.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/029_paint.png
--------------------------------------------------------------------------------
/demo/images/human_face/030_paint.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/030_paint.png
--------------------------------------------------------------------------------
/demo/images/human_face/031_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/031_abstract.png
--------------------------------------------------------------------------------
/demo/images/human_face/032_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/032_abstract.png
--------------------------------------------------------------------------------
/demo/images/human_face/033_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/033_abstract.png
--------------------------------------------------------------------------------
/demo/images/human_face/034_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/034_abstract.png
--------------------------------------------------------------------------------
/demo/images/human_face/035_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/035_abstract.png
--------------------------------------------------------------------------------
/demo/images/human_face/036_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/036_abstract.png
--------------------------------------------------------------------------------
/demo/images/human_face/037_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/037_abstract.png
--------------------------------------------------------------------------------
/demo/images/human_face/038_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/038_abstract.png
--------------------------------------------------------------------------------
/demo/images/human_face/039_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/039_abstract.png
--------------------------------------------------------------------------------
/demo/images/human_face/040_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/040_abstract.png
--------------------------------------------------------------------------------
/demo/images/human_face/041_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/041_abstract.png
--------------------------------------------------------------------------------
/demo/images/human_face/042_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/042_abstract.png
--------------------------------------------------------------------------------
/demo/images/human_face/043_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/043_abstract.png
--------------------------------------------------------------------------------
/demo/images/human_face/044_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/044_abstract.png
--------------------------------------------------------------------------------
/demo/images/human_face/045_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/demo/images/human_face/045_abstract.png
--------------------------------------------------------------------------------
/demo/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class Encoder(nn.Module):
9 | def __init__(self, cin, cout, nf=64, activation=nn.Tanh):
10 | super(Encoder, self).__init__()
11 | network = [
12 | nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32
13 | nn.ReLU(inplace=True),
14 | nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16
15 | nn.ReLU(inplace=True),
16 | nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8
17 | nn.ReLU(inplace=True),
18 | nn.Conv2d(nf*4, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4
19 | nn.ReLU(inplace=True),
20 | nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1
21 | nn.ReLU(inplace=True),
22 | nn.Conv2d(nf*8, cout, kernel_size=1, stride=1, padding=0, bias=False)]
23 | if activation is not None:
24 | network += [activation()]
25 | self.network = nn.Sequential(*network)
26 |
27 | def forward(self, input):
28 | return self.network(input).reshape(input.size(0),-1)
29 |
30 |
31 | class EDDeconv(nn.Module):
32 | def __init__(self, cin, cout, zdim=128, nf=64, activation=nn.Tanh):
33 | super(EDDeconv, self).__init__()
34 | ## downsampling
35 | network = [
36 | nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32
37 | nn.GroupNorm(16, nf),
38 | nn.LeakyReLU(0.2, inplace=True),
39 | nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16
40 | nn.GroupNorm(16*2, nf*2),
41 | nn.LeakyReLU(0.2, inplace=True),
42 | nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8
43 | nn.GroupNorm(16*4, nf*4),
44 | nn.LeakyReLU(0.2, inplace=True),
45 | nn.Conv2d(nf*4, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4
46 | nn.LeakyReLU(0.2, inplace=True),
47 | nn.Conv2d(nf*8, zdim, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1
48 | nn.ReLU(inplace=True)]
49 | ## upsampling
50 | network += [
51 | nn.ConvTranspose2d(zdim, nf*8, kernel_size=4, stride=1, padding=0, bias=False), # 1x1 -> 4x4
52 | nn.ReLU(inplace=True),
53 | nn.Conv2d(nf*8, nf*8, kernel_size=3, stride=1, padding=1, bias=False),
54 | nn.ReLU(inplace=True),
55 | nn.ConvTranspose2d(nf*8, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 4x4 -> 8x8
56 | nn.GroupNorm(16*4, nf*4),
57 | nn.ReLU(inplace=True),
58 | nn.Conv2d(nf*4, nf*4, kernel_size=3, stride=1, padding=1, bias=False),
59 | nn.GroupNorm(16*4, nf*4),
60 | nn.ReLU(inplace=True),
61 | nn.ConvTranspose2d(nf*4, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 16x16
62 | nn.GroupNorm(16*2, nf*2),
63 | nn.ReLU(inplace=True),
64 | nn.Conv2d(nf*2, nf*2, kernel_size=3, stride=1, padding=1, bias=False),
65 | nn.GroupNorm(16*2, nf*2),
66 | nn.ReLU(inplace=True),
67 | nn.ConvTranspose2d(nf*2, nf, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 32x32
68 | nn.GroupNorm(16, nf),
69 | nn.ReLU(inplace=True),
70 | nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=False),
71 | nn.GroupNorm(16, nf),
72 | nn.ReLU(inplace=True),
73 | nn.Upsample(scale_factor=2, mode='nearest'), # 32x32 -> 64x64
74 | nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=False),
75 | nn.GroupNorm(16, nf),
76 | nn.ReLU(inplace=True),
77 | nn.Conv2d(nf, nf, kernel_size=5, stride=1, padding=2, bias=False),
78 | nn.GroupNorm(16, nf),
79 | nn.ReLU(inplace=True),
80 | nn.Conv2d(nf, cout, kernel_size=5, stride=1, padding=2, bias=False)]
81 | if activation is not None:
82 | network += [activation()]
83 | self.network = nn.Sequential(*network)
84 |
85 | def forward(self, input):
86 | return self.network(input)
87 |
88 |
89 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp')
90 | def is_image_file(filename):
91 | return filename.lower().endswith(IMG_EXTENSIONS)
92 |
93 |
94 | def save_video(out_fold, frames, fname='image', ext='.mp4', cycle=False):
95 | os.makedirs(out_fold, exist_ok=True)
96 | frames = frames.detach().cpu().numpy().transpose(0,2,3,1) # TxCxHxW -> TxHxWxC
97 | if cycle:
98 | frames = np.concatenate([frames, frames[::-1]], 0)
99 | fourcc = cv2.VideoWriter_fourcc(*'mp4v')
100 | # fourcc = cv2.VideoWriter_fourcc(*'avc1')
101 | vid = cv2.VideoWriter(os.path.join(out_fold, fname+ext), fourcc, 25, (frames.shape[2], frames.shape[1]))
102 | [vid.write(np.uint8(f[...,::-1]*255.)) for f in frames]
103 | vid.release()
104 |
105 |
106 | def save_image(out_fold, img, fname='image', ext='.png'):
107 | os.makedirs(out_fold, exist_ok=True)
108 | img = img.detach().cpu().numpy().transpose(1,2,0)
109 | if 'depth' in fname:
110 | im_out = np.uint16(img*65535.)
111 | else:
112 | im_out = np.uint8(img*255.)
113 | cv2.imwrite(os.path.join(out_fold, fname+ext), im_out[:,:,::-1])
114 |
115 |
116 | def get_grid(b, H, W, normalize=True):
117 | if normalize:
118 | h_range = torch.linspace(-1,1,H)
119 | w_range = torch.linspace(-1,1,W)
120 | else:
121 | h_range = torch.arange(0,H)
122 | w_range = torch.arange(0,W)
123 | grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).repeat(b,1,1,1).flip(3).float() # flip h,w to x,y
124 | return grid
125 |
126 |
127 | def export_to_obj_string(vertices, normal):
128 | b, h, w, _ = vertices.shape
129 | vertices[:,:,:,1:2] = -1*vertices[:,:,:,1:2] # flip y
130 | vertices[:,:,:,2:3] = 1-vertices[:,:,:,2:3] # flip and shift z
131 | vertices *= 100
132 | vertices_center = nn.functional.avg_pool2d(vertices.permute(0,3,1,2), 2, stride=1).permute(0,2,3,1)
133 | vertices = torch.cat([vertices.view(b,h*w,3), vertices_center.view(b,(h-1)*(w-1),3)], 1)
134 |
135 | vertice_textures = get_grid(b, h, w, normalize=True) # BxHxWx2
136 | vertice_textures[:,:,:,1:2] = -1*vertice_textures[:,:,:,1:2] # flip y
137 | vertice_textures_center = nn.functional.avg_pool2d(vertice_textures.permute(0,3,1,2), 2, stride=1).permute(0,2,3,1)
138 | vertice_textures = torch.cat([vertice_textures.view(b,h*w,2), vertice_textures_center.view(b,(h-1)*(w-1),2)], 1) /2+0.5 # Bx(H*W)x2, [0,1]
139 |
140 | vertice_normals = normal.clone()
141 | vertice_normals[:,:,:,0:1] = -1*vertice_normals[:,:,:,0:1]
142 | vertice_normals_center = nn.functional.avg_pool2d(vertice_normals.permute(0,3,1,2), 2, stride=1).permute(0,2,3,1)
143 | vertice_normals_center = vertice_normals_center / (vertice_normals_center**2).sum(3, keepdim=True)**0.5
144 | vertice_normals = torch.cat([vertice_normals.view(b,h*w,3), vertice_normals_center.view(b,(h-1)*(w-1),3)], 1) # Bx(H*W)x2, [0,1]
145 |
146 | idx_map = torch.arange(h*w).reshape(h,w)
147 | idx_map_center = torch.arange((h-1)*(w-1)).reshape(h-1,w-1)
148 | faces1 = torch.stack([idx_map[:h-1,:w-1], idx_map[1:,:w-1], idx_map_center+h*w], -1).reshape(-1,3).repeat(b,1,1).int() # Bx((H-1)*(W-1))x4
149 | faces2 = torch.stack([idx_map[1:,:w-1], idx_map[1:,1:], idx_map_center+h*w], -1).reshape(-1,3).repeat(b,1,1).int() # Bx((H-1)*(W-1))x4
150 | faces3 = torch.stack([idx_map[1:,1:], idx_map[:h-1,1:], idx_map_center+h*w], -1).reshape(-1,3).repeat(b,1,1).int() # Bx((H-1)*(W-1))x4
151 | faces4 = torch.stack([idx_map[:h-1,1:], idx_map[:h-1,:w-1], idx_map_center+h*w], -1).reshape(-1,3).repeat(b,1,1).int() # Bx((H-1)*(W-1))x4
152 | faces = torch.cat([faces1, faces2, faces3, faces4], 1)
153 |
154 | objs = []
155 | mtls = []
156 | for bi in range(b):
157 | obj = "# OBJ File:"
158 | obj += "\n\nmtllib $MTLFILE"
159 | obj += "\n\n# vertices:"
160 | for v in vertices[bi]:
161 | obj += "\nv " + " ".join(["%.4f"%x for x in v])
162 | obj += "\n\n# vertice textures:"
163 | for vt in vertice_textures[bi]:
164 | obj += "\nvt " + " ".join(["%.4f"%x for x in vt])
165 | obj += "\n\n# vertice normals:"
166 | for vn in vertice_normals[bi]:
167 | obj += "\nvn " + " ".join(["%.4f"%x for x in vn])
168 | obj += "\n\n# faces:"
169 | obj += "\n\nusemtl tex"
170 | for f in faces[bi]:
171 | obj += "\nf " + " ".join(["%d/%d/%d"%(x+1,x+1,x+1) for x in f])
172 | objs += [obj]
173 |
174 | mtl = "newmtl tex"
175 | mtl += "\nKa 1.0000 1.0000 1.0000"
176 | mtl += "\nKd 1.0000 1.0000 1.0000"
177 | mtl += "\nKs 0.0000 0.0000 0.0000"
178 | mtl += "\nd 1.0"
179 | mtl += "\nillum 0"
180 | mtl += "\nmap_Kd $TXTFILE"
181 | mtls += [mtl]
182 | return objs, mtls
183 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: unsup3d
2 | channels:
3 | - conda-forge
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - bzip2=1.0.8=h516909a_2
8 | - ca-certificates=2020.4.5.1=hecc5488_0
9 | - cairo=1.16.0=h18b612c_1001
10 | - certifi=2020.4.5.1=py37hc8dfbb8_0
11 | - cffi=1.14.0=py37hd463f26_0
12 | - chardet=3.0.4=py37hc8dfbb8_1006
13 | - cloudpickle=1.3.0=py_0
14 | - cryptography=2.8=py37hb09aad4_2
15 | - cycler=0.10.0=py_2
16 | - cytoolz=0.10.1=py37h516909a_0
17 | - dask-core=2.14.0=py_0
18 | - dbus=1.13.6=he372182_0
19 | - decorator=4.4.2=py_0
20 | - expat=2.2.9=he1b5a44_2
21 | - ffmpeg=4.1.3=h167e202_0
22 | - fontconfig=2.13.1=he4413a7_1000
23 | - freetype=2.10.1=he06d7ca_0
24 | - gettext=0.19.8.1=hc5be6a0_1002
25 | - giflib=5.2.1=h516909a_2
26 | - glib=2.58.3=py37he00f558_1003
27 | - gmp=6.2.0=he1b5a44_2
28 | - gnutls=3.6.5=hd3a4fd2_1002
29 | - graphite2=1.3.13=he1b5a44_1001
30 | - gst-plugins-base=1.14.5=h0935bb2_2
31 | - gstreamer=1.14.5=h36ae1b5_2
32 | - harfbuzz=2.4.0=h37c48d4_1
33 | - hdf5=1.10.5=nompi_h3c11f04_1104
34 | - icu=58.2=hf484d3e_1000
35 | - idna=2.9=py_1
36 | - imageio=2.8.0=py_0
37 | - imageio-ffmpeg=0.4.1=py_0
38 | - jasper=1.900.1=h07fcdf6_1006
39 | - jpeg=9c=h14c3975_1001
40 | - kiwisolver=1.2.0=py37h99015e2_0
41 | - lame=3.100=h14c3975_1001
42 | - ld_impl_linux-64=2.33.1=h53a641e_7
43 | - libblas=3.8.0=14_openblas
44 | - libcblas=3.8.0=14_openblas
45 | - libedit=3.1.20181209=hc058e9b_0
46 | - libffi=3.2.1=hd88cf55_4
47 | - libgcc-ng=9.1.0=hdf63c60_0
48 | - libgfortran-ng=7.3.0=hdf63c60_5
49 | - libiconv=1.15=h516909a_1006
50 | - liblapack=3.8.0=14_openblas
51 | - liblapacke=3.8.0=14_openblas
52 | - libopenblas=0.3.7=h5ec1e0e_6
53 | - libpng=1.6.37=hed695b0_1
54 | - libprotobuf=3.11.4=h8b12597_0
55 | - libstdcxx-ng=9.1.0=hdf63c60_0
56 | - libtiff=4.1.0=hc3755c2_3
57 | - libuuid=2.32.1=h14c3975_1000
58 | - libwebp=1.0.2=h56121f0_5
59 | - libxcb=1.13=h14c3975_1002
60 | - libxml2=2.9.9=h13577e0_2
61 | - lz4-c=1.8.3=he1b5a44_1001
62 | - matplotlib=3.1.3=py37_0
63 | - matplotlib-base=3.1.3=py37hef1b27d_0
64 | - moviepy=1.0.1=py_0
65 | - ncurses=6.2=he6710b0_0
66 | - nettle=3.4.1=h1bed415_1002
67 | - networkx=2.4=py_1
68 | - numpy=1.18.1=py37h8960a57_1
69 | - olefile=0.46=py_0
70 | - opencv=4.1.1=py37ha799480_1
71 | - openh264=1.8.0=hdbcaa40_1000
72 | - openssl=1.1.1f=h516909a_0
73 | - pcre=8.44=he1b5a44_0
74 | - pillow=5.3.0=py37h00a061d_1000
75 | - pip=20.0.2=py37_1
76 | - pixman=0.38.0=h516909a_1003
77 | - proglog=0.1.9=py_0
78 | - protobuf=3.11.4=py37h3340039_1
79 | - pthread-stubs=0.4=h14c3975_1001
80 | - pycparser=2.20=py_0
81 | - pyopenssl=19.1.0=py_1
82 | - pyparsing=2.4.7=pyh9f0ad1d_0
83 | - pyqt=5.9.2=py37hcca6a23_4
84 | - pysocks=1.7.1=py37hc8dfbb8_1
85 | - python=3.7.7=hcf32534_0_cpython
86 | - python-dateutil=2.8.1=py_0
87 | - python_abi=3.7=1_cp37m
88 | - pywavelets=1.1.1=py37hc1659b7_0
89 | - pyyaml=5.3.1=py37h8f50634_0
90 | - qt=5.9.7=h52cfd70_2
91 | - readline=8.0=h7b6447c_0
92 | - requests=2.23.0=pyh8c360ce_2
93 | - scikit-image=0.16.2=py37hb3f55d8_0
94 | - scipy=1.4.1=py37ha3d9a3c_2
95 | - setuptools=46.1.3=py37_0
96 | - sip=4.19.8=py37hf484d3e_1000
97 | - six=1.14.0=py_1
98 | - sqlite=3.31.1=h7b6447c_0
99 | - tensorboardx=2.0=py_0
100 | - tk=8.6.8=hbc83047_0
101 | - toolz=0.10.0=py_0
102 | - tornado=6.0.4=py37h8f50634_1
103 | - tqdm=4.45.0=pyh9f0ad1d_0
104 | - urllib3=1.25.7=py37hc8dfbb8_1
105 | - wheel=0.34.2=py37_0
106 | - x264=1!152.20180806=h14c3975_0
107 | - xorg-kbproto=1.0.7=h14c3975_1002
108 | - xorg-libice=1.0.10=h516909a_0
109 | - xorg-libsm=1.2.3=h84519dc_1000
110 | - xorg-libx11=1.6.9=h516909a_0
111 | - xorg-libxau=1.0.9=h14c3975_0
112 | - xorg-libxdmcp=1.1.3=h516909a_0
113 | - xorg-libxext=1.3.4=h516909a_0
114 | - xorg-libxrender=0.9.10=h516909a_1002
115 | - xorg-renderproto=0.11.1=h14c3975_1002
116 | - xorg-xextproto=7.3.0=h14c3975_1002
117 | - xorg-xproto=7.0.31=h14c3975_1007
118 | - xz=5.2.4=h14c3975_4
119 | - yaml=0.2.2=h516909a_1
120 | - zlib=1.2.11=h7b6447c_3
121 | - zstd=1.4.4=h3b9ef0a_2
122 |
--------------------------------------------------------------------------------
/experiments/test_cat.yml:
--------------------------------------------------------------------------------
1 | ## test cat
2 | ## trainer
3 | run_test: true
4 | batch_size: 64
5 | checkpoint_dir: results/cat
6 | checkpoint_name: checkpoint100.pth
7 | test_result_dir: results/cat/test_results_checkpoint100
8 |
9 | ## dataloader
10 | num_workers: 4
11 | image_size: 64
12 | crop: 170
13 | load_gt_depth: false
14 | test_data_dir: data/cat_combined/test
15 |
16 | ## model
17 | model_name: unsup3d_cat
18 | min_depth: 0.9
19 | max_depth: 1.1
20 | xyz_rotation_range: 60 # (-r,r) in degrees
21 | xy_translation_range: 0.1 # (-t,t) in 3D
22 | z_translation_range: 0.1 # (-t,t) in 3D
23 | lam_perc: 1
24 | lam_flip: 0.5
25 |
26 | ## renderer
27 | rot_center_depth: 1.0
28 | fov: 10 # in degrees
29 | tex_cube_size: 2
30 |
--------------------------------------------------------------------------------
/experiments/test_celeba.yml:
--------------------------------------------------------------------------------
1 | ## test celeba
2 | ## trainer
3 | run_test: true
4 | batch_size: 64
5 | checkpoint_dir: results/celeba
6 | checkpoint_name: checkpoint030.pth
7 | test_result_dir: results/celeba/test_results_checkpoint030
8 |
9 | ## dataloader
10 | num_workers: 4
11 | image_size: 64
12 | load_gt_depth: false
13 | test_data_dir: data/celeba_cropped/test
14 |
15 | ## model
16 | model_name: unsup3d_celeba
17 | min_depth: 0.9
18 | max_depth: 1.1
19 | xyz_rotation_range: 60 # (-r,r) in degrees
20 | xy_translation_range: 0.1 # (-t,t) in 3D
21 | z_translation_range: 0 # (-t,t) in 3D
22 | lam_perc: 1
23 | lam_flip: 0.5
24 |
25 | ## renderer
26 | rot_center_depth: 1.0
27 | fov: 10 # in degrees
28 | tex_cube_size: 2
29 |
--------------------------------------------------------------------------------
/experiments/test_syncar.yml:
--------------------------------------------------------------------------------
1 | ## test syncar
2 | ## trainer
3 | run_test: true
4 | batch_size: 64
5 | checkpoint_dir: results/syncar
6 | checkpoint_name: checkpoint100.pth
7 | test_result_dir: results/syncar/test_results_checkpoint100
8 |
9 | ## dataloader
10 | num_workers: 4
11 | image_size: 64
12 | crop: [8, 14, 100, 100]
13 | load_gt_depth: false
14 | test_data_dir: data/syncar/test
15 |
16 | ## model
17 | model_name: unsup3d_syncar
18 | min_depth: 0.9
19 | max_depth: 1.1
20 | min_amb_light: 0.
21 | max_amb_light: 1.
22 | min_diff_light: 0.5
23 | max_diff_light: 1.
24 | xyz_rotation_range: 60 # (-r,r) in degrees
25 | xy_translation_range: 0.1 # (-t,t) in 3D
26 | z_translation_range: 0 # (-t,t) in 3D
27 | use_conf_map: false
28 | lam_perc: 0.01
29 | lam_flip: 1
30 | lam_depth_sm: 0.1
31 |
32 | ## renderer
33 | rot_center_depth: 1.0
34 | fov: 10 # in degrees
35 | tex_cube_size: 2
36 |
--------------------------------------------------------------------------------
/experiments/test_synface.yml:
--------------------------------------------------------------------------------
1 | ## test synface
2 | ## trainer
3 | run_test: true
4 | batch_size: 64
5 | checkpoint_dir: results/synface
6 | checkpoint_name: checkpoint030.pth
7 | test_result_dir: results/synface/test_results_checkpoint030
8 |
9 | ## dataloader
10 | num_workers: 4
11 | image_size: 64
12 | crop: 170
13 | load_gt_depth: true
14 | paired_data_dir_names: ['image', 'depth']
15 | paired_data_filename_diff: ['image', 'depth']
16 | test_data_dir: data/synface/test
17 |
18 | ## model
19 | model_name: unsup3d_synface
20 | min_depth: 0.9
21 | max_depth: 1.1
22 | xyz_rotation_range: 60 # (-r,r) in degrees
23 | xy_translation_range: 0.1 # (-t,t) in 3D
24 | z_translation_range: 0 # (-t,t) in 3D
25 | lam_perc: 1
26 | lam_flip: 0.5
27 |
28 | ## renderer
29 | rot_center_depth: 1.0
30 | fov: 10 # in degrees
31 | tex_cube_size: 2
32 |
--------------------------------------------------------------------------------
/experiments/train_cat.yml:
--------------------------------------------------------------------------------
1 | ## train cat
2 | ## trainer
3 | run_train: true
4 | num_epochs: 200
5 | batch_size: 64
6 | checkpoint_dir: results/cat
7 | save_checkpoint_freq: 10
8 | keep_num_checkpoint: 2
9 | resume: true
10 | use_logger: true
11 | log_freq: 500
12 |
13 | ## dataloader
14 | num_workers: 4
15 | image_size: 64
16 | crop: 170
17 | load_gt_depth: false
18 | train_val_data_dir: data/cat_combined
19 |
20 | ## model
21 | model_name: unsup3d_cat
22 | min_depth: 0.9
23 | max_depth: 1.1
24 | xyz_rotation_range: 60 # (-r,r) in degrees
25 | xy_translation_range: 0.1 # (-t,t) in 3D
26 | z_translation_range: 0.1 # (-t,t) in 3D
27 | lam_perc: 1
28 | lam_flip: 0.5
29 | lam_flip_start_epoch: 10
30 | lr: 0.0001
31 |
32 | ## renderer
33 | rot_center_depth: 1.0
34 | fov: 10 # in degrees
35 | tex_cube_size: 2
36 |
--------------------------------------------------------------------------------
/experiments/train_celeba.yml:
--------------------------------------------------------------------------------
1 | ## train celeba
2 | ## trainer
3 | run_train: true
4 | num_epochs: 30
5 | batch_size: 64
6 | checkpoint_dir: results/celeba
7 | save_checkpoint_freq: 1
8 | keep_num_checkpoint: 2
9 | resume: true
10 | use_logger: true
11 | log_freq: 500
12 |
13 | ## dataloader
14 | num_workers: 4
15 | image_size: 64
16 | load_gt_depth: false
17 | train_val_data_dir: data/celeba_cropped
18 |
19 | ## model
20 | model_name: unsup3d_celeba
21 | min_depth: 0.9
22 | max_depth: 1.1
23 | xyz_rotation_range: 60 # (-r,r) in degrees
24 | xy_translation_range: 0.1 # (-t,t) in 3D
25 | z_translation_range: 0 # (-t,t) in 3D
26 | lam_perc: 1
27 | lam_flip: 0.5
28 | lr: 0.0001
29 |
30 | ## renderer
31 | rot_center_depth: 1.0
32 | fov: 10 # in degrees
33 | tex_cube_size: 2
34 |
--------------------------------------------------------------------------------
/experiments/train_syncar.yml:
--------------------------------------------------------------------------------
1 | ## train syncar
2 | ## trainer
3 | run_train: true
4 | num_epochs: 100
5 | batch_size: 64
6 | checkpoint_dir: results/syncar
7 | save_checkpoint_freq: 10
8 | keep_num_checkpoint: 2
9 | resume: true
10 | use_logger: true
11 | log_freq: 500
12 |
13 | ## dataloader
14 | num_workers: 4
15 | image_size: 64
16 | crop: [8, 14, 100, 100]
17 | load_gt_depth: false
18 | test_data_dir: data/syncar
19 |
20 | ## model
21 | model_name: unsup3d_syncar
22 | min_depth: 0.9
23 | max_depth: 1.1
24 | border_depth: 1.08
25 | min_amb_light: 0.
26 | max_amb_light: 1.
27 | min_diff_light: 0.5
28 | max_diff_light: 1.
29 | xyz_rotation_range: 60 # (-r,r) in degrees
30 | xy_translation_range: 0.1 # (-t,t) in 3D
31 | z_translation_range: 0 # (-t,t) in 3D
32 | use_conf_map: false
33 | lam_perc: 0.01
34 | lam_flip: 1
35 | lam_depth_sm: 0.1
36 | lr: 0.0001
37 |
38 | ## renderer
39 | rot_center_depth: 1.0
40 | fov: 10 # in degrees
41 | tex_cube_size: 2
42 |
--------------------------------------------------------------------------------
/experiments/train_synface.yml:
--------------------------------------------------------------------------------
1 | ## train synface
2 | ## trainer
3 | run_train: true
4 | num_epochs: 30
5 | batch_size: 64
6 | checkpoint_dir: results/synface
7 | save_checkpoint_freq: 1
8 | keep_num_checkpoint: 2
9 | resume: true
10 | use_logger: true
11 | log_freq: 500
12 |
13 | ## dataloader
14 | num_workers: 4
15 | image_size: 64
16 | crop: 170
17 | load_gt_depth: true
18 | paired_data_dir_names: ['image', 'depth']
19 | paired_data_filename_diff: ['image', 'depth']
20 | train_val_data_dir: data/synface
21 |
22 | ## model
23 | model_name: unsup3d_synface
24 | min_depth: 0.9
25 | max_depth: 1.1
26 | xyz_rotation_range: 60 # (-r,r) in degrees
27 | xy_translation_range: 0.1 # (-t,t) in 3D
28 | z_translation_range: 0 # (-t,t) in 3D
29 | lam_perc: 1
30 | lam_flip: 0.5
31 | lr: 0.0001
32 |
33 | ## renderer
34 | rot_center_depth: 1.0
35 | fov: 10 # in degrees
36 | tex_cube_size: 2
37 |
--------------------------------------------------------------------------------
/img/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/unsup3d/dc961410d61684561f19525c2f7e9ee6f4dacb91/img/teaser.jpg
--------------------------------------------------------------------------------
/pretrained/download_pretrained_cat.sh:
--------------------------------------------------------------------------------
1 | echo "----------------------- downloading pretrained model on cat dataset -----------------------"
2 | curl -o pretrained_cat.zip "https://www.robots.ox.ac.uk/~vgg/research/unsup3d/data/pretrained_cat.zip" && unzip pretrained_cat.zip
3 |
--------------------------------------------------------------------------------
/pretrained/download_pretrained_celeba.sh:
--------------------------------------------------------------------------------
1 | echo "----------------------- downloading pretrained model on celeba face dataset -----------------------"
2 | curl -o pretrained_celeba.zip "https://www.robots.ox.ac.uk/~vgg/research/unsup3d/data/pretrained_celeba.zip" && unzip pretrained_celeba.zip
3 |
--------------------------------------------------------------------------------
/pretrained/download_pretrained_syncar.sh:
--------------------------------------------------------------------------------
1 | echo "----------------------- downloading pretrained model on synthetic car dataset -----------------------"
2 | curl -o pretrained_syncar.zip "https://www.robots.ox.ac.uk/~vgg/research/unsup3d/data/pretrained_syncar.zip" && unzip pretrained_syncar.zip
3 |
--------------------------------------------------------------------------------
/pretrained/download_pretrained_synface.sh:
--------------------------------------------------------------------------------
1 | echo "----------------------- downloading pretrained model on synthetic face dataset -----------------------"
2 | curl -o pretrained_synface.zip "https://www.robots.ox.ac.uk/~vgg/research/unsup3d/data/pretrained_synface.zip" && unzip pretrained_synface.zip
3 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from unsup3d import setup_runtime, Trainer, Unsup3D
4 |
5 |
6 | ## runtime arguments
7 | parser = argparse.ArgumentParser(description='Training configurations.')
8 | parser.add_argument('--config', default=None, type=str, help='Specify a config file path')
9 | parser.add_argument('--gpu', default=None, type=int, help='Specify a GPU device')
10 | parser.add_argument('--num_workers', default=4, type=int, help='Specify the number of worker threads for data loaders')
11 | parser.add_argument('--seed', default=0, type=int, help='Specify a random seed')
12 | args = parser.parse_args()
13 |
14 | ## set up
15 | cfgs = setup_runtime(args)
16 | trainer = Trainer(cfgs, Unsup3D)
17 | run_train = cfgs.get('run_train', False)
18 | run_test = cfgs.get('run_test', False)
19 |
20 | ## run
21 | if run_train:
22 | trainer.train()
23 | if run_test:
24 | trainer.test()
25 |
--------------------------------------------------------------------------------
/unsup3d/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import setup_runtime
2 | from .trainer import Trainer
3 | from .model import Unsup3D
4 |
--------------------------------------------------------------------------------
/unsup3d/dataloaders.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torchvision.transforms as tfs
3 | import torch.utils.data
4 | import numpy as np
5 | from PIL import Image
6 |
7 |
8 | def get_data_loaders(cfgs):
9 | batch_size = cfgs.get('batch_size', 64)
10 | num_workers = cfgs.get('num_workers', 4)
11 | image_size = cfgs.get('image_size', 64)
12 | crop = cfgs.get('crop', None)
13 |
14 | run_train = cfgs.get('run_train', False)
15 | train_val_data_dir = cfgs.get('train_val_data_dir', './data')
16 | run_test = cfgs.get('run_test', False)
17 | test_data_dir = cfgs.get('test_data_dir', './data/test')
18 |
19 | load_gt_depth = cfgs.get('load_gt_depth', False)
20 | AB_dnames = cfgs.get('paired_data_dir_names', ['A', 'B'])
21 | AB_fnames = cfgs.get('paired_data_filename_diff', None)
22 |
23 | train_loader = val_loader = test_loader = None
24 | if load_gt_depth:
25 | get_loader = lambda **kargs: get_paired_image_loader(**kargs, batch_size=batch_size, image_size=image_size, crop=crop, AB_dnames=AB_dnames, AB_fnames=AB_fnames)
26 | else:
27 | get_loader = lambda **kargs: get_image_loader(**kargs, batch_size=batch_size, image_size=image_size, crop=crop)
28 |
29 | if run_train:
30 | train_data_dir = os.path.join(train_val_data_dir, "train")
31 | val_data_dir = os.path.join(train_val_data_dir, "val")
32 | assert os.path.isdir(train_data_dir), "Training data directory does not exist: %s" %train_data_dir
33 | assert os.path.isdir(val_data_dir), "Validation data directory does not exist: %s" %val_data_dir
34 | print(f"Loading training data from {train_data_dir}")
35 | train_loader = get_loader(data_dir=train_data_dir, is_validation=False)
36 | print(f"Loading validation data from {val_data_dir}")
37 | val_loader = get_loader(data_dir=val_data_dir, is_validation=True)
38 | if run_test:
39 | assert os.path.isdir(test_data_dir), "Testing data directory does not exist: %s" %test_data_dir
40 | print(f"Loading testing data from {test_data_dir}")
41 | test_loader = get_loader(data_dir=test_data_dir, is_validation=True)
42 |
43 | return train_loader, val_loader, test_loader
44 |
45 |
46 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp')
47 | def is_image_file(filename):
48 | return filename.lower().endswith(IMG_EXTENSIONS)
49 |
50 |
51 | ## simple image dataset ##
52 | def make_dataset(dir):
53 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
54 |
55 | images = []
56 | for root, _, fnames in sorted(os.walk(dir)):
57 | for fname in sorted(fnames):
58 | if is_image_file(fname):
59 | fpath = os.path.join(root, fname)
60 | images.append(fpath)
61 | return images
62 |
63 |
64 | class ImageDataset(torch.utils.data.Dataset):
65 | def __init__(self, data_dir, image_size=256, crop=None, is_validation=False):
66 | super(ImageDataset, self).__init__()
67 | self.root = data_dir
68 | self.paths = make_dataset(data_dir)
69 | self.size = len(self.paths)
70 | self.image_size = image_size
71 | self.crop = crop
72 | self.is_validation = is_validation
73 |
74 | def transform(self, img, hflip=False):
75 | if self.crop is not None:
76 | if isinstance(self.crop, int):
77 | img = tfs.CenterCrop(self.crop)(img)
78 | else:
79 | assert len(self.crop) == 4, 'Crop size must be an integer for center crop, or a list of 4 integers (y0,x0,h,w)'
80 | img = tfs.functional.crop(img, *self.crop)
81 | img = tfs.functional.resize(img, (self.image_size, self.image_size))
82 | if hflip:
83 | img = tfs.functional.hflip(img)
84 | return tfs.functional.to_tensor(img)
85 |
86 | def __getitem__(self, index):
87 | fpath = self.paths[index % self.size]
88 | img = Image.open(fpath).convert('RGB')
89 | hflip = not self.is_validation and np.random.rand()>0.5
90 | return self.transform(img, hflip=hflip)
91 |
92 | def __len__(self):
93 | return self.size
94 |
95 | def name(self):
96 | return 'ImageDataset'
97 |
98 |
99 | def get_image_loader(data_dir, is_validation=False,
100 | batch_size=256, num_workers=4, image_size=256, crop=None):
101 |
102 | dataset = ImageDataset(data_dir, image_size=image_size, crop=crop, is_validation=is_validation)
103 | loader = torch.utils.data.DataLoader(
104 | dataset,
105 | batch_size=batch_size,
106 | shuffle=not is_validation,
107 | num_workers=num_workers,
108 | pin_memory=True
109 | )
110 | return loader
111 |
112 |
113 | ## paired AB image dataset ##
114 | def make_paied_dataset(dir, AB_dnames=None, AB_fnames=None):
115 | A_dname, B_dname = AB_dnames or ('A', 'B')
116 | dir_A = os.path.join(dir, A_dname)
117 | dir_B = os.path.join(dir, B_dname)
118 | assert os.path.isdir(dir_A), '%s is not a valid directory' % dir_A
119 | assert os.path.isdir(dir_B), '%s is not a valid directory' % dir_B
120 |
121 | images = []
122 | for root_A, _, fnames_A in sorted(os.walk(dir_A)):
123 | for fname_A in sorted(fnames_A):
124 | if is_image_file(fname_A):
125 | path_A = os.path.join(root_A, fname_A)
126 | root_B = root_A.replace(dir_A, dir_B, 1)
127 | if AB_fnames is not None:
128 | fname_B = fname_A.replace(*AB_fnames)
129 | else:
130 | fname_B = fname_A
131 | path_B = os.path.join(root_B, fname_B)
132 | if os.path.isfile(path_B):
133 | images.append((path_A, path_B))
134 | return images
135 |
136 |
137 | class PairedDataset(torch.utils.data.Dataset):
138 | def __init__(self, data_dir, image_size=256, crop=None, is_validation=False, AB_dnames=None, AB_fnames=None):
139 | super(PairedDataset, self).__init__()
140 | self.root = data_dir
141 | self.paths = make_paied_dataset(data_dir, AB_dnames=AB_dnames, AB_fnames=AB_fnames)
142 | self.size = len(self.paths)
143 | self.image_size = image_size
144 | self.crop = crop
145 | self.is_validation = is_validation
146 |
147 | def transform(self, img, hflip=False):
148 | if self.crop is not None:
149 | if isinstance(self.crop, int):
150 | img = tfs.CenterCrop(self.crop)(img)
151 | else:
152 | assert len(self.crop) == 4, 'Crop size must be an integer for center crop, or a list of 4 integers (y0,x0,h,w)'
153 | img = tfs.functional.crop(img, *self.crop)
154 | img = tfs.functional.resize(img, (self.image_size, self.image_size))
155 | if hflip:
156 | img = tfs.functional.hflip(img)
157 | return tfs.functional.to_tensor(img)
158 |
159 | def __getitem__(self, index):
160 | path_A, path_B = self.paths[index % self.size]
161 | img_A = Image.open(path_A).convert('RGB')
162 | img_B = Image.open(path_B).convert('RGB')
163 | hflip = not self.is_validation and np.random.rand()>0.5
164 | return self.transform(img_A, hflip=hflip), self.transform(img_B, hflip=hflip)
165 |
166 | def __len__(self):
167 | return self.size
168 |
169 | def name(self):
170 | return 'PairedDataset'
171 |
172 |
173 | def get_paired_image_loader(data_dir, is_validation=False,
174 | batch_size=256, num_workers=4, image_size=256, crop=None, AB_dnames=None, AB_fnames=None):
175 |
176 | dataset = PairedDataset(data_dir, image_size=image_size, crop=crop, \
177 | is_validation=is_validation, AB_dnames=AB_dnames, AB_fnames=AB_fnames)
178 | loader = torch.utils.data.DataLoader(
179 | dataset,
180 | batch_size=batch_size,
181 | shuffle=not is_validation,
182 | num_workers=num_workers,
183 | pin_memory=True
184 | )
185 | return loader
186 |
--------------------------------------------------------------------------------
/unsup3d/meters.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import time
4 | import torch
5 | import operator
6 | from functools import reduce
7 | import matplotlib.pyplot as plt
8 | import collections
9 | from .utils import xmkdir
10 |
11 |
12 | class TotalAverage():
13 | def __init__(self):
14 | self.reset()
15 |
16 | def reset(self):
17 | self.last_value = 0.
18 | self.mass = 0.
19 | self.sum = 0.
20 |
21 | def update(self, value, mass=1):
22 | self.last_value = value
23 | self.mass += mass
24 | self.sum += value * mass
25 |
26 | def get(self):
27 | return self.sum / self.mass
28 |
29 | class MovingAverage():
30 | def __init__(self, inertia=0.9):
31 | self.inertia = inertia
32 | self.reset()
33 | self.last_value = None
34 |
35 | def reset(self):
36 | self.last_value = None
37 | self.average = None
38 |
39 | def update(self, value, mass=1):
40 | self.last_value = value
41 | if self.average is None:
42 | self.average = value
43 | else:
44 | self.average = self.inertia * self.average + (1 - self.inertia) * value
45 |
46 | def get(self):
47 | return self.average
48 |
49 | class MetricsTrace():
50 | def __init__(self):
51 | self.reset()
52 |
53 | def reset(self):
54 | self.data = {}
55 |
56 | def append(self, dataset, metric):
57 | if dataset not in self.data:
58 | self.data[dataset] = []
59 | self.data[dataset].append(metric.get_data_dict())
60 |
61 | def load(self, path):
62 | """Load the metrics trace from the specified JSON file."""
63 | with open(path, 'r') as f:
64 | self.data = json.load(f)
65 |
66 | def save(self, path):
67 | """Save the metrics trace to the specified JSON file."""
68 | if path is None:
69 | return
70 | xmkdir(os.path.dirname(path))
71 | with open(path, 'w') as f:
72 | json.dump(self.data, f, indent=2)
73 |
74 | def plot(self, pdf_path=None):
75 | """Plots and optionally save as PDF the metrics trace."""
76 | plot_metrics(self.data, pdf_path=pdf_path)
77 |
78 | def get(self):
79 | return self.data
80 |
81 | def __str__(self):
82 | pass
83 |
84 | class Metrics():
85 | def __init__(self):
86 | self.iteration_time = MovingAverage(inertia=0.9)
87 | self.now = time.time()
88 |
89 | def update(self, prediction=None, ground_truth=None):
90 | self.iteration_time.update(time.time() - self.now)
91 | self.now = time.time()
92 |
93 | def get_data_dict(self):
94 | return {"objective" : self.objective.get(), "iteration_time" : self.iteration_time.get()}
95 |
96 | class StandardMetrics(Metrics):
97 | def __init__(self, m=None):
98 | super(StandardMetrics, self).__init__()
99 | self.metrics = m or {}
100 | self.speed = MovingAverage(inertia=0.9)
101 |
102 | def update(self, metric_dict, mass=1):
103 | super(StandardMetrics, self).update()
104 | for metric, val in metric_dict.items():
105 | if torch.is_tensor(val):
106 | val = val.item()
107 | if metric not in self.metrics:
108 | self.metrics[metric] = TotalAverage()
109 | self.metrics[metric].update(val, mass)
110 | self.speed.update(mass / self.iteration_time.last_value)
111 |
112 | def get_data_dict(self):
113 | data_dict = {k: v.get() for k,v in self.metrics.items()}
114 | data_dict['speed'] = self.speed.get()
115 | return data_dict
116 |
117 | def __str__(self):
118 | pstr = '%7.1fHz\t' %self.speed.get()
119 | pstr += '\t'.join(['%s: %6.5f' %(k,v.get()) for k,v in self.metrics.items()])
120 | return pstr
121 |
122 | def plot_metrics(stats, pdf_path=None, fig=1, datasets=None, metrics=None):
123 | """Plot metrics. `stats` should be a dictionary of type
124 |
125 | stats[dataset][t][metric][i]
126 |
127 | where dataset is the dataset name (e.g. `train` or `val`), t is an iteration number,
128 | metric is the name of a metric (e.g. `loss` or `top1`), and i is a loss dimension.
129 |
130 | Alternatively, if a loss has a single dimension, `stats[dataset][t][metric]` can
131 | be a scalar.
132 |
133 | The supported options are:
134 |
135 | - pdf_file: path to a PDF file to store the figure (default: None)
136 | - fig: MatPlotLib figure index (default: 1)
137 | - datasets: list of dataset names to plot (default: None)
138 | - metrics: list of metrics to plot (default: None)
139 | """
140 | plt.figure(fig)
141 | plt.clf()
142 | linestyles = ['-', '--', '-.', ':']
143 | datasets = list(stats.keys()) if datasets is None else datasets
144 | # Filter out empty datasets
145 | datasets = [d for d in datasets if len(stats[d]) > 0]
146 | duration = len(stats[datasets[0]])
147 | metrics = list(stats[datasets[0]][0].keys()) if metrics is None else metrics
148 | for m, metric in enumerate(metrics):
149 | plt.subplot(len(metrics),1,m+1)
150 | legend_content = []
151 | for d, dataset in enumerate(datasets):
152 | ls = linestyles[d % len(linestyles)]
153 | if isinstance(stats[dataset][0][metric], collections.Iterable):
154 | metric_dimension = len(stats[dataset][0][metric])
155 | for sl in range(metric_dimension):
156 | x = [stats[dataset][t][metric][sl] for t in range(duration)]
157 | plt.plot(x, linestyle=ls)
158 | name = f'{dataset} {metric}[{sl}]'
159 | legend_content.append(name)
160 | else:
161 | x = [stats[dataset][t][metric] for t in range(duration)]
162 | plt.plot(x, linestyle=ls)
163 | name = f'{dataset} {metric}'
164 | legend_content.append(name)
165 | plt.legend(legend_content, loc=(1.04,0))
166 | plt.grid(True)
167 | if pdf_path is not None:
168 | plt.savefig(pdf_path, format='pdf', bbox_inches='tight')
169 | plt.draw()
170 | plt.pause(0.0001)
171 |
--------------------------------------------------------------------------------
/unsup3d/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import glob
4 | import torch
5 | import torch.nn as nn
6 | import torchvision
7 | from . import networks
8 | from . import utils
9 | from .renderer import Renderer
10 |
11 |
12 | EPS = 1e-7
13 |
14 |
15 | class Unsup3D():
16 | def __init__(self, cfgs):
17 | self.model_name = cfgs.get('model_name', self.__class__.__name__)
18 | self.device = cfgs.get('device', 'cpu')
19 | self.image_size = cfgs.get('image_size', 64)
20 | self.min_depth = cfgs.get('min_depth', 0.9)
21 | self.max_depth = cfgs.get('max_depth', 1.1)
22 | self.border_depth = cfgs.get('border_depth', (0.7*self.max_depth + 0.3*self.min_depth))
23 | self.min_amb_light = cfgs.get('min_amb_light', 0.)
24 | self.max_amb_light = cfgs.get('max_amb_light', 1.)
25 | self.min_diff_light = cfgs.get('min_diff_light', 0.)
26 | self.max_diff_light = cfgs.get('max_diff_light', 1.)
27 | self.xyz_rotation_range = cfgs.get('xyz_rotation_range', 60)
28 | self.xy_translation_range = cfgs.get('xy_translation_range', 0.1)
29 | self.z_translation_range = cfgs.get('z_translation_range', 0.1)
30 | self.use_conf_map = cfgs.get('use_conf_map', True)
31 | self.lam_perc = cfgs.get('lam_perc', 1)
32 | self.lam_flip = cfgs.get('lam_flip', 0.5)
33 | self.lam_flip_start_epoch = cfgs.get('lam_flip_start_epoch', 0)
34 | self.lam_depth_sm = cfgs.get('lam_depth_sm', 0)
35 | self.lr = cfgs.get('lr', 1e-4)
36 | self.load_gt_depth = cfgs.get('load_gt_depth', False)
37 | self.renderer = Renderer(cfgs)
38 |
39 | ## networks and optimizers
40 | self.netD = networks.EDDeconv(cin=3, cout=1, nf=64, zdim=256, activation=None)
41 | self.netA = networks.EDDeconv(cin=3, cout=3, nf=64, zdim=256)
42 | self.netL = networks.Encoder(cin=3, cout=4, nf=32)
43 | self.netV = networks.Encoder(cin=3, cout=6, nf=32)
44 | if self.use_conf_map:
45 | self.netC = networks.ConfNet(cin=3, cout=2, nf=64, zdim=128)
46 | self.network_names = [k for k in vars(self) if 'net' in k]
47 | self.make_optimizer = lambda model: torch.optim.Adam(
48 | filter(lambda p: p.requires_grad, model.parameters()),
49 | lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4)
50 |
51 | ## other parameters
52 | self.PerceptualLoss = networks.PerceptualLoss(requires_grad=False)
53 | self.other_param_names = ['PerceptualLoss']
54 |
55 | ## depth rescaler: -1~1 -> min_deph~max_deph
56 | self.depth_rescaler = lambda d : (1+d)/2 *self.max_depth + (1-d)/2 *self.min_depth
57 | self.amb_light_rescaler = lambda x : (1+x)/2 *self.max_amb_light + (1-x)/2 *self.min_amb_light
58 | self.diff_light_rescaler = lambda x : (1+x)/2 *self.max_diff_light + (1-x)/2 *self.min_diff_light
59 |
60 | def init_optimizers(self):
61 | self.optimizer_names = []
62 | for net_name in self.network_names:
63 | optimizer = self.make_optimizer(getattr(self, net_name))
64 | optim_name = net_name.replace('net','optimizer')
65 | setattr(self, optim_name, optimizer)
66 | self.optimizer_names += [optim_name]
67 |
68 | def load_model_state(self, cp):
69 | for k in cp:
70 | if k and k in self.network_names:
71 | getattr(self, k).load_state_dict(cp[k])
72 |
73 | def load_optimizer_state(self, cp):
74 | for k in cp:
75 | if k and k in self.optimizer_names:
76 | getattr(self, k).load_state_dict(cp[k])
77 |
78 | def get_model_state(self):
79 | states = {}
80 | for net_name in self.network_names:
81 | states[net_name] = getattr(self, net_name).state_dict()
82 | return states
83 |
84 | def get_optimizer_state(self):
85 | states = {}
86 | for optim_name in self.optimizer_names:
87 | states[optim_name] = getattr(self, optim_name).state_dict()
88 | return states
89 |
90 | def to_device(self, device):
91 | self.device = device
92 | for net_name in self.network_names:
93 | setattr(self, net_name, getattr(self, net_name).to(device))
94 | if self.other_param_names:
95 | for param_name in self.other_param_names:
96 | setattr(self, param_name, getattr(self, param_name).to(device))
97 |
98 | def set_train(self):
99 | for net_name in self.network_names:
100 | getattr(self, net_name).train()
101 |
102 | def set_eval(self):
103 | for net_name in self.network_names:
104 | getattr(self, net_name).eval()
105 |
106 | def photometric_loss(self, im1, im2, mask=None, conf_sigma=None):
107 | loss = (im1-im2).abs()
108 | if conf_sigma is not None:
109 | loss = loss *2**0.5 / (conf_sigma +EPS) + (conf_sigma +EPS).log()
110 | if mask is not None:
111 | mask = mask.expand_as(loss)
112 | loss = (loss * mask).sum() / mask.sum()
113 | else:
114 | loss = loss.mean()
115 | return loss
116 |
117 | def backward(self):
118 | for optim_name in self.optimizer_names:
119 | getattr(self, optim_name).zero_grad()
120 | self.loss_total.backward()
121 | for optim_name in self.optimizer_names:
122 | getattr(self, optim_name).step()
123 |
124 | def forward(self, input):
125 | """Feedforward once."""
126 | if self.load_gt_depth:
127 | input, depth_gt = input
128 | self.input_im = input.to(self.device) *2.-1.
129 | b, c, h, w = self.input_im.shape
130 |
131 | ## predict canonical depth
132 | self.canon_depth_raw = self.netD(self.input_im).squeeze(1) # BxHxW
133 | self.canon_depth = self.canon_depth_raw - self.canon_depth_raw.view(b,-1).mean(1).view(b,1,1)
134 | self.canon_depth = self.canon_depth.tanh()
135 | self.canon_depth = self.depth_rescaler(self.canon_depth)
136 |
137 | ## optional depth smoothness loss (only used in synthetic car experiments)
138 | self.loss_depth_sm = ((self.canon_depth[:,:-1,:] - self.canon_depth[:,1:,:]) /(self.max_depth-self.min_depth)).abs().mean()
139 | self.loss_depth_sm += ((self.canon_depth[:,:,:-1] - self.canon_depth[:,:,1:]) /(self.max_depth-self.min_depth)).abs().mean()
140 |
141 | ## clamp border depth
142 | depth_border = torch.zeros(1,h,w-4).to(self.input_im.device)
143 | depth_border = nn.functional.pad(depth_border, (2,2), mode='constant', value=1)
144 | self.canon_depth = self.canon_depth*(1-depth_border) + depth_border *self.border_depth
145 | self.canon_depth = torch.cat([self.canon_depth, self.canon_depth.flip(2)], 0) # flip
146 |
147 | ## predict canonical albedo
148 | self.canon_albedo = self.netA(self.input_im) # Bx3xHxW
149 | self.canon_albedo = torch.cat([self.canon_albedo, self.canon_albedo.flip(3)], 0) # flip
150 |
151 | ## predict confidence map
152 | if self.use_conf_map:
153 | conf_sigma_l1, conf_sigma_percl = self.netC(self.input_im) # Bx2xHxW
154 | self.conf_sigma_l1 = conf_sigma_l1[:,:1]
155 | self.conf_sigma_l1_flip = conf_sigma_l1[:,1:]
156 | self.conf_sigma_percl = conf_sigma_percl[:,:1]
157 | self.conf_sigma_percl_flip = conf_sigma_percl[:,1:]
158 | else:
159 | self.conf_sigma_l1 = None
160 | self.conf_sigma_l1_flip = None
161 | self.conf_sigma_percl = None
162 | self.conf_sigma_percl_flip = None
163 |
164 | ## predict lighting
165 | canon_light = self.netL(self.input_im).repeat(2,1) # Bx4
166 | self.canon_light_a = self.amb_light_rescaler(canon_light[:,:1]) # ambience term
167 | self.canon_light_b = self.diff_light_rescaler(canon_light[:,1:2]) # diffuse term
168 | canon_light_dxy = canon_light[:,2:]
169 | self.canon_light_d = torch.cat([canon_light_dxy, torch.ones(b*2,1).to(self.input_im.device)], 1)
170 | self.canon_light_d = self.canon_light_d / ((self.canon_light_d**2).sum(1, keepdim=True))**0.5 # diffuse light direction
171 |
172 | ## shading
173 | self.canon_normal = self.renderer.get_normal_from_depth(self.canon_depth)
174 | self.canon_diffuse_shading = (self.canon_normal * self.canon_light_d.view(-1,1,1,3)).sum(3).clamp(min=0).unsqueeze(1)
175 | canon_shading = self.canon_light_a.view(-1,1,1,1) + self.canon_light_b.view(-1,1,1,1)*self.canon_diffuse_shading
176 | self.canon_im = (self.canon_albedo/2+0.5) * canon_shading *2-1
177 |
178 | ## predict viewpoint transformation
179 | self.view = self.netV(self.input_im).repeat(2,1)
180 | self.view = torch.cat([
181 | self.view[:,:3] *math.pi/180 *self.xyz_rotation_range,
182 | self.view[:,3:5] *self.xy_translation_range,
183 | self.view[:,5:] *self.z_translation_range], 1)
184 |
185 | ## reconstruct input view
186 | self.renderer.set_transform_matrices(self.view)
187 | self.recon_depth = self.renderer.warp_canon_depth(self.canon_depth)
188 | self.recon_normal = self.renderer.get_normal_from_depth(self.recon_depth)
189 | grid_2d_from_canon = self.renderer.get_inv_warped_2d_grid(self.recon_depth)
190 | self.recon_im = nn.functional.grid_sample(self.canon_im, grid_2d_from_canon, mode='bilinear')
191 |
192 | margin = (self.max_depth - self.min_depth) /2
193 | recon_im_mask = (self.recon_depth < self.max_depth+margin).float() # invalid border pixels have been clamped at max_depth+margin
194 | recon_im_mask_both = recon_im_mask[:b] * recon_im_mask[b:] # both original and flip reconstruction
195 | recon_im_mask_both = recon_im_mask_both.repeat(2,1,1).unsqueeze(1).detach()
196 | self.recon_im = self.recon_im * recon_im_mask_both
197 |
198 | ## render symmetry axis
199 | canon_sym_axis = torch.zeros(h, w).to(self.input_im.device)
200 | canon_sym_axis[:, w//2-1:w//2+1] = 1
201 | self.recon_sym_axis = nn.functional.grid_sample(canon_sym_axis.repeat(b*2,1,1,1), grid_2d_from_canon, mode='bilinear')
202 | self.recon_sym_axis = self.recon_sym_axis * recon_im_mask_both
203 | green = torch.FloatTensor([-1,1,-1]).to(self.input_im.device).view(1,3,1,1)
204 | self.input_im_symline = (0.5*self.recon_sym_axis) *green + (1-0.5*self.recon_sym_axis) *self.input_im.repeat(2,1,1,1)
205 |
206 | ## loss function
207 | self.loss_l1_im = self.photometric_loss(self.recon_im[:b], self.input_im, mask=recon_im_mask_both[:b], conf_sigma=self.conf_sigma_l1)
208 | self.loss_l1_im_flip = self.photometric_loss(self.recon_im[b:], self.input_im, mask=recon_im_mask_both[b:], conf_sigma=self.conf_sigma_l1_flip)
209 | self.loss_perc_im = self.PerceptualLoss(self.recon_im[:b], self.input_im, mask=recon_im_mask_both[:b], conf_sigma=self.conf_sigma_percl)
210 | self.loss_perc_im_flip = self.PerceptualLoss(self.recon_im[b:], self.input_im, mask=recon_im_mask_both[b:], conf_sigma=self.conf_sigma_percl_flip)
211 | lam_flip = 1 if self.trainer.current_epoch < self.lam_flip_start_epoch else self.lam_flip
212 | self.loss_total = self.loss_l1_im + lam_flip*self.loss_l1_im_flip + self.lam_perc*(self.loss_perc_im + lam_flip*self.loss_perc_im_flip) + self.lam_depth_sm*self.loss_depth_sm
213 |
214 | metrics = {'loss': self.loss_total}
215 |
216 | ## compute accuracy if gt depth is available
217 | if self.load_gt_depth:
218 | self.depth_gt = depth_gt[:,0,:,:].to(self.input_im.device)
219 | self.depth_gt = (1-self.depth_gt)*2-1
220 | self.depth_gt = self.depth_rescaler(self.depth_gt)
221 | self.normal_gt = self.renderer.get_normal_from_depth(self.depth_gt)
222 |
223 | # mask out background
224 | mask_gt = (self.depth_gt 0.99).float() # erode by 1 pixel
226 | mask_pred = (nn.functional.avg_pool2d(recon_im_mask[:b].unsqueeze(1), 3, stride=1, padding=1).squeeze(1) > 0.99).float() # erode by 1 pixel
227 | mask = mask_gt * mask_pred
228 | self.acc_mae_masked = ((self.recon_depth[:b] - self.depth_gt[:b]).abs() *mask).view(b,-1).sum(1) / mask.view(b,-1).sum(1)
229 | self.acc_mse_masked = (((self.recon_depth[:b] - self.depth_gt[:b])**2) *mask).view(b,-1).sum(1) / mask.view(b,-1).sum(1)
230 | self.sie_map_masked = utils.compute_sc_inv_err(self.recon_depth[:b].log(), self.depth_gt[:b].log(), mask=mask)
231 | self.acc_sie_masked = (self.sie_map_masked.view(b,-1).sum(1) / mask.view(b,-1).sum(1))**0.5
232 | self.norm_err_map_masked = utils.compute_angular_distance(self.recon_normal[:b], self.normal_gt[:b], mask=mask)
233 | self.acc_normal_masked = self.norm_err_map_masked.view(b,-1).sum(1) / mask.view(b,-1).sum(1)
234 |
235 | metrics['SIE_masked'] = self.acc_sie_masked.mean()
236 | metrics['NorErr_masked'] = self.acc_normal_masked.mean()
237 |
238 | return metrics
239 |
240 | def visualize(self, logger, total_iter, max_bs=25):
241 | b, c, h, w = self.input_im.shape
242 | b0 = min(max_bs, b)
243 |
244 | ## render rotations
245 | with torch.no_grad():
246 | v0 = torch.FloatTensor([-0.1*math.pi/180*60,0,0,0,0,0]).to(self.input_im.device).repeat(b0,1)
247 | canon_im_rotate = self.renderer.render_yaw(self.canon_im[:b0], self.canon_depth[:b0], v_before=v0, maxr=90).detach().cpu() /2.+0.5 # (B,T,C,H,W)
248 | canon_normal_rotate = self.renderer.render_yaw(self.canon_normal[:b0].permute(0,3,1,2), self.canon_depth[:b0], v_before=v0, maxr=90).detach().cpu() /2.+0.5 # (B,T,C,H,W)
249 |
250 | input_im = self.input_im[:b0].detach().cpu().numpy() /2+0.5
251 | input_im_symline = self.input_im_symline[:b0].detach().cpu() /2.+0.5
252 | canon_albedo = self.canon_albedo[:b0].detach().cpu() /2.+0.5
253 | canon_im = self.canon_im[:b0].detach().cpu() /2.+0.5
254 | recon_im = self.recon_im[:b0].detach().cpu() /2.+0.5
255 | recon_im_flip = self.recon_im[b:b+b0].detach().cpu() /2.+0.5
256 | canon_depth_raw_hist = self.canon_depth_raw.detach().unsqueeze(1).cpu()
257 | canon_depth_raw = self.canon_depth_raw[:b0].detach().unsqueeze(1).cpu() /2.+0.5
258 | canon_depth = ((self.canon_depth[:b0] -self.min_depth)/(self.max_depth-self.min_depth)).detach().cpu().unsqueeze(1)
259 | recon_depth = ((self.recon_depth[:b0] -self.min_depth)/(self.max_depth-self.min_depth)).detach().cpu().unsqueeze(1)
260 | canon_diffuse_shading = self.canon_diffuse_shading[:b0].detach().cpu()
261 | canon_normal = self.canon_normal.permute(0,3,1,2)[:b0].detach().cpu() /2+0.5
262 | recon_normal = self.recon_normal.permute(0,3,1,2)[:b0].detach().cpu() /2+0.5
263 | if self.use_conf_map:
264 | conf_map_l1 = 1/(1+self.conf_sigma_l1[:b0].detach().cpu()+EPS)
265 | conf_map_l1_flip = 1/(1+self.conf_sigma_l1_flip[:b0].detach().cpu()+EPS)
266 | conf_map_percl = 1/(1+self.conf_sigma_percl[:b0].detach().cpu()+EPS)
267 | conf_map_percl_flip = 1/(1+self.conf_sigma_percl_flip[:b0].detach().cpu()+EPS)
268 |
269 | canon_im_rotate_grid = [torchvision.utils.make_grid(img, nrow=int(math.ceil(b0**0.5))) for img in torch.unbind(canon_im_rotate, 1)] # [(C,H,W)]*T
270 | canon_im_rotate_grid = torch.stack(canon_im_rotate_grid, 0).unsqueeze(0) # (1,T,C,H,W)
271 | canon_normal_rotate_grid = [torchvision.utils.make_grid(img, nrow=int(math.ceil(b0**0.5))) for img in torch.unbind(canon_normal_rotate, 1)] # [(C,H,W)]*T
272 | canon_normal_rotate_grid = torch.stack(canon_normal_rotate_grid, 0).unsqueeze(0) # (1,T,C,H,W)
273 |
274 | ## write summary
275 | logger.add_scalar('Loss/loss_total', self.loss_total, total_iter)
276 | logger.add_scalar('Loss/loss_l1_im', self.loss_l1_im, total_iter)
277 | logger.add_scalar('Loss/loss_l1_im_flip', self.loss_l1_im_flip, total_iter)
278 | logger.add_scalar('Loss/loss_perc_im', self.loss_perc_im, total_iter)
279 | logger.add_scalar('Loss/loss_perc_im_flip', self.loss_perc_im_flip, total_iter)
280 | logger.add_scalar('Loss/loss_depth_sm', self.loss_depth_sm, total_iter)
281 |
282 | logger.add_histogram('Depth/canon_depth_raw_hist', canon_depth_raw_hist, total_iter)
283 | vlist = ['view_rx', 'view_ry', 'view_rz', 'view_tx', 'view_ty', 'view_tz']
284 | for i in range(self.view.shape[1]):
285 | logger.add_histogram('View/'+vlist[i], self.view[:,i], total_iter)
286 | logger.add_histogram('Light/canon_light_a', self.canon_light_a, total_iter)
287 | logger.add_histogram('Light/canon_light_b', self.canon_light_b, total_iter)
288 | llist = ['canon_light_dx', 'canon_light_dy', 'canon_light_dz']
289 | for i in range(self.canon_light_d.shape[1]):
290 | logger.add_histogram('Light/'+llist[i], self.canon_light_d[:,i], total_iter)
291 |
292 | def log_grid_image(label, im, nrow=int(math.ceil(b0**0.5)), iter=total_iter):
293 | im_grid = torchvision.utils.make_grid(im, nrow=nrow)
294 | logger.add_image(label, im_grid, iter)
295 |
296 | log_grid_image('Image/input_image_symline', input_im_symline)
297 | log_grid_image('Image/canonical_albedo', canon_albedo)
298 | log_grid_image('Image/canonical_image', canon_im)
299 | log_grid_image('Image/recon_image', recon_im)
300 | log_grid_image('Image/recon_image_flip', recon_im_flip)
301 | log_grid_image('Image/recon_side', canon_im_rotate[:,0,:,:,:])
302 |
303 | log_grid_image('Depth/canonical_depth_raw', canon_depth_raw)
304 | log_grid_image('Depth/canonical_depth', canon_depth)
305 | log_grid_image('Depth/recon_depth', recon_depth)
306 | log_grid_image('Depth/canonical_diffuse_shading', canon_diffuse_shading)
307 | log_grid_image('Depth/canonical_normal', canon_normal)
308 | log_grid_image('Depth/recon_normal', recon_normal)
309 |
310 | logger.add_histogram('Image/canonical_albedo_hist', canon_albedo, total_iter)
311 | logger.add_histogram('Image/canonical_diffuse_shading_hist', canon_diffuse_shading, total_iter)
312 |
313 | if self.use_conf_map:
314 | log_grid_image('Conf/conf_map_l1', conf_map_l1)
315 | logger.add_histogram('Conf/conf_sigma_l1_hist', self.conf_sigma_l1, total_iter)
316 | log_grid_image('Conf/conf_map_l1_flip', conf_map_l1_flip)
317 | logger.add_histogram('Conf/conf_sigma_l1_flip_hist', self.conf_sigma_l1_flip, total_iter)
318 | log_grid_image('Conf/conf_map_percl', conf_map_percl)
319 | logger.add_histogram('Conf/conf_sigma_percl_hist', self.conf_sigma_percl, total_iter)
320 | log_grid_image('Conf/conf_map_percl_flip', conf_map_percl_flip)
321 | logger.add_histogram('Conf/conf_sigma_percl_flip_hist', self.conf_sigma_percl_flip, total_iter)
322 |
323 | logger.add_video('Image_rotate/recon_rotate', canon_im_rotate_grid, total_iter, fps=4)
324 | logger.add_video('Image_rotate/canon_normal_rotate', canon_normal_rotate_grid, total_iter, fps=4)
325 |
326 | # visualize images and accuracy if gt is loaded
327 | if self.load_gt_depth:
328 | depth_gt = ((self.depth_gt[:b0] -self.min_depth)/(self.max_depth-self.min_depth)).detach().cpu().unsqueeze(1)
329 | normal_gt = self.normal_gt.permute(0,3,1,2)[:b0].detach().cpu() /2+0.5
330 | sie_map_masked = self.sie_map_masked[:b0].detach().unsqueeze(1).cpu() *1000
331 | norm_err_map_masked = self.norm_err_map_masked[:b0].detach().unsqueeze(1).cpu() /100
332 |
333 | logger.add_scalar('Acc_masked/MAE_masked', self.acc_mae_masked.mean(), total_iter)
334 | logger.add_scalar('Acc_masked/MSE_masked', self.acc_mse_masked.mean(), total_iter)
335 | logger.add_scalar('Acc_masked/SIE_masked', self.acc_sie_masked.mean(), total_iter)
336 | logger.add_scalar('Acc_masked/NorErr_masked', self.acc_normal_masked.mean(), total_iter)
337 |
338 | log_grid_image('Depth_gt/depth_gt', depth_gt)
339 | log_grid_image('Depth_gt/normal_gt', normal_gt)
340 | log_grid_image('Depth_gt/sie_map_masked', sie_map_masked)
341 | log_grid_image('Depth_gt/norm_err_map_masked', norm_err_map_masked)
342 |
343 | def save_results(self, save_dir):
344 | b, c, h, w = self.input_im.shape
345 |
346 | with torch.no_grad():
347 | v0 = torch.FloatTensor([-0.1*math.pi/180*60,0,0,0,0,0]).to(self.input_im.device).repeat(b,1)
348 | canon_im_rotate = self.renderer.render_yaw(self.canon_im[:b], self.canon_depth[:b], v_before=v0, maxr=90, nsample=15) # (B,T,C,H,W)
349 | canon_im_rotate = canon_im_rotate.clamp(-1,1).detach().cpu() /2+0.5
350 | canon_normal_rotate = self.renderer.render_yaw(self.canon_normal[:b].permute(0,3,1,2), self.canon_depth[:b], v_before=v0, maxr=90, nsample=15) # (B,T,C,H,W)
351 | canon_normal_rotate = canon_normal_rotate.clamp(-1,1).detach().cpu() /2+0.5
352 |
353 | input_im = self.input_im[:b].detach().cpu().numpy() /2+0.5
354 | input_im_symline = self.input_im_symline.detach().cpu().numpy() /2.+0.5
355 | canon_albedo = self.canon_albedo[:b].detach().cpu().numpy() /2+0.5
356 | canon_im = self.canon_im[:b].clamp(-1,1).detach().cpu().numpy() /2+0.5
357 | recon_im = self.recon_im[:b].clamp(-1,1).detach().cpu().numpy() /2+0.5
358 | recon_im_flip = self.recon_im[b:].clamp(-1,1).detach().cpu().numpy() /2+0.5
359 | canon_depth = ((self.canon_depth[:b] -self.min_depth)/(self.max_depth-self.min_depth)).clamp(0,1).detach().cpu().unsqueeze(1).numpy()
360 | recon_depth = ((self.recon_depth[:b] -self.min_depth)/(self.max_depth-self.min_depth)).clamp(0,1).detach().cpu().unsqueeze(1).numpy()
361 | canon_diffuse_shading = self.canon_diffuse_shading[:b].detach().cpu().numpy()
362 | canon_normal = self.canon_normal[:b].permute(0,3,1,2).detach().cpu().numpy() /2+0.5
363 | recon_normal = self.recon_normal[:b].permute(0,3,1,2).detach().cpu().numpy() /2+0.5
364 | if self.use_conf_map:
365 | conf_map_l1 = 1/(1+self.conf_sigma_l1[:b].detach().cpu().numpy()+EPS)
366 | conf_map_l1_flip = 1/(1+self.conf_sigma_l1_flip[:b].detach().cpu().numpy()+EPS)
367 | conf_map_percl = 1/(1+self.conf_sigma_percl[:b].detach().cpu().numpy()+EPS)
368 | conf_map_percl_flip = 1/(1+self.conf_sigma_percl_flip[:b].detach().cpu().numpy()+EPS)
369 | canon_light = torch.cat([self.canon_light_a, self.canon_light_b, self.canon_light_d], 1)[:b].detach().cpu().numpy()
370 | view = self.view[:b].detach().cpu().numpy()
371 |
372 | canon_im_rotate_grid = [torchvision.utils.make_grid(img, nrow=int(math.ceil(b**0.5))) for img in torch.unbind(canon_im_rotate,1)] # [(C,H,W)]*T
373 | canon_im_rotate_grid = torch.stack(canon_im_rotate_grid, 0).unsqueeze(0).numpy() # (1,T,C,H,W)
374 | canon_normal_rotate_grid = [torchvision.utils.make_grid(img, nrow=int(math.ceil(b**0.5))) for img in torch.unbind(canon_normal_rotate,1)] # [(C,H,W)]*T
375 | canon_normal_rotate_grid = torch.stack(canon_normal_rotate_grid, 0).unsqueeze(0).numpy() # (1,T,C,H,W)
376 |
377 | sep_folder = True
378 | utils.save_images(save_dir, input_im, suffix='input_image', sep_folder=sep_folder)
379 | utils.save_images(save_dir, input_im_symline, suffix='input_image_symline', sep_folder=sep_folder)
380 | utils.save_images(save_dir, canon_albedo, suffix='canonical_albedo', sep_folder=sep_folder)
381 | utils.save_images(save_dir, canon_im, suffix='canonical_image', sep_folder=sep_folder)
382 | utils.save_images(save_dir, recon_im, suffix='recon_image', sep_folder=sep_folder)
383 | utils.save_images(save_dir, recon_im_flip, suffix='recon_image_flip', sep_folder=sep_folder)
384 | utils.save_images(save_dir, canon_depth, suffix='canonical_depth', sep_folder=sep_folder)
385 | utils.save_images(save_dir, recon_depth, suffix='recon_depth', sep_folder=sep_folder)
386 | utils.save_images(save_dir, canon_diffuse_shading, suffix='canonical_diffuse_shading', sep_folder=sep_folder)
387 | utils.save_images(save_dir, canon_normal, suffix='canonical_normal', sep_folder=sep_folder)
388 | utils.save_images(save_dir, recon_normal, suffix='recon_normal', sep_folder=sep_folder)
389 | if self.use_conf_map:
390 | utils.save_images(save_dir, conf_map_l1, suffix='conf_map_l1', sep_folder=sep_folder)
391 | utils.save_images(save_dir, conf_map_l1_flip, suffix='conf_map_l1_flip', sep_folder=sep_folder)
392 | utils.save_images(save_dir, conf_map_percl, suffix='conf_map_percl', sep_folder=sep_folder)
393 | utils.save_images(save_dir, conf_map_percl_flip, suffix='conf_map_percl_flip', sep_folder=sep_folder)
394 | utils.save_txt(save_dir, canon_light, suffix='canonical_light', sep_folder=sep_folder)
395 | utils.save_txt(save_dir, view, suffix='viewpoint', sep_folder=sep_folder)
396 |
397 | utils.save_videos(save_dir, canon_im_rotate_grid, suffix='image_video', sep_folder=sep_folder, cycle=True)
398 | utils.save_videos(save_dir, canon_normal_rotate_grid, suffix='normal_video', sep_folder=sep_folder, cycle=True)
399 |
400 | # save scores if gt is loaded
401 | if self.load_gt_depth:
402 | depth_gt = ((self.depth_gt[:b] -self.min_depth)/(self.max_depth-self.min_depth)).clamp(0,1).detach().cpu().unsqueeze(1).numpy()
403 | normal_gt = self.normal_gt[:b].permute(0,3,1,2).detach().cpu().numpy() /2+0.5
404 | utils.save_images(save_dir, depth_gt, suffix='depth_gt', sep_folder=sep_folder)
405 | utils.save_images(save_dir, normal_gt, suffix='normal_gt', sep_folder=sep_folder)
406 |
407 | all_scores = torch.stack([
408 | self.acc_mae_masked.detach().cpu(),
409 | self.acc_mse_masked.detach().cpu(),
410 | self.acc_sie_masked.detach().cpu(),
411 | self.acc_normal_masked.detach().cpu()], 1)
412 | if not hasattr(self, 'all_scores'):
413 | self.all_scores = torch.FloatTensor()
414 | self.all_scores = torch.cat([self.all_scores, all_scores], 0)
415 |
416 | def save_scores(self, path):
417 | # save scores if gt is loaded
418 | if self.load_gt_depth:
419 | header = 'MAE_masked, \
420 | MSE_masked, \
421 | SIE_masked, \
422 | NorErr_masked'
423 | mean = self.all_scores.mean(0)
424 | std = self.all_scores.std(0)
425 | header = header + '\nMean: ' + ',\t'.join(['%.8f'%x for x in mean])
426 | header = header + '\nStd: ' + ',\t'.join(['%.8f'%x for x in std])
427 | utils.save_scores(path, self.all_scores, header=header)
428 |
--------------------------------------------------------------------------------
/unsup3d/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 |
5 |
6 | EPS = 1e-7
7 |
8 |
9 | class Encoder(nn.Module):
10 | def __init__(self, cin, cout, nf=64, activation=nn.Tanh):
11 | super(Encoder, self).__init__()
12 | network = [
13 | nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32
14 | nn.ReLU(inplace=True),
15 | nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16
16 | nn.ReLU(inplace=True),
17 | nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8
18 | nn.ReLU(inplace=True),
19 | nn.Conv2d(nf*4, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4
20 | nn.ReLU(inplace=True),
21 | nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1
22 | nn.ReLU(inplace=True),
23 | nn.Conv2d(nf*8, cout, kernel_size=1, stride=1, padding=0, bias=False)]
24 | if activation is not None:
25 | network += [activation()]
26 | self.network = nn.Sequential(*network)
27 |
28 | def forward(self, input):
29 | return self.network(input).reshape(input.size(0),-1)
30 |
31 |
32 | class EDDeconv(nn.Module):
33 | def __init__(self, cin, cout, zdim=128, nf=64, activation=nn.Tanh):
34 | super(EDDeconv, self).__init__()
35 | ## downsampling
36 | network = [
37 | nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32
38 | nn.GroupNorm(16, nf),
39 | nn.LeakyReLU(0.2, inplace=True),
40 | nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16
41 | nn.GroupNorm(16*2, nf*2),
42 | nn.LeakyReLU(0.2, inplace=True),
43 | nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8
44 | nn.GroupNorm(16*4, nf*4),
45 | nn.LeakyReLU(0.2, inplace=True),
46 | nn.Conv2d(nf*4, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4
47 | nn.LeakyReLU(0.2, inplace=True),
48 | nn.Conv2d(nf*8, zdim, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1
49 | nn.ReLU(inplace=True)]
50 | ## upsampling
51 | network += [
52 | nn.ConvTranspose2d(zdim, nf*8, kernel_size=4, stride=1, padding=0, bias=False), # 1x1 -> 4x4
53 | nn.ReLU(inplace=True),
54 | nn.Conv2d(nf*8, nf*8, kernel_size=3, stride=1, padding=1, bias=False),
55 | nn.ReLU(inplace=True),
56 | nn.ConvTranspose2d(nf*8, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 4x4 -> 8x8
57 | nn.GroupNorm(16*4, nf*4),
58 | nn.ReLU(inplace=True),
59 | nn.Conv2d(nf*4, nf*4, kernel_size=3, stride=1, padding=1, bias=False),
60 | nn.GroupNorm(16*4, nf*4),
61 | nn.ReLU(inplace=True),
62 | nn.ConvTranspose2d(nf*4, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 16x16
63 | nn.GroupNorm(16*2, nf*2),
64 | nn.ReLU(inplace=True),
65 | nn.Conv2d(nf*2, nf*2, kernel_size=3, stride=1, padding=1, bias=False),
66 | nn.GroupNorm(16*2, nf*2),
67 | nn.ReLU(inplace=True),
68 | nn.ConvTranspose2d(nf*2, nf, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 32x32
69 | nn.GroupNorm(16, nf),
70 | nn.ReLU(inplace=True),
71 | nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=False),
72 | nn.GroupNorm(16, nf),
73 | nn.ReLU(inplace=True),
74 | nn.Upsample(scale_factor=2, mode='nearest'), # 32x32 -> 64x64
75 | nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=False),
76 | nn.GroupNorm(16, nf),
77 | nn.ReLU(inplace=True),
78 | nn.Conv2d(nf, nf, kernel_size=5, stride=1, padding=2, bias=False),
79 | nn.GroupNorm(16, nf),
80 | nn.ReLU(inplace=True),
81 | nn.Conv2d(nf, cout, kernel_size=5, stride=1, padding=2, bias=False)]
82 | if activation is not None:
83 | network += [activation()]
84 | self.network = nn.Sequential(*network)
85 |
86 | def forward(self, input):
87 | return self.network(input)
88 |
89 |
90 | class ConfNet(nn.Module):
91 | def __init__(self, cin, cout, zdim=128, nf=64):
92 | super(ConfNet, self).__init__()
93 | ## downsampling
94 | network = [
95 | nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32
96 | nn.GroupNorm(16, nf),
97 | nn.LeakyReLU(0.2, inplace=True),
98 | nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16
99 | nn.GroupNorm(16*2, nf*2),
100 | nn.LeakyReLU(0.2, inplace=True),
101 | nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8
102 | nn.GroupNorm(16*4, nf*4),
103 | nn.LeakyReLU(0.2, inplace=True),
104 | nn.Conv2d(nf*4, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4
105 | nn.LeakyReLU(0.2, inplace=True),
106 | nn.Conv2d(nf*8, zdim, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1
107 | nn.ReLU(inplace=True)]
108 | ## upsampling
109 | network += [
110 | nn.ConvTranspose2d(zdim, nf*8, kernel_size=4, padding=0, bias=False), # 1x1 -> 4x4
111 | nn.ReLU(inplace=True),
112 | nn.ConvTranspose2d(nf*8, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 4x4 -> 8x8
113 | nn.GroupNorm(16*4, nf*4),
114 | nn.ReLU(inplace=True),
115 | nn.ConvTranspose2d(nf*4, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 16x16
116 | nn.GroupNorm(16*2, nf*2),
117 | nn.ReLU(inplace=True)]
118 | self.network = nn.Sequential(*network)
119 |
120 | out_net1 = [
121 | nn.ConvTranspose2d(nf*2, nf, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 32x32
122 | nn.GroupNorm(16, nf),
123 | nn.ReLU(inplace=True),
124 | nn.ConvTranspose2d(nf, nf, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 64x64
125 | nn.GroupNorm(16, nf),
126 | nn.ReLU(inplace=True),
127 | nn.Conv2d(nf, 2, kernel_size=5, stride=1, padding=2, bias=False), # 64x64
128 | nn.Softplus()]
129 | self.out_net1 = nn.Sequential(*out_net1)
130 |
131 | out_net2 = [nn.Conv2d(nf*2, 2, kernel_size=3, stride=1, padding=1, bias=False), # 16x16
132 | nn.Softplus()]
133 | self.out_net2 = nn.Sequential(*out_net2)
134 |
135 | def forward(self, input):
136 | out = self.network(input)
137 | return self.out_net1(out), self.out_net2(out)
138 |
139 |
140 | class PerceptualLoss(nn.Module):
141 | def __init__(self, requires_grad=False):
142 | super(PerceptualLoss, self).__init__()
143 | mean_rgb = torch.FloatTensor([0.485, 0.456, 0.406])
144 | std_rgb = torch.FloatTensor([0.229, 0.224, 0.225])
145 | self.register_buffer('mean_rgb', mean_rgb)
146 | self.register_buffer('std_rgb', std_rgb)
147 |
148 | vgg_pretrained_features = torchvision.models.vgg16(pretrained=True).features
149 | self.slice1 = nn.Sequential()
150 | self.slice2 = nn.Sequential()
151 | self.slice3 = nn.Sequential()
152 | self.slice4 = nn.Sequential()
153 | for x in range(4):
154 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
155 | for x in range(4, 9):
156 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
157 | for x in range(9, 16):
158 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
159 | for x in range(16, 23):
160 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
161 | if not requires_grad:
162 | for param in self.parameters():
163 | param.requires_grad = False
164 |
165 | def normalize(self, x):
166 | out = x/2 + 0.5
167 | out = (out - self.mean_rgb.view(1,3,1,1)) / self.std_rgb.view(1,3,1,1)
168 | return out
169 |
170 | def __call__(self, im1, im2, mask=None, conf_sigma=None):
171 | im = torch.cat([im1,im2], 0)
172 | im = self.normalize(im) # normalize input
173 |
174 | ## compute features
175 | feats = []
176 | f = self.slice1(im)
177 | feats += [torch.chunk(f, 2, dim=0)]
178 | f = self.slice2(f)
179 | feats += [torch.chunk(f, 2, dim=0)]
180 | f = self.slice3(f)
181 | feats += [torch.chunk(f, 2, dim=0)]
182 | f = self.slice4(f)
183 | feats += [torch.chunk(f, 2, dim=0)]
184 |
185 | losses = []
186 | for f1, f2 in feats[2:3]: # use relu3_3 features only
187 | loss = (f1-f2)**2
188 | if conf_sigma is not None:
189 | loss = loss / (2*conf_sigma**2 +EPS) + (conf_sigma +EPS).log()
190 | if mask is not None:
191 | b, c, h, w = loss.shape
192 | _, _, hm, wm = mask.shape
193 | sh, sw = hm//h, wm//w
194 | mask0 = nn.functional.avg_pool2d(mask, kernel_size=(sh,sw), stride=(sh,sw)).expand_as(loss)
195 | loss = (loss * mask0).sum() / mask0.sum()
196 | else:
197 | loss = loss.mean()
198 | losses += [loss]
199 | return sum(losses)
200 |
--------------------------------------------------------------------------------
/unsup3d/renderer/__init__.py:
--------------------------------------------------------------------------------
1 | from .renderer import Renderer
2 |
--------------------------------------------------------------------------------
/unsup3d/renderer/renderer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import neural_renderer as nr
4 | from .utils import *
5 |
6 |
7 | EPS = 1e-7
8 |
9 |
10 | class Renderer():
11 | def __init__(self, cfgs):
12 | self.device = cfgs.get('device', 'cpu')
13 | self.image_size = cfgs.get('image_size', 64)
14 | self.min_depth = cfgs.get('min_depth', 0.9)
15 | self.max_depth = cfgs.get('max_depth', 1.1)
16 | self.rot_center_depth = cfgs.get('rot_center_depth', (self.min_depth+self.max_depth)/2)
17 | self.fov = cfgs.get('fov', 10)
18 | self.tex_cube_size = cfgs.get('tex_cube_size', 2)
19 | self.renderer_min_depth = cfgs.get('renderer_min_depth', 0.1)
20 | self.renderer_max_depth = cfgs.get('renderer_max_depth', 10.)
21 |
22 | #### camera intrinsics
23 | # (u) (x)
24 | # d * K^-1 (v) = (y)
25 | # (1) (z)
26 |
27 | ## renderer for visualization
28 | R = [[[1.,0.,0.],
29 | [0.,1.,0.],
30 | [0.,0.,1.]]]
31 | R = torch.FloatTensor(R).to(self.device)
32 | t = torch.zeros(1,3, dtype=torch.float32).to(self.device)
33 | fx = (self.image_size-1)/2/(math.tan(self.fov/2 *math.pi/180))
34 | fy = (self.image_size-1)/2/(math.tan(self.fov/2 *math.pi/180))
35 | cx = (self.image_size-1)/2
36 | cy = (self.image_size-1)/2
37 | K = [[fx, 0., cx],
38 | [0., fy, cy],
39 | [0., 0., 1.]]
40 | K = torch.FloatTensor(K).to(self.device)
41 | self.inv_K = torch.inverse(K).unsqueeze(0)
42 | self.K = K.unsqueeze(0)
43 | self.renderer = nr.Renderer(camera_mode='projection',
44 | light_intensity_ambient=1.0,
45 | light_intensity_directional=0.,
46 | K=self.K, R=R, t=t,
47 | near=self.renderer_min_depth, far=self.renderer_max_depth,
48 | image_size=self.image_size, orig_size=self.image_size,
49 | fill_back=True,
50 | background_color=[1,1,1])
51 |
52 | def set_transform_matrices(self, view):
53 | self.rot_mat, self.trans_xyz = get_transform_matrices(view)
54 |
55 | def rotate_pts(self, pts, rot_mat):
56 | centroid = torch.FloatTensor([0.,0.,self.rot_center_depth]).to(pts.device).view(1,1,3)
57 | pts = pts - centroid # move to centroid
58 | pts = pts.matmul(rot_mat.transpose(2,1)) # rotate
59 | pts = pts + centroid # move back
60 | return pts
61 |
62 | def translate_pts(self, pts, trans_xyz):
63 | return pts + trans_xyz
64 |
65 | def depth_to_3d_grid(self, depth):
66 | b, h, w = depth.shape
67 | grid_2d = get_grid(b, h, w, normalize=False).to(depth.device) # Nxhxwx2
68 | depth = depth.unsqueeze(-1)
69 | grid_3d = torch.cat((grid_2d, torch.ones_like(depth)), dim=3)
70 | grid_3d = grid_3d.matmul(self.inv_K.to(depth.device).transpose(2,1)) * depth
71 | return grid_3d
72 |
73 | def grid_3d_to_2d(self, grid_3d):
74 | b, h, w, _ = grid_3d.shape
75 | grid_2d = grid_3d / grid_3d[...,2:]
76 | grid_2d = grid_2d.matmul(self.K.to(grid_3d.device).transpose(2,1))[:,:,:,:2]
77 | WH = torch.FloatTensor([w-1, h-1]).to(grid_3d.device).view(1,1,1,2)
78 | grid_2d = grid_2d / WH *2.-1. # normalize to -1~1
79 | return grid_2d
80 |
81 | def get_warped_3d_grid(self, depth):
82 | b, h, w = depth.shape
83 | grid_3d = self.depth_to_3d_grid(depth).reshape(b,-1,3)
84 | grid_3d = self.rotate_pts(grid_3d, self.rot_mat)
85 | grid_3d = self.translate_pts(grid_3d, self.trans_xyz)
86 | return grid_3d.reshape(b,h,w,3) # return 3d vertices
87 |
88 | def get_inv_warped_3d_grid(self, depth):
89 | b, h, w = depth.shape
90 | grid_3d = self.depth_to_3d_grid(depth).reshape(b,-1,3)
91 | grid_3d = self.translate_pts(grid_3d, -self.trans_xyz)
92 | grid_3d = self.rotate_pts(grid_3d, self.rot_mat.transpose(2,1))
93 | return grid_3d.reshape(b,h,w,3) # return 3d vertices
94 |
95 | def get_warped_2d_grid(self, depth):
96 | b, h, w = depth.shape
97 | grid_3d = self.get_warped_3d_grid(depth)
98 | grid_2d = self.grid_3d_to_2d(grid_3d)
99 | return grid_2d
100 |
101 | def get_inv_warped_2d_grid(self, depth):
102 | b, h, w = depth.shape
103 | grid_3d = self.get_inv_warped_3d_grid(depth)
104 | grid_2d = self.grid_3d_to_2d(grid_3d)
105 | return grid_2d
106 |
107 | def warp_canon_depth(self, canon_depth):
108 | b, h, w = canon_depth.shape
109 | grid_3d = self.get_warped_3d_grid(canon_depth).reshape(b,-1,3)
110 | faces = get_face_idx(b, h, w).to(canon_depth.device)
111 | warped_depth = self.renderer.render_depth(grid_3d, faces)
112 |
113 | # allow some margin out of valid range
114 | margin = (self.max_depth - self.min_depth) /2
115 | warped_depth = warped_depth.clamp(min=self.min_depth-margin, max=self.max_depth+margin)
116 | return warped_depth
117 |
118 | def get_normal_from_depth(self, depth):
119 | b, h, w = depth.shape
120 | grid_3d = self.depth_to_3d_grid(depth)
121 |
122 | tu = grid_3d[:,1:-1,2:] - grid_3d[:,1:-1,:-2]
123 | tv = grid_3d[:,2:,1:-1] - grid_3d[:,:-2,1:-1]
124 | normal = tu.cross(tv, dim=3)
125 |
126 | zero = torch.FloatTensor([0,0,1]).to(depth.device)
127 | normal = torch.cat([zero.repeat(b,h-2,1,1), normal, zero.repeat(b,h-2,1,1)], 2)
128 | normal = torch.cat([zero.repeat(b,1,w,1), normal, zero.repeat(b,1,w,1)], 1)
129 | normal = normal / (((normal**2).sum(3, keepdim=True))**0.5 + EPS)
130 | return normal
131 |
132 | def render_yaw(self, im, depth, v_before=None, v_after=None, rotations=None, maxr=90, nsample=9, crop_mesh=None):
133 | b, c, h, w = im.shape
134 | grid_3d = self.depth_to_3d_grid(depth)
135 |
136 | if crop_mesh is not None:
137 | top, bottom, left, right = crop_mesh # pixels from border to be cropped
138 | if top > 0:
139 | grid_3d[:,:top,:,1] = grid_3d[:,top:top+1,:,1].repeat(1,top,1)
140 | grid_3d[:,:top,:,2] = grid_3d[:,top:top+1,:,2].repeat(1,top,1)
141 | if bottom > 0:
142 | grid_3d[:,-bottom:,:,1] = grid_3d[:,-bottom-1:-bottom,:,1].repeat(1,bottom,1)
143 | grid_3d[:,-bottom:,:,2] = grid_3d[:,-bottom-1:-bottom,:,2].repeat(1,bottom,1)
144 | if left > 0:
145 | grid_3d[:,:,:left,0] = grid_3d[:,:,left:left+1,0].repeat(1,1,left)
146 | grid_3d[:,:,:left,2] = grid_3d[:,:,left:left+1,2].repeat(1,1,left)
147 | if right > 0:
148 | grid_3d[:,:,-right:,0] = grid_3d[:,:,-right-1:-right,0].repeat(1,1,right)
149 | grid_3d[:,:,-right:,2] = grid_3d[:,:,-right-1:-right,2].repeat(1,1,right)
150 |
151 | grid_3d = grid_3d.reshape(b,-1,3)
152 | im_trans = []
153 |
154 | # inverse warp
155 | if v_before is not None:
156 | rot_mat, trans_xyz = get_transform_matrices(v_before)
157 | grid_3d = self.translate_pts(grid_3d, -trans_xyz)
158 | grid_3d = self.rotate_pts(grid_3d, rot_mat.transpose(2,1))
159 |
160 | if rotations is None:
161 | rotations = torch.linspace(-math.pi/180*maxr, math.pi/180*maxr, nsample)
162 | for i, ri in enumerate(rotations):
163 | ri = torch.FloatTensor([0, ri, 0]).to(im.device).view(1,3)
164 | rot_mat_i, _ = get_transform_matrices(ri)
165 | grid_3d_i = self.rotate_pts(grid_3d, rot_mat_i.repeat(b,1,1))
166 |
167 | if v_after is not None:
168 | if len(v_after.shape) == 3:
169 | v_after_i = v_after[i]
170 | else:
171 | v_after_i = v_after
172 | rot_mat, trans_xyz = get_transform_matrices(v_after_i)
173 | grid_3d_i = self.rotate_pts(grid_3d_i, rot_mat)
174 | grid_3d_i = self.translate_pts(grid_3d_i, trans_xyz)
175 |
176 | faces = get_face_idx(b, h, w).to(im.device)
177 | textures = get_textures_from_im(im, tx_size=self.tex_cube_size)
178 | warped_images = self.renderer.render_rgb(grid_3d_i, faces, textures).clamp(min=-1., max=1.)
179 | im_trans += [warped_images]
180 | return torch.stack(im_trans, 1) # b x t x c x h x w
181 |
--------------------------------------------------------------------------------
/unsup3d/renderer/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def mm_normalize(x, min=0, max=1):
5 | x_min = x.min()
6 | x_max = x.max()
7 | x_range = x_max - x_min
8 | x_z = (x - x_min) / x_range
9 | x_out = x_z * (max - min) + min
10 | return x_out
11 |
12 |
13 | def rand_range(size, min, max):
14 | return torch.rand(size)*(max-min)+min
15 |
16 |
17 | def rand_posneg_range(size, min, max):
18 | i = (torch.rand(size) > 0.5).type(torch.float)*2.-1.
19 | return i*rand_range(size, min, max)
20 |
21 |
22 | def get_grid(b, H, W, normalize=True):
23 | if normalize:
24 | h_range = torch.linspace(-1,1,H)
25 | w_range = torch.linspace(-1,1,W)
26 | else:
27 | h_range = torch.arange(0,H)
28 | w_range = torch.arange(0,W)
29 | grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).repeat(b,1,1,1).flip(3).float() # flip h,w to x,y
30 | return grid
31 |
32 |
33 | def get_rotation_matrix(tx, ty, tz):
34 | m_x = torch.zeros((len(tx), 3, 3)).to(tx.device)
35 | m_y = torch.zeros((len(tx), 3, 3)).to(tx.device)
36 | m_z = torch.zeros((len(tx), 3, 3)).to(tx.device)
37 |
38 | m_x[:, 1, 1], m_x[:, 1, 2] = tx.cos(), -tx.sin()
39 | m_x[:, 2, 1], m_x[:, 2, 2] = tx.sin(), tx.cos()
40 | m_x[:, 0, 0] = 1
41 |
42 | m_y[:, 0, 0], m_y[:, 0, 2] = ty.cos(), ty.sin()
43 | m_y[:, 2, 0], m_y[:, 2, 2] = -ty.sin(), ty.cos()
44 | m_y[:, 1, 1] = 1
45 |
46 | m_z[:, 0, 0], m_z[:, 0, 1] = tz.cos(), -tz.sin()
47 | m_z[:, 1, 0], m_z[:, 1, 1] = tz.sin(), tz.cos()
48 | m_z[:, 2, 2] = 1
49 | return torch.matmul(m_z, torch.matmul(m_y, m_x))
50 |
51 |
52 | def get_transform_matrices(view):
53 | b = view.size(0)
54 | if view.size(1) == 6:
55 | rx = view[:,0]
56 | ry = view[:,1]
57 | rz = view[:,2]
58 | trans_xyz = view[:,3:].reshape(b,1,3)
59 | elif view.size(1) == 5:
60 | rx = view[:,0]
61 | ry = view[:,1]
62 | rz = view[:,2]
63 | delta_xy = view[:,3:].reshape(b,1,2)
64 | trans_xyz = torch.cat([delta_xy, torch.zeros(b,1,1).to(view.device)], 2)
65 | elif view.size(1) == 3:
66 | rx = view[:,0]
67 | ry = view[:,1]
68 | rz = view[:,2]
69 | trans_xyz = torch.zeros(b,1,3).to(view.device)
70 | rot_mat = get_rotation_matrix(rx, ry, rz)
71 | return rot_mat, trans_xyz
72 |
73 |
74 | def get_face_idx(b, h, w):
75 | idx_map = torch.arange(h*w).reshape(h,w)
76 | faces1 = torch.stack([idx_map[:h-1,:w-1], idx_map[1:,:w-1], idx_map[:h-1,1:]], -1).reshape(-1,3)
77 | faces2 = torch.stack([idx_map[:h-1,1:], idx_map[1:,:w-1], idx_map[1:,1:]], -1).reshape(-1,3)
78 | return torch.cat([faces1,faces2], 0).repeat(b,1,1).int()
79 |
80 |
81 | def vcolor_to_texture_cube(vcolors):
82 | # input bxcxnx3
83 | b, c, n, f = vcolors.shape
84 | coeffs = torch.FloatTensor(
85 | [[ 0.5, 0.5, 0.5],
86 | [ 0. , 0. , 1. ],
87 | [ 0. , 1. , 0. ],
88 | [-0.5, 0.5, 0.5],
89 | [ 1. , 0. , 0. ],
90 | [ 0.5, -0.5, 0.5],
91 | [ 0.5, 0.5, -0.5],
92 | [ 0. , 0. , 0. ]]).to(vcolors.device)
93 | return coeffs.matmul(vcolors.permute(0,2,3,1)).reshape(b,n,2,2,2,c)
94 |
95 |
96 | def get_textures_from_im(im, tx_size=1):
97 | b, c, h, w = im.shape
98 | if tx_size == 1:
99 | textures = torch.cat([im[:,:,:h-1,:w-1].reshape(b,c,-1), im[:,:,1:,1:].reshape(b,c,-1)], 2)
100 | textures = textures.transpose(2,1).reshape(b,-1,1,1,1,c)
101 | elif tx_size == 2:
102 | textures1 = torch.stack([im[:,:,:h-1,:w-1], im[:,:,:h-1,1:], im[:,:,1:,:w-1]], -1).reshape(b,c,-1,3)
103 | textures2 = torch.stack([im[:,:,1:,:w-1], im[:,:,:h-1,1:], im[:,:,1:,1:]], -1).reshape(b,c,-1,3)
104 | textures = vcolor_to_texture_cube(torch.cat([textures1, textures2], 2)) # bxnx2x2x2xc
105 | else:
106 | raise NotImplementedError("Currently support texture size of 1 or 2 only.")
107 | return textures
108 |
--------------------------------------------------------------------------------
/unsup3d/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | from datetime import datetime
4 | import numpy as np
5 | import torch
6 | from . import meters
7 | from . import utils
8 | from .dataloaders import get_data_loaders
9 |
10 |
11 | class Trainer():
12 | def __init__(self, cfgs, model):
13 | self.device = cfgs.get('device', 'cpu')
14 | self.num_epochs = cfgs.get('num_epochs', 30)
15 | self.batch_size = cfgs.get('batch_size', 64)
16 | self.checkpoint_dir = cfgs.get('checkpoint_dir', 'results')
17 | self.save_checkpoint_freq = cfgs.get('save_checkpoint_freq', 1)
18 | self.keep_num_checkpoint = cfgs.get('keep_num_checkpoint', 2) # -1 for keeping all checkpoints
19 | self.resume = cfgs.get('resume', True)
20 | self.use_logger = cfgs.get('use_logger', True)
21 | self.log_freq = cfgs.get('log_freq', 1000)
22 | self.archive_code = cfgs.get('archive_code', True)
23 | self.checkpoint_name = cfgs.get('checkpoint_name', None)
24 | self.test_result_dir = cfgs.get('test_result_dir', None)
25 | self.cfgs = cfgs
26 |
27 | self.metrics_trace = meters.MetricsTrace()
28 | self.make_metrics = lambda m=None: meters.StandardMetrics(m)
29 | self.model = model(cfgs)
30 | self.model.trainer = self
31 | self.train_loader, self.val_loader, self.test_loader = get_data_loaders(cfgs)
32 |
33 | def load_checkpoint(self, optim=True):
34 | """Search the specified/latest checkpoint in checkpoint_dir and load the model and optimizer."""
35 | if self.checkpoint_name is not None:
36 | checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name)
37 | else:
38 | checkpoints = sorted(glob.glob(os.path.join(self.checkpoint_dir, '*.pth')))
39 | if len(checkpoints) == 0:
40 | return 0
41 | checkpoint_path = checkpoints[-1]
42 | self.checkpoint_name = os.path.basename(checkpoint_path)
43 | print(f"Loading checkpoint from {checkpoint_path}")
44 | cp = torch.load(checkpoint_path, map_location=self.device)
45 | self.model.load_model_state(cp)
46 | if optim:
47 | self.model.load_optimizer_state(cp)
48 | self.metrics_trace = cp['metrics_trace']
49 | epoch = cp['epoch']
50 | return epoch
51 |
52 | def save_checkpoint(self, epoch, optim=True):
53 | """Save model, optimizer, and metrics state to a checkpoint in checkpoint_dir for the specified epoch."""
54 | utils.xmkdir(self.checkpoint_dir)
55 | checkpoint_path = os.path.join(self.checkpoint_dir, f'checkpoint{epoch:03}.pth')
56 | state_dict = self.model.get_model_state()
57 | if optim:
58 | optimizer_state = self.model.get_optimizer_state()
59 | state_dict = {**state_dict, **optimizer_state}
60 | state_dict['metrics_trace'] = self.metrics_trace
61 | state_dict['epoch'] = epoch
62 | print(f"Saving checkpoint to {checkpoint_path}")
63 | torch.save(state_dict, checkpoint_path)
64 | if self.keep_num_checkpoint > 0:
65 | utils.clean_checkpoint(self.checkpoint_dir, keep_num=self.keep_num_checkpoint)
66 |
67 | def save_clean_checkpoint(self, path):
68 | """Save model state only to specified path."""
69 | torch.save(self.model.get_model_state(), path)
70 |
71 | def test(self):
72 | """Perform testing."""
73 | self.model.to_device(self.device)
74 | self.current_epoch = self.load_checkpoint(optim=False)
75 | if self.test_result_dir is None:
76 | self.test_result_dir = os.path.join(self.checkpoint_dir, f'test_results_{self.checkpoint_name}'.replace('.pth',''))
77 | print(f"Saving testing results to {self.test_result_dir}")
78 |
79 | with torch.no_grad():
80 | m = self.run_epoch(self.test_loader, epoch=self.current_epoch, is_test=True)
81 |
82 | score_path = os.path.join(self.test_result_dir, 'eval_scores.txt')
83 | self.model.save_scores(score_path)
84 |
85 | def train(self):
86 | """Perform training."""
87 | ## archive code and configs
88 | if self.archive_code:
89 | utils.archive_code(os.path.join(self.checkpoint_dir, 'archived_code.zip'), filetypes=['.py', '.yml'])
90 | utils.dump_yaml(os.path.join(self.checkpoint_dir, 'configs.yml'), self.cfgs)
91 |
92 | ## initialize
93 | start_epoch = 0
94 | self.metrics_trace.reset()
95 | self.train_iter_per_epoch = len(self.train_loader)
96 | self.model.to_device(self.device)
97 | self.model.init_optimizers()
98 |
99 | ## resume from checkpoint
100 | if self.resume:
101 | start_epoch = self.load_checkpoint(optim=True)
102 |
103 | ## initialize tensorboardX logger
104 | if self.use_logger:
105 | from tensorboardX import SummaryWriter
106 | self.logger = SummaryWriter(os.path.join(self.checkpoint_dir, 'logs', datetime.now().strftime("%Y%m%d-%H%M%S")))
107 |
108 | ## cache one batch for visualization
109 | self.viz_input = self.val_loader.__iter__().__next__()
110 |
111 | ## run epochs
112 | print(f"{self.model.model_name}: optimizing to {self.num_epochs} epochs")
113 | for epoch in range(start_epoch, self.num_epochs):
114 | self.current_epoch = epoch
115 | metrics = self.run_epoch(self.train_loader, epoch)
116 | self.metrics_trace.append("train", metrics)
117 |
118 | with torch.no_grad():
119 | metrics = self.run_epoch(self.val_loader, epoch, is_validation=True)
120 | self.metrics_trace.append("val", metrics)
121 |
122 | if (epoch+1) % self.save_checkpoint_freq == 0:
123 | self.save_checkpoint(epoch+1, optim=True)
124 | self.metrics_trace.plot(pdf_path=os.path.join(self.checkpoint_dir, 'metrics.pdf'))
125 | self.metrics_trace.save(os.path.join(self.checkpoint_dir, 'metrics.json'))
126 |
127 | print(f"Training completed after {epoch+1} epochs.")
128 |
129 | def run_epoch(self, loader, epoch=0, is_validation=False, is_test=False):
130 | """Run one epoch."""
131 | is_train = not is_validation and not is_test
132 | metrics = self.make_metrics()
133 |
134 | if is_train:
135 | print(f"Starting training epoch {epoch}")
136 | self.model.set_train()
137 | else:
138 | print(f"Starting validation epoch {epoch}")
139 | self.model.set_eval()
140 |
141 | for iter, input in enumerate(loader):
142 | m = self.model.forward(input)
143 | if is_train:
144 | self.model.backward()
145 | elif is_test:
146 | self.model.save_results(self.test_result_dir)
147 |
148 | metrics.update(m, self.batch_size)
149 | print(f"{'T' if is_train else 'V'}{epoch:02}/{iter:05}/{metrics}")
150 |
151 | if self.use_logger and is_train:
152 | total_iter = iter + epoch*self.train_iter_per_epoch
153 | if total_iter % self.log_freq == 0:
154 | self.model.forward(self.viz_input)
155 | self.model.visualize(self.logger, total_iter=total_iter, max_bs=25)
156 | return metrics
157 |
--------------------------------------------------------------------------------
/unsup3d/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import glob
4 | import yaml
5 | import random
6 | import numpy as np
7 | import cv2
8 | import torch
9 | import zipfile
10 |
11 |
12 | def setup_runtime(args):
13 | """Load configs, initialize CUDA, CuDNN and the random seeds."""
14 |
15 | # Setup CUDA
16 | cuda_device_id = args.gpu
17 | if cuda_device_id is not None:
18 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
19 | os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda_device_id)
20 | if torch.cuda.is_available():
21 | torch.backends.cudnn.enabled = True
22 | torch.backends.cudnn.benchmark = True
23 | torch.backends.cudnn.deterministic = True
24 |
25 | # Setup random seeds for reproducibility
26 | random.seed(args.seed)
27 | np.random.seed(args.seed)
28 | torch.manual_seed(args.seed)
29 | if torch.cuda.is_available():
30 | torch.cuda.manual_seed_all(args.seed)
31 |
32 | ## Load config
33 | cfgs = {}
34 | if args.config is not None and os.path.isfile(args.config):
35 | cfgs = load_yaml(args.config)
36 |
37 | cfgs['config'] = args.config
38 | cfgs['seed'] = args.seed
39 | cfgs['num_workers'] = args.num_workers
40 | cfgs['device'] = 'cuda:0' if torch.cuda.is_available() and cuda_device_id is not None else 'cpu'
41 |
42 | print(f"Environment: GPU {cuda_device_id} seed {args.seed} number of workers {args.num_workers}")
43 | return cfgs
44 |
45 |
46 | def load_yaml(path):
47 | print(f"Loading configs from {path}")
48 | with open(path, 'r') as f:
49 | return yaml.safe_load(f)
50 |
51 |
52 | def dump_yaml(path, cfgs):
53 | print(f"Saving configs to {path}")
54 | xmkdir(os.path.dirname(path))
55 | with open(path, 'w') as f:
56 | return yaml.safe_dump(cfgs, f)
57 |
58 |
59 | def xmkdir(path):
60 | """Create directory PATH recursively if it does not exist."""
61 | os.makedirs(path, exist_ok=True)
62 |
63 |
64 | def clean_checkpoint(checkpoint_dir, keep_num=2):
65 | if keep_num > 0:
66 | names = list(sorted(
67 | glob.glob(os.path.join(checkpoint_dir, 'checkpoint*.pth'))
68 | ))
69 | if len(names) > keep_num:
70 | for name in names[:-keep_num]:
71 | print(f"Deleting obslete checkpoint file {name}")
72 | os.remove(name)
73 |
74 |
75 | def archive_code(arc_path, filetypes=['.py', '.yml']):
76 | print(f"Archiving code to {arc_path}")
77 | xmkdir(os.path.dirname(arc_path))
78 | zipf = zipfile.ZipFile(arc_path, 'w', zipfile.ZIP_DEFLATED)
79 | cur_dir = os.getcwd()
80 | flist = []
81 | for ftype in filetypes:
82 | flist.extend(glob.glob(os.path.join(cur_dir, '**', '*'+ftype), recursive=True))
83 | [zipf.write(f, arcname=f.replace(cur_dir,'archived_code', 1)) for f in flist]
84 | zipf.close()
85 |
86 |
87 | def get_model_device(model):
88 | return next(model.parameters()).device
89 |
90 |
91 | def set_requires_grad(nets, requires_grad=False):
92 | if not isinstance(nets, list):
93 | nets = [nets]
94 | for net in nets:
95 | if net is not None:
96 | for param in net.parameters():
97 | param.requires_grad = requires_grad
98 |
99 |
100 | def draw_bbox(im, size):
101 | b, c, h, w = im.shape
102 | h2, w2 = (h-size)//2, (w-size)//2
103 | marker = np.tile(np.array([[1.],[0.],[0.]]), (1,size))
104 | marker = torch.FloatTensor(marker)
105 | im[:, :, h2, w2:w2+size] = marker
106 | im[:, :, h2+size, w2:w2+size] = marker
107 | im[:, :, h2:h2+size, w2] = marker
108 | im[:, :, h2:h2+size, w2+size] = marker
109 | return im
110 |
111 |
112 | def save_videos(out_fold, imgs, prefix='', suffix='', sep_folder=True, ext='.mp4', cycle=False):
113 | if sep_folder:
114 | out_fold = os.path.join(out_fold, suffix)
115 | xmkdir(out_fold)
116 | prefix = prefix + '_' if prefix else ''
117 | suffix = '_' + suffix if suffix else ''
118 | offset = len(glob.glob(os.path.join(out_fold, prefix+'*'+suffix+ext))) +1
119 |
120 | imgs = imgs.transpose(0,1,3,4,2) # BxTxCxHxW -> BxTxHxWxC
121 | for i, fs in enumerate(imgs):
122 | if cycle:
123 | fs = np.concatenate([fs, fs[::-1]], 0)
124 | fourcc = cv2.VideoWriter_fourcc(*'mp4v')
125 | # fourcc = cv2.VideoWriter_fourcc(*'avc1')
126 | vid = cv2.VideoWriter(os.path.join(out_fold, prefix+'%05d'%(i+offset)+suffix+ext), fourcc, 5, (fs.shape[2], fs.shape[1]))
127 | [vid.write(np.uint8(f[...,::-1]*255.)) for f in fs]
128 | vid.release()
129 |
130 |
131 | def save_images(out_fold, imgs, prefix='', suffix='', sep_folder=True, ext='.png'):
132 | if sep_folder:
133 | out_fold = os.path.join(out_fold, suffix)
134 | xmkdir(out_fold)
135 | prefix = prefix + '_' if prefix else ''
136 | suffix = '_' + suffix if suffix else ''
137 | offset = len(glob.glob(os.path.join(out_fold, prefix+'*'+suffix+ext))) +1
138 |
139 | imgs = imgs.transpose(0,2,3,1)
140 | for i, img in enumerate(imgs):
141 | if 'depth' in suffix:
142 | im_out = np.uint16(img[...,::-1]*65535.)
143 | else:
144 | im_out = np.uint8(img[...,::-1]*255.)
145 | cv2.imwrite(os.path.join(out_fold, prefix+'%05d'%(i+offset)+suffix+ext), im_out)
146 |
147 |
148 | def save_txt(out_fold, data, prefix='', suffix='', sep_folder=True, ext='.txt'):
149 | if sep_folder:
150 | out_fold = os.path.join(out_fold, suffix)
151 | xmkdir(out_fold)
152 | prefix = prefix + '_' if prefix else ''
153 | suffix = '_' + suffix if suffix else ''
154 | offset = len(glob.glob(os.path.join(out_fold, prefix+'*'+suffix+ext))) +1
155 |
156 | [np.savetxt(os.path.join(out_fold, prefix+'%05d'%(i+offset)+suffix+ext), d, fmt='%.6f', delimiter=', ') for i,d in enumerate(data)]
157 |
158 |
159 | def compute_sc_inv_err(d_pred, d_gt, mask=None):
160 | b = d_pred.size(0)
161 | diff = d_pred - d_gt
162 | if mask is not None:
163 | diff = diff * mask
164 | avg = diff.view(b, -1).sum(1) / (mask.view(b, -1).sum(1))
165 | score = (diff - avg.view(b,1,1))**2 * mask
166 | else:
167 | avg = diff.view(b, -1).mean(1)
168 | score = (diff - avg.view(b,1,1))**2
169 | return score # masked error maps
170 |
171 |
172 | def compute_angular_distance(n1, n2, mask=None):
173 | dist = (n1*n2).sum(3).clamp(-1,1).acos() /np.pi*180
174 | return dist*mask if mask is not None else dist
175 |
176 |
177 | def save_scores(out_path, scores, header=''):
178 | print('Saving scores to %s' %out_path)
179 | np.savetxt(out_path, scores, fmt='%.8f', delimiter=',\t', header=header)
180 |
--------------------------------------------------------------------------------