├── .gitignore ├── README.md ├── config.yml ├── config_example.py ├── dataset ├── __init__.py ├── eval_dataset.py └── uv_dataset.py ├── model ├── __init__.py ├── pipeline.py ├── texture.py └── unet.py ├── nni_train.py ├── render.py ├── render_texture.py ├── requirements.txt ├── search_space.json ├── train.py ├── train_texture.py ├── train_unet.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Texture 2 | 3 | This repository implements [Deferred Neural Rendering: Image Synthesis using Neural Textures](https://arxiv.org/abs/1904.12356) . 4 | 5 | 6 | 7 | ## Requirements 8 | 9 | + Python 3.6+ 10 | + argparse 11 | + nni 12 | + NumPy 13 | + Pillow 14 | + pytorch 15 | + tensorboardX 16 | + torchvision 17 | + tqdm 18 | 19 | 20 | 21 | ## File Organization 22 | 23 | The root directory contains several subdirectories and files: 24 | 25 | ``` 26 | dataset/ --- custom PyTorch Dataset classes for loading included data 27 | model/ --- custom PyTorch Module classes 28 | util.py --- useful procedures 29 | render.py --- render using texture and U-Net 30 | render_texture.py --- render from RGB texture or neural texture 31 | train.py --- optimize texture and U-Net jointly 32 | train_texture.py --- optimize only texture 33 | train_unet.py --- optimize U-Net using pretrained 3-channel texture 34 | ``` 35 | 36 | 37 | 38 | ## How to Use 39 | 40 | ### Set up Environment 41 | 42 | Install python >= 3.6 and create an environment. 43 | 44 | Install requirements: 45 | 46 | ```powershell 47 | pip install -r requirements.txt 48 | ``` 49 | 50 | ### Prepare Data 51 | 52 | We need 3 folders of data: 53 | 54 | + `/data/frame/` with video frames `.png` files 55 | + `/data/uv/` with uv-map `.npy` files, each shaped (H, W, 2) 56 | + `/data/extrinsics/` with normalized camera extrinsics in `.npy` files, each shaped (3) 57 | 58 | Each frame corresponds to one uv map and corresponding camera extrinsic parameters. They are named sequentially, from `0000` to `xxxx` . 59 | 60 | We demonstrate 2 ways to prepare data. One way is to render training data, the code is at https://github.com/A-Dying-Pig/OpenGL_NeuralTexture. The other way is to reconstruct from real scene, the code is at https://github.com/gerwang/InfiniTAM . 61 | 62 | ### Configuration 63 | 64 | Rename `config_example.py` as `config.py` and set the parameters for training and rendering. 65 | 66 | ### Train Jointly 67 | 68 | ```powershell 69 | python train.py [--args] 70 | ``` 71 | 72 | ### Train Texture 73 | 74 | ```powershell 75 | python train_texture.py [--args] 76 | ``` 77 | 78 | ### Train U-Net 79 | 80 | ```powershell 81 | python train_unet.py [--args] 82 | ``` 83 | 84 | ### Render by Texture 85 | 86 | ```powershell 87 | python render_texture.py [--args] 88 | ``` 89 | 90 | ### Render by Texture and U-Net Jointly 91 | 92 | ```powershell 93 | python render.py [--args] 94 | ``` 95 | 96 | -------------------------------------------------------------------------------- /config.yml: -------------------------------------------------------------------------------- 1 | authorName: default 2 | experimentName: neural_texture 3 | trialConcurrency: 2 4 | maxExecDuration: 24h 5 | maxTrialNum: 6 6 | #choice: local, remote, pai 7 | trainingServicePlatform: local 8 | searchSpacePath: search_space.json 9 | #choice: true, false 10 | useAnnotation: false 11 | tuner: 12 | #choice: TPE, Random, Anneal, Evolution, BatchTuner 13 | #SMAC (SMAC should be installed through nnictl) 14 | builtinTunerName: TPE 15 | classArgs: 16 | #choice: maximize, minimize 17 | optimize_mode: maximize 18 | trial: 19 | command: python nni_train.py 20 | codeDir: . 21 | gpuNum: 1 22 | -------------------------------------------------------------------------------- /config_example.py: -------------------------------------------------------------------------------- 1 | # =============== Basic Configurations =========== 2 | TEXTURE_W = 1024 3 | TEXTURE_H = 1024 4 | TEXTURE_DIM = 16 5 | USE_PYRAMID = True 6 | VIEW_DIRECTION = True 7 | 8 | 9 | # =============== Train Configurations =========== 10 | DATA_DIR = '' 11 | CHECKPOINT_DIR = '' 12 | LOG_DIR = '' 13 | TRAIN_SET = ['{:04d}'.format(i) for i in range(899)] 14 | EPOCH = 50 15 | BATCH_SIZE = 12 16 | CROP_W = 256 17 | CROP_H = 256 18 | LEARNING_RATE = 1e-3 19 | BETAS = '0.9, 0.999' 20 | L2_WEIGHT_DECAY = '0.01, 0.001, 0.0001, 0' 21 | EPS = 1e-8 22 | LOAD = None 23 | LOAD_STEP = 0 24 | EPOCH_PER_CHECKPOINT = 50 25 | 26 | 27 | # =============== Test Configurations ============ 28 | TEST_LOAD = '' 29 | TEST_DATA_DIR = '' 30 | TEST_SET = ['{:04d}'.format(i) for i in range(10)] 31 | SAVE_DIR = '' -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SSRSGJYD/NeuralTexture/d23f5e5ebb2c721525926c4e3d338b7687fcc1d3/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/eval_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | from util import map_transform, view2sh 7 | 8 | 9 | class EvalDataset(Dataset): 10 | 11 | def __init__(self, dir, idx_list, view_direction=False): 12 | self.idx_list = idx_list 13 | self.dir = dir 14 | self.view_direction = view_direction 15 | uv_map = np.load(os.path.join(self.dir, 'uv/'+self.idx_list[0]+'.npy')) 16 | self.height, self.width, _ = uv_map.shape 17 | 18 | def __len__(self): 19 | return len(self.idx_list) 20 | 21 | def __getitem__(self, idx): 22 | uv_map = np.load(os.path.join(self.dir, 'uv/'+self.idx_list[idx]+'.npy')) 23 | nan_pos = np.isnan(uv_map) 24 | uv_map[nan_pos] = 0 25 | if np.any(np.isnan(uv_map)): 26 | print('nan in dataset') 27 | 28 | # final transform 29 | uv_map = map_transform(uv_map) 30 | # mask for invalid uv positions 31 | mask = torch.max(uv_map, dim=2)[0].le(-1.0 + 1e-6) 32 | mask = mask.repeat((3, 1, 1)) 33 | 34 | if self.view_direction: 35 | extrinsics = np.load(os.path.join(self.dir, 'extrinsics/'+self.idx_list[idx]+'.npy')) 36 | return uv_map, extrinsics, mask, self.idx_list[idx] 37 | else: 38 | return uv_map, mask, self.idx_list[idx] 39 | 40 | @staticmethod 41 | def _collect_fn(data, view_direction=False): 42 | if view_direction: 43 | uv_maps, extrinsics, masks, idxs = zip(*data) 44 | uv_maps = torch.stack(tuple(uv_maps), dim=0) 45 | extrinsics = torch.FloatTensor(extrinsics) 46 | masks = torch.stack(tuple(masks), dim=0) 47 | return uv_maps, extrinsics, masks, idxs 48 | else: 49 | uv_maps, masks, idxs = zip(*data) 50 | uv_maps = torch.stack(tuple(uv_maps), dim=0) 51 | masks = torch.stack(tuple(masks), dim=0) 52 | return uv_maps, masks, idxs 53 | 54 | @staticmethod 55 | def get_collect_fn(view_direction=False): 56 | collect_fn = lambda x: EvalDataset._collect_fn(x, view_direction) 57 | return collect_fn 58 | -------------------------------------------------------------------------------- /dataset/uv_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | 6 | from util import augment 7 | 8 | 9 | class UVDataset(Dataset): 10 | 11 | def __init__(self, dir, idx_list, H, W, view_direction=False): 12 | self.idx_list = idx_list 13 | self.dir = dir 14 | self.crop_size = (H, W) 15 | self.view_direction = view_direction 16 | 17 | def __len__(self): 18 | return len(self.idx_list) 19 | 20 | def __getitem__(self, idx): 21 | img = Image.open(os.path.join(self.dir, 'frame/'+self.idx_list[idx]+'.png'), 'r') 22 | uv_map = np.load(os.path.join(self.dir, 'uv/'+self.idx_list[idx]+'.npy')) 23 | nan_pos = np.isnan(uv_map) 24 | uv_map[nan_pos] = 0 25 | if np.any(np.isnan(uv_map)): 26 | print('nan in dataset') 27 | if np.any(np.isinf(uv_map)): 28 | print('inf in dataset') 29 | img, uv_map, mask = augment(img, uv_map, self.crop_size) 30 | if self.view_direction: 31 | # view_map = np.load(os.path.join(self.dir, 'view_normal/'+self.idx_list[idx]+'.npy')) 32 | extrinsics = np.load(os.path.join(self.dir, 'extrinsics/'+self.idx_list[idx]+'.npy')) 33 | return img, uv_map, extrinsics, mask 34 | else: 35 | 36 | return img, uv_map, mask 37 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SSRSGJYD/NeuralTexture/d23f5e5ebb2c721525926c4e3d338b7687fcc1d3/model/__init__.py -------------------------------------------------------------------------------- /model/pipeline.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import sys 4 | import torch 5 | import torch.nn as nn 6 | 7 | sys.path.append('..') 8 | from model.texture import Texture 9 | from model.unet import UNet 10 | 11 | 12 | class PipeLine(nn.Module): 13 | def __init__(self, W, H, feature_num, use_pyramid=True, view_direction=True): 14 | super(PipeLine, self).__init__() 15 | self.feature_num = feature_num 16 | self.use_pyramid = use_pyramid 17 | self.view_direction = view_direction 18 | self.texture = Texture(W, H, feature_num, use_pyramid) 19 | self.unet = UNet(feature_num, 3) 20 | 21 | def _spherical_harmonics_basis(self, extrinsics): 22 | ''' 23 | extrinsics: a tensor shaped (N, 3) 24 | output: a tensor shaped (N, 9) 25 | ''' 26 | batch = extrinsics.shape[0] 27 | sh_bands = torch.ones((batch, 9), dtype=torch.float) 28 | coff_0 = 1 / (2.0*math.sqrt(np.pi)) 29 | coff_1 = math.sqrt(3.0) * coff_0 30 | coff_2 = math.sqrt(15.0) * coff_0 31 | coff_3 = math.sqrt(1.25) * coff_0 32 | # l=0 33 | sh_bands[:, 0] = coff_0 34 | # l=1 35 | sh_bands[:, 1] = extrinsics[:, 1] * coff_1 36 | sh_bands[:, 2] = extrinsics[:, 2] * coff_1 37 | sh_bands[:, 3] = extrinsics[:, 0] * coff_1 38 | # l=2 39 | sh_bands[:, 4] = extrinsics[:, 0] * extrinsics[:, 1] * coff_2 40 | sh_bands[:, 5] = extrinsics[:, 1] * extrinsics[:, 2] * coff_2 41 | sh_bands[:, 6] = (3.0 * extrinsics[:, 2] * extrinsics[:, 2] - 1.0) * coff_3 42 | sh_bands[:, 7] = extrinsics[:, 2] * extrinsics[:, 0] * coff_2 43 | sh_bands[:, 8] = (extrinsics[:, 0] * extrinsics[:, 0] - extrinsics[:, 2] * extrinsics[:, 2]) * coff_2 44 | return sh_bands 45 | 46 | def forward(self, *args): 47 | if self.view_direction: 48 | uv_map, extrinsics = args 49 | x = self.texture(uv_map) 50 | assert x.shape[1] >= 12 51 | basis = self._spherical_harmonics_basis(extrinsics).cuda() 52 | basis = basis.view(basis.shape[0], basis.shape[1], 1, 1) 53 | x[:, 3:12, :, :] = x[:, 3:12, :, :] * basis 54 | else: 55 | uv_map = args[0] 56 | x = self.texture(uv_map) 57 | y = self.unet(x) 58 | return x[:, 0:3, :, :], y 59 | -------------------------------------------------------------------------------- /model/texture.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SingleLayerTexture(nn.Module): 7 | def __init__(self, W, H): 8 | super(SingleLayerTexture, self).__init__() 9 | self.layer1 = nn.Parameter(torch.FloatTensor(1, 1, W, H)) 10 | 11 | def forward(self, x): 12 | batch = x.shape[0] 13 | x = x * 2.0 - 1.0 14 | y = F.grid_sample(self.layer1.repeat(batch,1,1,1), x) 15 | return y 16 | 17 | 18 | class LaplacianPyramid(nn.Module): 19 | def __init__(self, W, H): 20 | super(LaplacianPyramid, self).__init__() 21 | self.layer1 = nn.Parameter(torch.FloatTensor(1, 1, W, H)) 22 | self.layer2 = nn.Parameter(torch.FloatTensor(1, 1, W // 2, H // 2)) 23 | self.layer3 = nn.Parameter(torch.FloatTensor(1, 1, W // 4, H // 4)) 24 | self.layer4 = nn.Parameter(torch.FloatTensor(1, 1, W // 8, H // 8)) 25 | 26 | def forward(self, x): 27 | batch = x.shape[0] 28 | x = x * 2.0 - 1.0 29 | y1 = F.grid_sample(self.layer1.repeat(batch,1,1,1), x) 30 | y2 = F.grid_sample(self.layer2.repeat(batch,1,1,1), x) 31 | y3 = F.grid_sample(self.layer3.repeat(batch,1,1,1), x) 32 | y4 = F.grid_sample(self.layer4.repeat(batch,1,1,1), x) 33 | y = y1 + y2 + y3 + y4 34 | return y 35 | 36 | 37 | class Texture(nn.Module): 38 | def __init__(self, W, H, feature_num, use_pyramid=True): 39 | super(Texture, self).__init__() 40 | self.feature_num = feature_num 41 | self.use_pyramid = use_pyramid 42 | self.layer1 = nn.ParameterList() 43 | self.layer2 = nn.ParameterList() 44 | self.layer3 = nn.ParameterList() 45 | self.layer4 = nn.ParameterList() 46 | if self.use_pyramid: 47 | self.textures = nn.ModuleList([LaplacianPyramid(W, H) for i in range(feature_num)]) 48 | for i in range(self.feature_num): 49 | self.layer1.append(self.textures[i].layer1) 50 | self.layer2.append(self.textures[i].layer2) 51 | self.layer3.append(self.textures[i].layer3) 52 | self.layer4.append(self.textures[i].layer4) 53 | else: 54 | self.textures = nn.ModuleList([SingleLayerTexture(W, H) for i in range(feature_num)]) 55 | for i in range(self.feature_num): 56 | self.layer1.append(self.textures[i].layer1) 57 | 58 | def forward(self, x): 59 | y_i = [] 60 | for i in range(self.feature_num): 61 | y = self.textures[i](x) 62 | y_i.append(y) 63 | y = torch.cat(tuple(y_i), dim=1) 64 | return y 65 | -------------------------------------------------------------------------------- /model/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class down(nn.Module): 7 | def __init__(self, in_ch, out_ch): 8 | super(down, self).__init__() 9 | self.conv = nn.Sequential( 10 | nn.Conv2d(in_ch, out_ch, 4, stride=2, padding=1), 11 | nn.InstanceNorm2d(out_ch), 12 | nn.LeakyReLU(0.2, inplace=True) 13 | ) 14 | 15 | def forward(self, x): 16 | x = self.conv(x) 17 | return x 18 | 19 | 20 | class up(nn.Module): 21 | def __init__(self, in_ch, out_ch, output_pad=0, concat=True, final=False): 22 | super(up, self).__init__() 23 | self.concat = concat 24 | self.final = final 25 | if self.final: 26 | self.conv = nn.Sequential( 27 | nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1, output_padding=output_pad), 28 | nn.InstanceNorm2d(out_ch), 29 | nn.Tanh() 30 | ) 31 | else: 32 | self.conv = nn.Sequential( 33 | nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1, output_padding=output_pad), 34 | nn.InstanceNorm2d(out_ch), 35 | nn.LeakyReLU(0.2, inplace=True) 36 | ) 37 | 38 | def forward(self, x1, x2): 39 | if self.concat: 40 | diffY = x2.size()[2] - x1.size()[2] 41 | diffX = x2.size()[3] - x1.size()[3] 42 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 43 | diffY // 2, diffY - diffY // 2]) 44 | x1 = torch.cat((x2, x1), dim=1) 45 | x1 = self.conv(x1) 46 | return x1 47 | 48 | 49 | class UNet(nn.Module): 50 | def __init__(self, input_channels, output_channels): 51 | super(UNet, self).__init__() 52 | self.down1 = down(input_channels, 64) 53 | self.down2 = down(64, 128) 54 | self.down3 = down(128, 256) 55 | self.down4 = down(256, 512) 56 | self.down5 = down(512, 512) 57 | self.up1 = up(512, 512, output_pad=1, concat=False) 58 | self.up2 = up(1024, 512) 59 | self.up3 = up(768, 256) 60 | self.up4 = up(384, 128) 61 | self.up5 = up(192, output_channels, final=True) 62 | 63 | def forward(self, x): 64 | x1 = self.down1(x) 65 | x2 = self.down2(x1) 66 | x3 = self.down3(x2) 67 | x4 = self.down4(x3) 68 | x5 = self.down5(x4) 69 | x = self.up1(x5, None) 70 | x = self.up2(x, x4) 71 | x = self.up3(x, x3) 72 | x = self.up4(x, x2) 73 | x = self.up5(x, x1) 74 | return x 75 | -------------------------------------------------------------------------------- /nni_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import nni 4 | import os 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | from torch.optim import Adam 9 | from torch.utils.data import DataLoader 10 | 11 | import config 12 | from dataset.uv_dataset import UVDataset 13 | from model.pipeline import PipeLine 14 | 15 | logger = logging.getLogger('neural_texture_AutoML') 16 | 17 | 18 | def get_params(): 19 | # Training settings 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--texturew', type=int, default=config.TEXTURE_W) 22 | parser.add_argument('--textureh', type=int, default=config.TEXTURE_H) 23 | parser.add_argument('--texture_dim', type=int, default=config.TEXTURE_DIM) 24 | parser.add_argument('--use_pyramid', type=bool, default=config.USE_PYRAMID) 25 | parser.add_argument('--view_direction', type=bool, default=config.VIEW_DIRECTION) 26 | parser.add_argument('--data', type=str, default=config.DATA_DIR, help='directory to data') 27 | parser.add_argument('--checkpoint', type=str, default=config.CHECKPOINT_DIR, help='directory to save checkpoint') 28 | parser.add_argument('--logdir', type=str, default=config.LOG_DIR, help='directory to save checkpoint') 29 | parser.add_argument('--train', default=config.TRAIN_SET) 30 | parser.add_argument('--epoch', type=int, default=config.EPOCH) 31 | parser.add_argument('--cropw', type=int, default=config.CROP_W) 32 | parser.add_argument('--croph', type=int, default=config.CROP_H) 33 | parser.add_argument('--batch', type=int, default=config.BATCH_SIZE) 34 | parser.add_argument('--lr', default=config.LEARNING_RATE) 35 | parser.add_argument('--betas', default=config.BETAS) 36 | parser.add_argument('--l2', default=config.L2_WEIGHT_DECAY) 37 | parser.add_argument('--eps', default=config.EPS) 38 | parser.add_argument('--load', default=config.LOAD) 39 | parser.add_argument('--load_step', type=int, default=config.LOAD_STEP) 40 | args = parser.parse_args() 41 | return args 42 | 43 | def adjust_learning_rate(optimizer, epoch, original_lr): 44 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 45 | if epoch <= 5: 46 | lr = original_lr * 0.2 * epoch 47 | elif epoch < 50: 48 | lr = original_lr 49 | elif epoch < 100: 50 | lr = 0.1 * original_lr 51 | else: 52 | lr = 0.01 * original_lr 53 | for param_group in optimizer.param_groups: 54 | param_group['lr'] = lr 55 | 56 | def main(args): 57 | 58 | # named_tuple = time.localtime() 59 | # time_string = time.strftime("%m_%d_%Y_%H_%M", named_tuple) 60 | # log_dir = os.path.join(args.logdir, time_string) 61 | # if not os.path.exists(log_dir): 62 | # os.makedirs(log_dir) 63 | # writer = tensorboardX.SummaryWriter(logdir=log_dir) 64 | 65 | # checkpoint_dir = args.checkpoint + time_string 66 | # if not os.path.exists(checkpoint_dir): 67 | # os.makedirs(checkpoint_dir) 68 | 69 | dataset = UVDataset(args.data, args.train, args.croph, args.cropw) 70 | dataloader = DataLoader(dataset, batch_size=args.batch, shuffle=True, num_workers=4) 71 | 72 | if args.load: 73 | print('Loading Saved Model') 74 | model = torch.load(os.path.join(args.checkpoint, args.load)) 75 | step = args.load_step 76 | else: 77 | model = PipeLine(args.texturew, args.textureh, args.texture_dim, args.use_pyramid, args.view_direction) 78 | step = 0 79 | 80 | l2 = args.l2.split(',') 81 | l2 = [float(x) for x in l2] 82 | betas = args.betas.split(',') 83 | betas = [float(x) for x in betas] 84 | betas = tuple(betas) 85 | optimizer = Adam([ 86 | {'params': model.texture.layer1, 'weight_decay': l2[0], 'lr': args.lr}, 87 | {'params': model.texture.layer2, 'weight_decay': l2[1], 'lr': args.lr}, 88 | {'params': model.texture.layer3, 'weight_decay': l2[2], 'lr': args.lr}, 89 | {'params': model.texture.layer4, 'weight_decay': l2[3], 'lr': args.lr}, 90 | {'params': model.unet.parameters(), 'lr': 0.1 * args.lr}], 91 | betas=betas, eps=args.eps) 92 | model = model.to('cuda') 93 | model.train() 94 | torch.set_grad_enabled(True) 95 | criterion = nn.L1Loss() 96 | 97 | print('Training started') 98 | for i in range(args.epoch): 99 | print('Epoch {}'.format(i+1)) 100 | adjust_learning_rate(optimizer, i, args.lr) 101 | for samples in dataloader: 102 | if args.view_direction: 103 | images, uv_maps, sh_maps, masks = samples 104 | # random scale 105 | scale = 2 ** random.randint(-1,1) 106 | images = F.interpolate(images, scale_factor=scale, mode='bilinear') 107 | 108 | uv_maps = uv_maps.permute(0, 3, 1, 2) 109 | uv_maps = F.interpolate(uv_maps, scale_factor=scale, mode='bilinear') 110 | uv_maps = uv_maps.permute(0, 2, 3, 1) 111 | 112 | sh_maps = F.interpolate(sh_maps, scale_factor=scale, mode='bilinear') 113 | 114 | step += images.shape[0] 115 | optimizer.zero_grad() 116 | RGB_texture, preds = model(uv_maps.cuda(), sh_maps.cuda()) 117 | else: 118 | images, uv_maps, masks = samples 119 | # random scale 120 | scale = 2 ** random.randint(-1,1) 121 | images = F.interpolate(images, scale_factor=scale, mode='bilinear') 122 | uv_maps = uv_maps.permute(0, 3, 1, 2) 123 | uv_maps = F.interpolate(uv_maps, scale_factor=scale, mode='bilinear') 124 | uv_maps = uv_maps.permute(0, 2, 3, 1) 125 | 126 | step += images.shape[0] 127 | optimizer.zero_grad() 128 | RGB_texture, preds = model(uv_maps.cuda()) 129 | 130 | loss1 = criterion(RGB_texture.cpu(), images) 131 | loss2 = criterion(preds.cpu(), images) 132 | loss = loss1 + loss2 133 | loss.backward() 134 | optimizer.step() 135 | nni.report_intermediate_result(loss.item()) 136 | # writer.add_scalar('train/loss', loss.item(), step) 137 | print('loss at step {}: {}'.format(step, loss.item())) 138 | 139 | # save checkpoint 140 | # print('Saving checkpoint') 141 | # torch.save(model, args.checkpoint+time_string+'/epoch_{}.pt'.format(i+1)) 142 | 143 | 144 | if __name__ == '__main__': 145 | try: 146 | # get parameters form tuner 147 | tuner_params = nni.get_next_parameter() 148 | logger.debug(tuner_params) 149 | params = vars(get_params()) 150 | params.update(tuner_params) 151 | main(params) 152 | except Exception as exception: 153 | logger.exception(exception) 154 | raise 155 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | import os 5 | from skimage import img_as_ubyte 6 | import sys 7 | import tqdm 8 | import torch 9 | from torch.utils.data import DataLoader 10 | import torchvision.transforms as transforms 11 | 12 | import config 13 | from dataset.eval_dataset import EvalDataset 14 | from model.pipeline import PipeLine 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--data', type=str, default=config.DATA_DIR, help='directory to data') 18 | parser.add_argument('--test', default=config.TEST_SET, help='index list of test uv_maps') 19 | parser.add_argument('--checkpoint', type=str, default=config.CHECKPOINT_DIR, help='directory to save checkpoint') 20 | parser.add_argument('--load', type=str, default=config.TEST_LOAD, help='checkpoint name') 21 | parser.add_argument('--batch', type=int, default=config.BATCH_SIZE) 22 | parser.add_argument('--save', type=str, default=config.SAVE_DIR, help='save directory') 23 | parser.add_argument('--out_mode', type=str, default=config.OUT_MODE, choices=('video', 'image')) 24 | parser.add_argument('--fps', type=int, default=config.FPS) 25 | args = parser.parse_args() 26 | 27 | 28 | if __name__ == '__main__': 29 | 30 | checkpoint_file = os.path.join(args.checkpoint, args.load) 31 | if not os.path.exists(checkpoint_file): 32 | print('checkpoint not exists!') 33 | sys.exit() 34 | 35 | if not os.path.exists(args.save): 36 | os.makedirs(args.save) 37 | 38 | dataset = EvalDataset(args.data, args.test, args.view_direction) 39 | dataloader = DataLoader(dataset, batch_size=args.batch, shuffle=False, num_workers=4, collate_fn=EvalDataset.get_collect_fn(args.view_direction)) 40 | 41 | model = torch.load(checkpoint_file) 42 | model = model.to('cuda') 43 | model.eval() 44 | torch.set_grad_enabled(False) 45 | 46 | if args.out_mode == 'video': 47 | fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') 48 | writer = cv2.VideoWriter(os.path.join(args.save, 'render.mp4'), fourcc, 16, 49 | (dataset.width, dataset.height), True) 50 | print('Evaluating started') 51 | for samples in tqdm.tqdm(dataloader): 52 | if args.view_direction: 53 | uv_maps, extrinsics, masks, idxs = samples 54 | RGB_texture, preds = model(uv_maps.cuda(), extrinsics.cuda()) 55 | else: 56 | uv_maps, masks, idxs = samples 57 | RGB_texture, preds = model(uv_maps.cuda()) 58 | 59 | preds = preds.cpu() 60 | preds.masked_fill_(masks, 0) # fill invalid with 0 61 | 62 | # save result 63 | if args.out_mode == 'video': 64 | preds = preds.numpy() 65 | preds = np.clip(preds, -1.0, 1.0) 66 | for i in range(len(idxs)): 67 | image = img_as_ubyte(preds[i]) 68 | image = np.transpose(image, (1,2,0)) 69 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 70 | writer.write(image) 71 | else: 72 | for i in range(len(idxs)): 73 | image = transforms.ToPILImage()(preds[i]) 74 | image.save(os.path.join(args.save, '{}_render.png'.format(idxs[i]))) 75 | -------------------------------------------------------------------------------- /render_texture.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | import os 5 | from skimage import img_as_ubyte 6 | import sys 7 | import tqdm 8 | import torch 9 | from torch.utils.data import DataLoader 10 | import torchvision.transforms as transforms 11 | 12 | import config 13 | from dataset.eval_dataset import EvalDataset 14 | from model.pipeline import PipeLine 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--data', type=str, default=config.DATA_DIR, help='directory to data') 18 | parser.add_argument('--test', default=config.TEST_SET, help='index list of test uv_maps') 19 | parser.add_argument('--checkpoint', type=str, default=config.CHECKPOINT_DIR, help='directory to save checkpoint') 20 | parser.add_argument('--load', type=str, default=config.TEST_LOAD, help='checkpoint name') 21 | parser.add_argument('--batch', type=int, default=config.BATCH_SIZE) 22 | parser.add_argument('--save', type=str, default=config.SAVE_DIR, help='save directory') 23 | parser.add_argument('--out_mode', type=str, default=config.OUT_MODE, choices=('video', 'image')) 24 | parser.add_argument('--fps', type=int, default=config.FPS) 25 | args = parser.parse_args() 26 | 27 | 28 | if __name__ == '__main__': 29 | 30 | checkpoint_file = os.path.join(args.checkpoint, args.load) 31 | if not os.path.exists(checkpoint_file): 32 | print('checkpoint not exists!') 33 | sys.exit() 34 | 35 | if not os.path.exists(args.save): 36 | os.makedirs(args.save) 37 | 38 | dataset = EvalDataset(args.data, args.test, False) 39 | dataloader = DataLoader(dataset, batch_size=args.batch, shuffle=False, num_workers=4, collate_fn=EvalDataset.get_collect_fn(False)) 40 | 41 | model = torch.load(checkpoint_file) 42 | model = model.to('cuda') 43 | model.eval() 44 | torch.set_grad_enabled(False) 45 | 46 | if args.out_mode == 'video': 47 | fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') 48 | writer = cv2.VideoWriter(os.path.join(args.save, 'render.mp4'), fourcc, 16, 49 | (dataset.width, dataset.height), True) 50 | print('Evaluating started') 51 | for samples in tqdm.tqdm(dataloader): 52 | uv_maps, masks, idxs = samples 53 | preds = model(uv_maps.cuda()).cpu() 54 | 55 | preds.masked_fill_(masks, 0) # fill invalid with 0 56 | 57 | # save result 58 | if args.out_mode == 'video': 59 | preds = preds.numpy() 60 | preds = np.clip(preds, -1.0, 1.0) 61 | for i in range(len(idxs)): 62 | image = img_as_ubyte(preds[i]) 63 | image = np.transpose(image, (1,2,0)) 64 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 65 | writer.write(image) 66 | else: 67 | for i in range(len(idxs)): 68 | image = transforms.ToPILImage()(preds[i]) 69 | image.save(os.path.join(args.save, '{}_render.png'.format(idxs[i]))) 70 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision==0.2.2.post3 2 | numpy==1.15.0 3 | torch==1.0.1 4 | nni==1.0 5 | Pillow>=6.2.2 6 | tensorboardX==1.9 7 | -------------------------------------------------------------------------------- /search_space.json: -------------------------------------------------------------------------------- 1 | { 2 | "lr":{"_type":"choice","_value":[0.0001, 0.001, 0.01]}, 3 | "l2":{"_type":"choice","_value":[[0,0,0,0], [0.01, 0.001, 0.0001, 0]]} 4 | } -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | import random 5 | import tensorboardX 6 | import time 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.optim import Adam 11 | from torch.utils.data import DataLoader 12 | 13 | import config 14 | from dataset.uv_dataset import UVDataset 15 | from model.pipeline import PipeLine 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--texturew', type=int, default=config.TEXTURE_W) 19 | parser.add_argument('--textureh', type=int, default=config.TEXTURE_H) 20 | parser.add_argument('--texture_dim', type=int, default=config.TEXTURE_DIM) 21 | parser.add_argument('--use_pyramid', type=bool, default=config.USE_PYRAMID) 22 | parser.add_argument('--view_direction', type=bool, default=config.VIEW_DIRECTION) 23 | parser.add_argument('--data', type=str, default=config.DATA_DIR, help='directory to data') 24 | parser.add_argument('--checkpoint', type=str, default=config.CHECKPOINT_DIR, help='directory to save checkpoint') 25 | parser.add_argument('--logdir', type=str, default=config.LOG_DIR, help='directory to save checkpoint') 26 | parser.add_argument('--train', default=config.TRAIN_SET) 27 | parser.add_argument('--epoch', type=int, default=config.EPOCH) 28 | parser.add_argument('--cropw', type=int, default=config.CROP_W) 29 | parser.add_argument('--croph', type=int, default=config.CROP_H) 30 | parser.add_argument('--batch', type=int, default=config.BATCH_SIZE) 31 | parser.add_argument('--lr', type=float, default=config.LEARNING_RATE) 32 | parser.add_argument('--betas', type=str, default=config.BETAS) 33 | parser.add_argument('--l2', type=str, default=config.L2_WEIGHT_DECAY) 34 | parser.add_argument('--eps', type=float, default=config.EPS) 35 | parser.add_argument('--load', type=str, default=config.LOAD) 36 | parser.add_argument('--load_step', type=int, default=config.LOAD_STEP) 37 | parser.add_argument('--epoch_per_checkpoint', type=int, default=config.EPOCH_PER_CHECKPOINT) 38 | args = parser.parse_args() 39 | 40 | 41 | def adjust_learning_rate(optimizer, epoch, original_lr): 42 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 43 | if epoch <= 5: 44 | lr = original_lr * 0.2 * epoch 45 | elif epoch < 50: 46 | lr = original_lr 47 | elif epoch < 100: 48 | lr = 0.1 * original_lr 49 | else: 50 | lr = 0.01 * original_lr 51 | for param_group in optimizer.param_groups: 52 | param_group['lr'] = lr 53 | 54 | 55 | def main(): 56 | named_tuple = time.localtime() 57 | time_string = time.strftime("%m_%d_%Y_%H_%M", named_tuple) 58 | log_dir = os.path.join(args.logdir, time_string) 59 | if not os.path.exists(log_dir): 60 | os.makedirs(log_dir) 61 | writer = tensorboardX.SummaryWriter(logdir=log_dir) 62 | 63 | checkpoint_dir = args.checkpoint + time_string 64 | if not os.path.exists(checkpoint_dir): 65 | os.makedirs(checkpoint_dir) 66 | 67 | dataset = UVDataset(args.data, args.train, args.croph, args.cropw, args.view_direction) 68 | dataloader = DataLoader(dataset, batch_size=args.batch, shuffle=True, num_workers=4) 69 | 70 | if args.load: 71 | print('Loading Saved Model') 72 | model = torch.load(os.path.join(args.checkpoint, args.load)) 73 | step = args.load_step 74 | else: 75 | model = PipeLine(args.texturew, args.textureh, args.texture_dim, args.use_pyramid, args.view_direction) 76 | step = 0 77 | 78 | l2 = args.l2.split(',') 79 | l2 = [float(x) for x in l2] 80 | betas = args.betas.split(',') 81 | betas = [float(x) for x in betas] 82 | betas = tuple(betas) 83 | optimizer = Adam([ 84 | {'params': model.texture.layer1, 'weight_decay': l2[0], 'lr': args.lr}, 85 | {'params': model.texture.layer2, 'weight_decay': l2[1], 'lr': args.lr}, 86 | {'params': model.texture.layer3, 'weight_decay': l2[2], 'lr': args.lr}, 87 | {'params': model.texture.layer4, 'weight_decay': l2[3], 'lr': args.lr}, 88 | {'params': model.unet.parameters(), 'lr': 0.1 * args.lr}], 89 | betas=betas, eps=args.eps) 90 | model = model.to('cuda') 91 | model.train() 92 | torch.set_grad_enabled(True) 93 | criterion = nn.L1Loss() 94 | 95 | print('Training started') 96 | for i in range(1, 1+args.epoch): 97 | print('Epoch {}'.format(i)) 98 | adjust_learning_rate(optimizer, i, args.lr) 99 | for samples in dataloader: 100 | if args.view_direction: 101 | images, uv_maps, extrinsics, masks = samples 102 | # random scale 103 | scale = 2 ** random.randint(-1,1) 104 | images = F.interpolate(images, scale_factor=scale, mode='bilinear') 105 | 106 | uv_maps = uv_maps.permute(0, 3, 1, 2) 107 | uv_maps = F.interpolate(uv_maps, scale_factor=scale, mode='bilinear') 108 | uv_maps = uv_maps.permute(0, 2, 3, 1) 109 | 110 | step += images.shape[0] 111 | optimizer.zero_grad() 112 | RGB_texture, preds = model(uv_maps.cuda(), extrinsics.cuda()) 113 | else: 114 | images, uv_maps, masks = samples 115 | # random scale 116 | scale = 2 ** random.randint(-1,1) 117 | images = F.interpolate(images, scale_factor=scale, mode='bilinear') 118 | uv_maps = uv_maps.permute(0, 3, 1, 2) 119 | uv_maps = F.interpolate(uv_maps, scale_factor=scale, mode='bilinear') 120 | uv_maps = uv_maps.permute(0, 2, 3, 1) 121 | 122 | step += images.shape[0] 123 | optimizer.zero_grad() 124 | RGB_texture, preds = model(uv_maps.cuda()) 125 | 126 | loss1 = criterion(RGB_texture.cpu(), images) 127 | loss2 = criterion(preds.cpu(), images) 128 | loss = loss1 + loss2 129 | loss.backward() 130 | optimizer.step() 131 | writer.add_scalar('train/loss', loss.item(), step) 132 | print('loss at step {}: {}'.format(step, loss.item())) 133 | 134 | # save checkpoint 135 | if i % args.epoch_per_checkpoint == 0: 136 | print('Saving checkpoint') 137 | torch.save(model, args.checkpoint+time_string+'/epoch_{}.pt'.format(i)) 138 | 139 | if __name__ == '__main__': 140 | main() 141 | -------------------------------------------------------------------------------- /train_texture.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | import tensorboardX 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.optim import Adam 10 | from torch.utils.data import DataLoader 11 | 12 | import config 13 | from dataset.uv_dataset import UVDataset 14 | from model.texture import Texture 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--texturew', type=int, default=config.TEXTURE_W) 18 | parser.add_argument('--textureh', type=int, default=config.TEXTURE_H) 19 | parser.add_argument('--texture_dim', type=int, default=config.TEXTURE_DIM) 20 | parser.add_argument('--use_pyramid', type=bool, default=config.USE_PYRAMID) 21 | parser.add_argument('--data', type=str, default=config.DATA_DIR, help='directory to data') 22 | parser.add_argument('--checkpoint', type=str, default=config.CHECKPOINT_DIR, help='directory to save checkpoint') 23 | parser.add_argument('--logdir', type=str, default=config.LOG_DIR, help='directory to save checkpoint') 24 | parser.add_argument('--train', default=config.TRAIN_SET) 25 | parser.add_argument('--epoch', type=int, default=config.EPOCH) 26 | parser.add_argument('--cropw', type=int, default=config.CROP_W) 27 | parser.add_argument('--croph', type=int, default=config.CROP_H) 28 | parser.add_argument('--batch', type=int, default=config.BATCH_SIZE) 29 | parser.add_argument('--lr', type=float, default=config.LEARNING_RATE) 30 | parser.add_argument('--betas', type=str, default=config.BETAS) 31 | parser.add_argument('--l2', type=str, default=config.L2_WEIGHT_DECAY) 32 | parser.add_argument('--eps', type=float, default=config.EPS) 33 | parser.add_argument('--load', type=str, default=config.LOAD) 34 | parser.add_argument('--load_step', type=int, default=config.LOAD_STEP) 35 | parser.add_argument('--epoch_per_checkpoint', type=int, default=config.EPOCH_PER_CHECKPOINT) 36 | args = parser.parse_args() 37 | 38 | 39 | def adjust_learning_rate(optimizer, epoch, original_lr): 40 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 41 | if epoch <= 3: 42 | lr = original_lr * 0.33 * epoch 43 | elif epoch < 5: 44 | lr = original_lr 45 | elif epoch < 10: 46 | lr = 0.1 * original_lr 47 | else: 48 | lr = 0.01 * original_lr 49 | for param_group in optimizer.param_groups: 50 | param_group['lr'] = lr 51 | 52 | 53 | def main(): 54 | named_tuple = time.localtime() 55 | time_string = time.strftime("%m_%d_%Y_%H_%M", named_tuple) 56 | log_dir = os.path.join(args.logdir, time_string) 57 | if not os.path.exists(log_dir): 58 | os.makedirs(log_dir) 59 | writer = tensorboardX.SummaryWriter(logdir=log_dir) 60 | 61 | checkpoint_dir = args.checkpoint + time_string 62 | if not os.path.exists(checkpoint_dir): 63 | os.makedirs(checkpoint_dir) 64 | 65 | dataset = UVDataset(args.data, args.train, args.croph, args.cropw, False) 66 | dataloader = DataLoader(dataset, batch_size=args.batch, shuffle=True, num_workers=4) 67 | 68 | if args.load: 69 | print('Loading Saved Model') 70 | model = torch.load(os.path.join(args.checkpoint, args.load)) 71 | step = args.load_step 72 | else: 73 | model = Texture(args.texturew, args.textureh, 3, use_pyramid=args.use_pyramid) 74 | step = 0 75 | 76 | l2 = args.l2.split(',') 77 | l2 = [float(x) for x in l2] 78 | betas = args.betas.split(',') 79 | betas = [float(x) for x in betas] 80 | betas = tuple(betas) 81 | optimizer = Adam([ 82 | {'params': model.layer1, 'weight_decay': l2[0]}, 83 | {'params': model.layer2, 'weight_decay': l2[1]}, 84 | {'params': model.layer3, 'weight_decay': l2[2]}, 85 | {'params': model.layer4, 'weight_decay': l2[3]}], 86 | lr=args.lr, betas=betas, eps=args.eps) 87 | model = model.to('cuda') 88 | model.train() 89 | torch.set_grad_enabled(True) 90 | criterion = nn.L1Loss() 91 | 92 | print('Training started') 93 | for i in range(1, 1+args.epoch): 94 | print('Epoch {}'.format(i)) 95 | adjust_learning_rate(optimizer, i, args.lr) 96 | for samples in dataloader: 97 | images, uv_maps, masks = samples 98 | step += images.shape[0] 99 | optimizer.zero_grad() 100 | preds = model(uv_maps.cuda()).cpu() 101 | 102 | preds = torch.masked_select(preds, masks) 103 | images = torch.masked_select(images, masks) 104 | loss = criterion(preds, images) 105 | loss.backward() 106 | optimizer.step() 107 | writer.add_scalar('train/loss', loss.item(), step) 108 | print('loss at step {}: {}'.format(step, loss.item())) 109 | 110 | # save checkpoint 111 | print('Saving checkpoint') 112 | torch.save(model, args.checkpoint+time_string+'/epoch_{}.pt'.format(i)) 113 | 114 | if __name__ == '__main__': 115 | main() 116 | -------------------------------------------------------------------------------- /train_unet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | import random 5 | import tensorboardX 6 | import time 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.optim import Adam 11 | from torch.utils.data import DataLoader 12 | 13 | import config 14 | from dataset.uv_dataset import UVDataset 15 | from model.pipeline import PipeLine 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--texturew', type=int, default=config.TEXTURE_W) 19 | parser.add_argument('--textureh', type=int, default=config.TEXTURE_H) 20 | parser.add_argument('--texture_dim', type=int, default=config.TEXTURE_DIM) 21 | parser.add_argument('--use_pyramid', type=bool, default=config.USE_PYRAMID) 22 | parser.add_argument('--view_direction', type=bool, default=config.VIEW_DIRECTION) 23 | parser.add_argument('--data', type=str, default=config.DATA_DIR, help='directory to data') 24 | parser.add_argument('--checkpoint', type=str, default=config.CHECKPOINT_DIR, help='directory to save checkpoint') 25 | parser.add_argument('--logdir', type=str, default=config.LOG_DIR, help='directory to save checkpoint') 26 | parser.add_argument('--train', default=config.TRAIN_SET) 27 | parser.add_argument('--epoch', type=int, default=config.EPOCH) 28 | parser.add_argument('--cropw', type=int, default=config.CROP_W) 29 | parser.add_argument('--croph', type=int, default=config.CROP_H) 30 | parser.add_argument('--batch', type=int, default=config.BATCH_SIZE) 31 | parser.add_argument('--lr', type=float, default=config.LEARNING_RATE) 32 | parser.add_argument('--betas', type=str, default=config.BETAS) 33 | parser.add_argument('--l2', type=str, default=config.L2_WEIGHT_DECAY) 34 | parser.add_argument('--eps', type=float, default=config.EPS) 35 | parser.add_argument('--load', type=str, default=config.LOAD) 36 | parser.add_argument('--epoch_per_checkpoint', type=int, default=config.EPOCH_PER_CHECKPOINT) 37 | args = parser.parse_args() 38 | 39 | 40 | def adjust_learning_rate(optimizer, epoch, original_lr): 41 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 42 | if epoch <= 3: 43 | lr = original_lr * 0.33 * epoch 44 | elif epoch < 10: 45 | lr = original_lr 46 | elif epoch < 50: 47 | lr = 0.3 * original_lr 48 | else: 49 | lr = 0.1 * original_lr 50 | for param_group in optimizer.param_groups: 51 | param_group['lr'] = lr 52 | 53 | 54 | def main(): 55 | named_tuple = time.localtime() 56 | time_string = time.strftime("%m_%d_%Y_%H_%M", named_tuple) 57 | log_dir = os.path.join(args.logdir, time_string) 58 | if not os.path.exists(log_dir): 59 | os.makedirs(log_dir) 60 | writer = tensorboardX.SummaryWriter(logdir=log_dir) 61 | 62 | checkpoint_dir = args.checkpoint + time_string 63 | if not os.path.exists(checkpoint_dir): 64 | os.makedirs(checkpoint_dir) 65 | 66 | dataset = UVDataset(args.data, args.train, args.croph, args.cropw, args.view_direction) 67 | dataloader = DataLoader(dataset, batch_size=args.batch, shuffle=True, num_workers=4) 68 | 69 | model = PipeLine(args.texturew, args.textureh, args.texture_dim, args.use_pyramid, args.view_direction) 70 | print('Loading Saved Model') 71 | texture = torch.load(os.path.join(args.checkpoint, args.load)) 72 | model.texture.textures[0] = texture.textures[0] 73 | model.texture.textures[1] = texture.textures[1] 74 | model.texture.textures[2] = texture.textures[2] 75 | step = 0 76 | 77 | l2 = args.l2.split(',') 78 | l2 = [float(x) for x in l2] 79 | betas = args.betas.split(',') 80 | betas = [float(x) for x in betas] 81 | betas = tuple(betas) 82 | optimizer = Adam([ 83 | {'params': model.unet.parameters()}], 84 | lr=args.lr, betas=betas, eps=args.eps) 85 | model = model.to('cuda') 86 | model.train() 87 | model.texture.textures[0].eval() 88 | model.texture.textures[1].eval() 89 | model.texture.textures[2].eval() 90 | criterion = nn.L1Loss() 91 | 92 | print('Training started') 93 | for i in range(1, 1+args.epoch): 94 | print('Epoch {}'.format(i)) 95 | adjust_learning_rate(optimizer, i, args.lr) 96 | for samples in dataloader: 97 | if args.view_direction: 98 | images, uv_maps, extrinsics, masks = samples 99 | # random scale 100 | scale = 2 ** random.randint(-1,1) 101 | images = F.interpolate(images, scale_factor=scale, mode='bilinear') 102 | uv_maps = uv_maps.permute(0, 3, 1, 2) 103 | uv_maps = F.interpolate(uv_maps, scale_factor=scale, mode='bilinear') 104 | uv_maps = uv_maps.permute(0, 2, 3, 1) 105 | 106 | step += images.shape[0] 107 | optimizer.zero_grad() 108 | preds = model(uv_maps.cuda(), extrinsics.cuda()).cpu() 109 | else: 110 | images, uv_maps, masks = samples 111 | # random scale 112 | scale = 2 ** random.randint(-1,1) 113 | images = F.interpolate(images, scale_factor=scale, mode='bilinear') 114 | uv_maps = uv_maps.permute(0, 3, 1, 2) 115 | uv_maps = F.interpolate(uv_maps, scale_factor=scale, mode='bilinear') 116 | uv_maps = uv_maps.permute(0, 2, 3, 1) 117 | 118 | step += images.shape[0] 119 | optimizer.zero_grad() 120 | preds = model(uv_maps.cuda()).cpu() 121 | 122 | loss = criterion(preds, images) 123 | loss.backward() 124 | optimizer.step() 125 | writer.add_scalar('train/loss', loss.item(), step) 126 | print('loss at step {}: {}'.format(step, loss.item())) 127 | 128 | # save checkpoint 129 | if i % args.epoch_per_checkpoint == 0: 130 | print('Saving checkpoint') 131 | torch.save(model, args.checkpoint+time_string+'/epoch_{}.pt'.format(i)) 132 | 133 | if __name__ == '__main__': 134 | main() 135 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import random 4 | import torch 5 | import torch.nn.functional as F 6 | import torchvision.transforms as transforms 7 | 8 | 9 | def img_transform(image): 10 | image_transforms = transforms.Compose([ 11 | transforms.ToTensor(), 12 | ]) 13 | image = image_transforms(image) 14 | return image 15 | 16 | 17 | def map_transform(map): 18 | map = torch.from_numpy(map) 19 | return map 20 | 21 | 22 | def augment(img, map, crop_size): 23 | ''' 24 | :param img: PIL input image 25 | :param map: numpy input map 26 | :param crop_size: a tuple (h, w) 27 | :return: image, map and mask 28 | ''' 29 | # random mirror 30 | # if random.random() < 0.5: 31 | # img = img.transpose(Image.FLIP_LEFT_RIGHT) 32 | # map = np.fliplr(map) 33 | 34 | # random crop 35 | w, h = img.size 36 | crop_h, crop_w = crop_size 37 | w1 = random.randint(0, w - crop_w) 38 | h1 = random.randint(0, h - crop_h) 39 | img = img.crop((w1, h1, w1 + crop_w, h1 + crop_h)) 40 | map = map[h1:h1 + crop_h, w1:w1 + crop_w, :] 41 | 42 | # final transform 43 | img, map = img_transform(img), map_transform(map) 44 | 45 | # mask for valid uv positions 46 | mask = torch.max(map, dim=2)[0].ge(-1.0+1e-6) 47 | mask = mask.repeat((3,1,1)) 48 | 49 | return img, map, mask 50 | 51 | 52 | # deprecated 53 | 54 | # sh = np.zeros(9) 55 | # sh[0] = 1 / np.sqrt(4 * np.pi) 56 | # sh[1:4] = 2 * np.pi / 3 * np.sqrt(3 / (4 * np.pi)) 57 | # sh[4] = np.pi / 8 * np.sqrt(5 / (4 * np.pi)) 58 | # sh[5:8] = 3 * np.pi / 4 * np.sqrt(5 / (12 * np.pi)) 59 | # sh[8] = 3 * np.pi / 8 * np.sqrt(5 / (12 * np.pi)) 60 | 61 | # def view2sh(view_map, h, crop_h, w, crop_w): 62 | # ''' 63 | # :param view_map: ndarray of (H, W, 3) 64 | # :param h: start position at height 65 | # :param crop_h: 66 | # :param w: start position at weight 67 | # :param crop_w: 68 | # :return: image, map and mask 69 | # ''' 70 | # map = view_map[h:h+crop_h, w:w+crop_w, :] 71 | # sh_map = np.zeros((9, crop_h, crop_w), dtype=np.float32) 72 | # sh_map[0] = sh[0] 73 | # sh_map[1] = sh[1] * map[:, :, 2] 74 | # sh_map[2] = sh[2] * map[:, :, 1] 75 | # sh_map[3] = sh[3] * map[:, :, 0] 76 | # sh_map[4] = sh[4] * (2*map[:, :, 2]*map[:, :, 2]-map[:, :, 0]*map[:, :, 0]-map[:, :, 1]*map[:, :, 1]) 77 | # sh_map[5] = sh[5] * map[:, :, 1] * map[:, :, 2] 78 | # sh_map[6] = sh[6] * map[:, :, 0] * map[:, :, 2] 79 | # sh_map[7] = sh[7] * map[:, :, 0] * map[:, :, 1] 80 | # sh_map[8] = sh[8] * (map[:, :, 0] * map[:, :, 0] - map[:, :, 1] * map[:, :, 1]) 81 | # return sh_map 82 | 83 | 84 | # def augment_view(img, map, view_map, crop_size): 85 | # ''' 86 | # :param img: PIL input image 87 | # :param map: numpy input map 88 | # :param view_map: numpy input map 89 | # :param crop_size: a tuple (h, w) 90 | # :return: image, map and mask 91 | # ''' 92 | # # random mirror 93 | # # if random.random() < 0.5: 94 | # # img = img.transpose(Image.FLIP_LEFT_RIGHT) 95 | # # map = np.fliplr(map) 96 | 97 | # # random crop 98 | # w, h = img.size 99 | # crop_h, crop_w = crop_size 100 | # w1 = random.randint(0, w - crop_w-1) 101 | # h1 = random.randint(0, h - crop_h-1) 102 | # img = img.crop((w1, h1, w1 + crop_w, h1 + crop_h)) 103 | # map = map[h1:h1 + crop_h, w1:w1 + crop_w, :] 104 | # sh_map = view2sh(view_map, h1, crop_h, w1, crop_w) 105 | 106 | # # final transform 107 | # img, map, sh_map = img_transform(img), map_transform(map), map_transform(sh_map) 108 | 109 | # # mask for valid uv positions 110 | # mask = torch.max(map, dim=2)[0].ge(-1.0+1e-6) 111 | # mask = mask.repeat((3,1,1)) 112 | 113 | # return img, map, sh_map, mask 114 | --------------------------------------------------------------------------------