├── .gitignore ├── LICENSE.md ├── configs ├── GAN │ ├── 000.yaml │ └── 000_eval_fix.yaml ├── VAE │ ├── 000.yaml │ └── 000_eval_fix.yaml ├── default.yaml └── singleview │ ├── NVS │ └── car.yaml │ └── texfields │ ├── car.yaml │ ├── car_demo.yaml │ ├── car_eval_fix.yaml │ └── car_eval_rnd.yaml ├── demo ├── cc067578ad92517bbe25370c898e25a5 │ ├── input_image │ │ └── 000.jpg │ ├── pointcloud.npz │ └── visualize │ │ ├── depth │ │ ├── 000.exr │ │ ├── 001.exr │ │ ├── 002.exr │ │ ├── 003.exr │ │ ├── 004.exr │ │ └── cameras.npz │ │ └── image │ │ ├── 000.png │ │ ├── 001.png │ │ ├── 002.png │ │ ├── 003.png │ │ └── 004.png └── test.lst ├── environment.yaml ├── evaluate.py ├── generate.py ├── gfx └── clips │ ├── blackroof.gif │ └── header.png ├── mesh2tex ├── .gitignore ├── __init__.py ├── checkpoints.py ├── common.py ├── config.py ├── data │ ├── __init__.py │ ├── core.py │ ├── fields.py │ └── transforms.py ├── eval.py ├── geometry │ ├── __init__.py │ └── pointnet.py ├── layers.py ├── nvs │ ├── __init__.py │ ├── config.py │ ├── generation.py │ ├── models │ │ ├── __init__.py │ │ ├── decoder.py │ │ ├── discriminator.py │ │ └── encoder.py │ └── training.py ├── texnet │ ├── __init__.py │ ├── config.py │ ├── generation.py │ ├── models │ │ ├── __init__.py │ │ ├── decoder.py │ │ ├── discriminator.py │ │ ├── image_encoder.py │ │ └── vae_encoder.py │ └── training.py ├── training.py └── utils │ ├── FID │ ├── feature_l1.py │ ├── fid_score.py │ └── inception.py │ ├── SSIM_L1 │ └── ssim_l1_score.py │ ├── __init__.py │ └── io.py ├── readme.md ├── scripts └── download_data.sh └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/synthetic_combined 2 | data/shapenet -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Michael Oechsle, Lars Mescheder, Michael Niemeyer, Thilo Strauss, Andreas Geiger 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/GAN/000.yaml: -------------------------------------------------------------------------------- 1 | method: texnet 2 | data: 3 | path_shapes: data/shapenet/synthetic_cars_nospecular/ 4 | dataset_imgs_type: image_folder 5 | img_size: 128 6 | training: 7 | out_dir: 'out/GAN/car' 8 | batch_size: 32 9 | model_selection_metric: null 10 | model_selection_mode: maximize 11 | print_every: 10 12 | visualize_every: 100 13 | checkpoint_every: 1000 14 | validate_every: 1000 15 | backup_every: 10000 16 | moving_average_beta: 0.99 17 | gradient_penalties_reg: 10. 18 | lr_g: 0.0001 19 | lr_d: 0.0001 20 | multi_gpu: false 21 | vis_fixviews: True 22 | weight_pixelloss: 0. 23 | weight_ganloss: 1. 24 | weight_vaeloss: 0. 25 | experiment: 'generative' 26 | model: 27 | decoder: each_layer_c 28 | encoder: 29 | geometry_encoder: simple 30 | discriminator: resnet_conditional 31 | vae_encoder: 32 | encoder_kwargs: 33 | vae_encoder_kwargs: 34 | decoder_kwargs: 35 | leaky: True 36 | geometry_encoder_kwargs: 37 | leaky: True 38 | generator_bg_kwargs: 39 | leaky: True 40 | discriminator_kwargs: 41 | leaky: True 42 | z_dim: 64 43 | c_dim: 512 44 | white_bg: True 45 | model_url: 'https://s3.eu-central-1.amazonaws.com/avg-projects/texture_fields/models/gan_car-360b7ce7.pt' 46 | generation: 47 | batch_size: 10 48 | test: 49 | model_file: model.pt 50 | vis_dir: 'out/GAN/car/eval_fix' 51 | dataset_split: 'test_vis' 52 | with_occnet: False 53 | 54 | -------------------------------------------------------------------------------- /configs/GAN/000_eval_fix.yaml: -------------------------------------------------------------------------------- 1 | method: texnet 2 | inherit_from: configs/GAN/000.yaml 3 | data: 4 | with_shuffle: False 5 | training: 6 | vis_fixviews: True 7 | generation: 8 | batch_size: 1 9 | test: 10 | model_file: model.pt 11 | vis_dir: 'out/GAN/car/eval_fix' 12 | dataset_split: 'test_vis' 13 | with_occnet: False 14 | generation_mode: 'gan' -------------------------------------------------------------------------------- /configs/VAE/000.yaml: -------------------------------------------------------------------------------- 1 | method: texnet 2 | data: 3 | path_shapes: data/shapenet/synthetic_cars_nospecular/ 4 | dataset_imgs_type: image_folder 5 | img_size: 128 6 | training: 7 | out_dir: 'out/VAE/car' 8 | batch_size: 3 9 | model_selection_metric: null 10 | model_selection_mode: maximize 11 | print_every: 100 12 | visualize_every: 1000 13 | checkpoint_every: 1000 14 | validate_every: 1000 15 | backup_every: 1000000 16 | moving_average_beta: 0 17 | pc_subsampling: 2048 18 | vis_fixviews: True 19 | weight_pixelloss: 1. 20 | weight_ganloss: 0. 21 | weight_vaeloss: 10. 22 | experiment: 'generative' 23 | gradient_penalties_reg: 0. 24 | model: 25 | decoder: each_layer_c 26 | encoder: 27 | vae_encoder: resnet 28 | geometry_encoder: simple 29 | decoder_kwargs: 30 | leaky: True 31 | resnet_leaky: True 32 | encoder_kwargs: {} 33 | vae_encoder_kwargs: 34 | leaky: True 35 | geometry_encoder_kwargs: 36 | leaky: True 37 | z_dim: 512 38 | c_dim: 512 39 | white_bg: True 40 | model_url: 'https://s3.eu-central-1.amazonaws.com/avg-projects/texture_fields/models/vae_car-f141e128.pt' 41 | 42 | generation: 43 | batch_size: 3 44 | test: 45 | model_file: model.pt 46 | vis_dir: 'out/VAE/car/eval_vis' 47 | dataset_split: 'test_vis' 48 | with_occnet: False 49 | -------------------------------------------------------------------------------- /configs/VAE/000_eval_fix.yaml: -------------------------------------------------------------------------------- 1 | method: texnet 2 | inherit_from: configs/VAE/000.yaml 3 | data: 4 | img_size: 128 5 | training: 6 | vis_fixviews: True 7 | generation: 8 | batch_size: 3 9 | test: 10 | model_file: model.pt 11 | vis_dir: 'out/VAE/car/eval_fix' 12 | dataset_split: 'test_vis' 13 | with_occnet: False 14 | generation_mode: 'vae' -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | method: texnet 2 | data: 3 | path_imgs: null 4 | path_shapes: null 5 | dataset_imgs_type: image_folder 6 | shapes_multiclass: false 7 | classes_shapes: null 8 | classes_imgs: null 9 | img_size: 64 10 | pcl_knn: null 11 | with_shuffle: True 12 | training: 13 | out_dir: 'out' 14 | batch_size: 64 15 | model_selection_metric: none 16 | model_selection_mode: maximize 17 | print_every: 100 18 | visualize_every: 1000 19 | checkpoint_every: 1000 20 | validate_every: 10000 21 | backup_every: 10000 22 | moving_average_beta: null 23 | lr_g: 0.0001 24 | lr_d: 0.0001 25 | gradient_penalties_reg: 10. 26 | multi_gpu: false 27 | pc_subsampling: 2048 28 | vis_fixviews: 29 | weight_pixelloss: 0.0 30 | weight_ganloss: 0.0 31 | weight_vaeloss: 0.0 32 | model: 33 | decoder: simple 34 | geometry_encoder: simple 35 | generator_bg: resnet 36 | discriminator: resnet_conditional 37 | vae_encoder: 38 | decoder_kwargs: 39 | resnet_leaky: True 40 | geometry_encoder_kwargs: {} 41 | generator_bg_kwargs: {} 42 | discriminator_kwargs: {} 43 | vae_encoder_kwargs: {} 44 | z_dim: 128 45 | c_dim: 128 46 | white_bg: False 47 | gan_setting: conditional 48 | model_url: 49 | test: 50 | model_file: model_best.pt 51 | vis_dir: 52 | for_eval: False 53 | dataset_split: 'test' 54 | for_vis: False 55 | with_occnet: False 56 | interpol: False 57 | generate_grid: False 58 | generation_mode: 'HD' -------------------------------------------------------------------------------- /configs/singleview/NVS/car.yaml: -------------------------------------------------------------------------------- 1 | method: nvs 2 | data: 3 | path_shapes: data/shapenet/synthetic_cars_nospecular/ 4 | dataset_imgs_type: image_folder 5 | img_size: 256 6 | training: 7 | out_dir: 'out/nvs/car' 8 | batch_size: 32 9 | model_selection_metric: loss_val 10 | model_selection_mode: minimize 11 | print_every: 100 12 | visualize_every: 10000 13 | checkpoint_every: 10000 14 | validate_every: 10000 15 | backup_every: 100000 16 | moving_average_beta: 0 17 | pc_subsampling: 2048 18 | vis_fixviews: True 19 | gradient_penalties_reg: 0. 20 | weight_pixelloss: 1. 21 | weight_ganloss: 0. 22 | experiment: 'conditional' 23 | model: 24 | decoder: each_layer_c 25 | encoder: resnet18 26 | geometry_encoder: null 27 | generator_bg: null 28 | discriminator: resnet 29 | decoder_kwargs: {} 30 | encoder_kwargs: {} 31 | geometry_encoder_kwargs: {} 32 | generator_bg_kwargs: {} 33 | discriminator_kwargs: {} 34 | z_dim: 0 35 | c_dim: 512 36 | white_bg: True 37 | generation: 38 | batch_size: 32 39 | test: 40 | model_file: model_best.pt 41 | vis_dir: 'out/nvs/car/eval_fix' 42 | dataset_split: 'test_vis' 43 | with_occnet: False 44 | generation_mode: 'HD' -------------------------------------------------------------------------------- /configs/singleview/texfields/car.yaml: -------------------------------------------------------------------------------- 1 | method: texnet 2 | data: 3 | path_shapes: data/shapenet/synthetic_cars_nospecular/ 4 | dataset_imgs_type: image_folder 5 | img_size: 128 6 | training: 7 | out_dir: 'out/singleview/car' 8 | batch_size: 18 9 | model_selection_metric: loss_val 10 | model_selection_mode: minimize 11 | print_every: 100 12 | visualize_every: 1000 13 | checkpoint_every: 1000 14 | validate_every: 1000 15 | backup_every: 100000 16 | moving_average_beta: 0 17 | pc_subsampling: 2048 18 | vis_fixviews: True 19 | weight_pixelloss: 1. 20 | weight_ganloss: 0. 21 | experiment: 'conditional' 22 | gradient_penalties_reg: 0. 23 | model: 24 | decoder: each_layer_c_larger 25 | encoder: resnet18 26 | geometry_encoder: resnet 27 | decoder_kwargs: 28 | leaky: True 29 | resnet_leaky: False 30 | encoder_kwargs: {} 31 | geometry_encoder_kwargs: {} 32 | generator_bg_kwargs: {} 33 | discriminator_kwargs: {} 34 | z_dim: 512 35 | c_dim: 512 36 | white_bg: True 37 | model_url: 38 | generation: 39 | batch_size: 1 40 | test: 41 | model_file: model_best.pt 42 | vis_dir: 'out/singleview/car/eval_fix/' 43 | dataset_split: 'test_vis' 44 | with_occnet: False 45 | generation_mode: 'HD' -------------------------------------------------------------------------------- /configs/singleview/texfields/car_demo.yaml: -------------------------------------------------------------------------------- 1 | method: texnet 2 | inherit_from: configs/singleview/texfields/car.yaml 3 | data: 4 | path_shapes: data/demo 5 | dataset_imgs_type: image_folder 6 | img_size: 256 7 | model: 8 | model_url: 'https://s3.eu-central-1.amazonaws.com/avg-projects/texture_fields/models/car-b3b2a506.pt' 9 | generation: 10 | batch_size: 1 11 | test: 12 | model_file: model_best.pt 13 | vis_dir: 'out/demo' 14 | dataset_split: 'test_vis' 15 | with_occnet: False 16 | generation_mode: 'HD' -------------------------------------------------------------------------------- /configs/singleview/texfields/car_eval_fix.yaml: -------------------------------------------------------------------------------- 1 | method: texnet 2 | inherit_from: configs/singleview/texfields/car.yaml 3 | data: 4 | # path_shapes: data/synthetic_combined/02958343 5 | dataset_imgs_type: image_folder 6 | img_size: 256 7 | model: 8 | model_url: 'https://s3.eu-central-1.amazonaws.com/avg-projects/texture_fields/models/car-b3b2a506.pt' 9 | generation: 10 | batch_size: 1 11 | test: 12 | vis_dir: 'out/singleview/car/eval_fix' 13 | dataset_split: 'test_vis' 14 | with_occnet: False 15 | generation_mode: 'HD' -------------------------------------------------------------------------------- /configs/singleview/texfields/car_eval_rnd.yaml: -------------------------------------------------------------------------------- 1 | method: texnet 2 | inherit_from: configs/singleview/texfields/car.yaml 3 | data: 4 | # path_shapes: data//02958343/ 5 | dataset_imgs_type: image_folder 6 | img_size: 256 7 | model: 8 | model_url: 'https://s3.eu-central-1.amazonaws.com/avg-projects/texture_fields/models/car-b3b2a506.pt' 9 | generation: 10 | batch_size: 100 11 | test: 12 | model_file: model_best.pt 13 | vis_dir: 'out/singleview/car/eval_rnd' 14 | dataset_split: 'test_eval' 15 | with_occnet: False 16 | generation_mode: 'HD' -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/input_image/000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/input_image/000.jpg -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/pointcloud.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/pointcloud.npz -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/visualize/depth/000.exr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/visualize/depth/000.exr -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/visualize/depth/001.exr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/visualize/depth/001.exr -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/visualize/depth/002.exr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/visualize/depth/002.exr -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/visualize/depth/003.exr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/visualize/depth/003.exr -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/visualize/depth/004.exr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/visualize/depth/004.exr -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/visualize/depth/cameras.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/visualize/depth/cameras.npz -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/visualize/image/000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/visualize/image/000.png -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/visualize/image/001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/visualize/image/001.png -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/visualize/image/002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/visualize/image/002.png -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/visualize/image/003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/visualize/image/003.png -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/visualize/image/004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/visualize/image/004.png -------------------------------------------------------------------------------- /demo/test.lst: -------------------------------------------------------------------------------- 1 | cc067578ad92517bbe25370c898e25a5 -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: texturefields 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - imageio=2.4.1 8 | - numpy=1.15.4 9 | - numpy-base=1.15.4 10 | - matplotlib=3.0.3 11 | - matplotlib-base=3.0.3 12 | - pandas=0.23.4 13 | - pillow=5.3.0 14 | - pyembree=0.1.4 15 | - pytest=4.0.2 16 | - python=3.6.7 17 | - pytorch=1.0.0 18 | - pyyaml=3.13 19 | - scikit-image=0.14.1 20 | - scipy=1.1.0 21 | - tensorboardx=1.4 22 | - torchvision=0.2.1 23 | - tqdm=4.28.1 24 | - trimesh=2.37.7 25 | - pip: 26 | - h5py==2.9.0 27 | - plyfile==0.7 28 | - lmdb -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | import os 4 | import glob 5 | 6 | from mesh2tex import config 7 | from mesh2tex.eval import evaluate_generated_images 8 | 9 | categories = {'02958343': 'cars', '03001627': 'chairs', 10 | '02691156': 'airplanes', '04379243': 'tables', 11 | '02828884': 'benches', '02933112': 'cabinets', 12 | '04256520': 'sofa', '03636649': 'lamps', 13 | '04530566': 'vessels'} 14 | 15 | parser = argparse.ArgumentParser( 16 | description='Generate Color for given mesh.' 17 | ) 18 | 19 | parser.add_argument('config', type=str, help='Path to config file.') 20 | 21 | args = parser.parse_args() 22 | cfg = config.load_config(args.config, 'configs/default.yaml') 23 | base_path = cfg['test']['vis_dir'] 24 | 25 | 26 | if cfg['data']['shapes_multiclass']: 27 | category_paths = glob.glob(os.path.join(base_path, '*')) 28 | else: 29 | category_paths = [base_path] 30 | 31 | for category_path in category_paths: 32 | cat_id = os.path.basename(category_path) 33 | category = categories.get(cat_id, cat_id) 34 | path1 = os.path.join(category_path, 'fake/') 35 | path2 = os.path.join(category_path, 'real/') 36 | print('Evaluating %s (%s)' % (category, category_path)) 37 | 38 | evaluation = evaluate_generated_images('all', path1, path2) 39 | name = base_path 40 | 41 | df = pd.DataFrame(evaluation, index=[category]) 42 | df.to_pickle(os.path.join(category_path, 'eval.pkl')) 43 | df.to_csv(os.path.join(category_path, 'eval.csv')) 44 | 45 | print('Evaluation finished') 46 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import argparse 4 | from tqdm import tqdm 5 | from mesh2tex import data 6 | from mesh2tex import config 7 | from mesh2tex.checkpoints import CheckpointIO 8 | 9 | # Get arguments and Config 10 | parser = argparse.ArgumentParser( 11 | description='Generate Color for given mesh.') 12 | parser.add_argument('config', type=str, help='Path to config file.') 13 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.') 14 | args = parser.parse_args() 15 | cfg = config.load_config(args.config, 'configs/default.yaml') 16 | 17 | # Define device 18 | is_cuda = (torch.cuda.is_available() and not args.no_cuda) 19 | device = torch.device("cuda" if is_cuda else "cpu") 20 | 21 | # Read config 22 | out_dir = cfg['training']['out_dir'] 23 | vis_dir = cfg['test']['vis_dir'] 24 | split = cfg['test']['dataset_split'] 25 | if split != 'test_vis' and split != 'test_eval': 26 | print('Are you sure not using test data?') 27 | batch_size = cfg['generation']['batch_size'] 28 | gen_mode = cfg['test']['generation_mode'] 29 | model_url = cfg['model']['model_url'] 30 | 31 | # Dataset 32 | dataset = config.get_dataset(split, cfg, input_sampling=False) 33 | if cfg['data']['shapes_multiclass']: 34 | datasets = dataset.datasets_classes 35 | else: 36 | datasets = [dataset] 37 | 38 | # Load Model 39 | models = config.get_models(cfg, device=device, dataset=dataset) 40 | model_g = models['generator'] 41 | checkpoint_io = CheckpointIO(out_dir, model_g=model_g) 42 | if model_url is None: 43 | checkpoint_io.load(cfg['test']['model_file']) 44 | else: 45 | checkpoint_io.load(cfg['model']['model_url']) 46 | 47 | # Assign Generator 48 | generator = config.get_generator(model_g, cfg, device) 49 | 50 | # data iteration loop 51 | for i_ds, ds in enumerate(datasets): 52 | ds_id = ds.metadata.get('id', str(i_ds)) 53 | ds_name = ds.metadata.get('name', 'n/a') 54 | 55 | if cfg['data']['shapes_multiclass']: 56 | out_dir = os.path.join(vis_dir, ds_id) 57 | else: 58 | out_dir = vis_dir 59 | 60 | test_loader = torch.utils.data.DataLoader( 61 | ds, batch_size=batch_size, num_workers=12, shuffle=False, 62 | collate_fn=data.collate_remove_none) 63 | 64 | batch_counter = 0 65 | 66 | def get_batch_size(batch): 67 | batch_size = next(iter(batch.values())).shape[0] 68 | return batch_size 69 | 70 | for batch in tqdm(test_loader): 71 | offset_batch = batch_size * batch_counter 72 | 73 | model_names = [ 74 | ds.get_model(i) for i in batch['idx'] 75 | ] 76 | 77 | if gen_mode == 'interpolate': 78 | out = generator.generate_images_4eval_vae_interpol(batch, 79 | out_dir, 80 | model_names) 81 | elif gen_mode == 'vae': 82 | out = generator.generate_images_4eval_vae(batch, 83 | out_dir, 84 | model_names) 85 | elif gen_mode == 'gan': 86 | out = generator.generate_images_4eval_gan(batch, 87 | out_dir, 88 | model_names) 89 | elif gen_mode == 'interpolate_rotation': 90 | out = generator.generate_images_4eval_vae_inter_rot(batch, 91 | out_dir, 92 | model_names) 93 | elif gen_mode == 'HD': 94 | generator.generate_images_4eval_condi_hd(batch, 95 | out_dir, 96 | model_names) 97 | 98 | elif gen_mode == 'SD': 99 | generator.generate_images_4eval_condi(batch, 100 | out_dir, 101 | model_names) 102 | 103 | elif gen_mode == 'grid': 104 | out = generator.generate_grid(batch, 105 | out_dir, 106 | model_names) 107 | 108 | elif gen_mode == 'test': 109 | out = generator.generate_images_occnet(batch, 110 | out_dir, 111 | model_names) 112 | else: 113 | print('Modes: HD, grid, interpolate, interpolate_rotation, test') 114 | 115 | batch_counter += 1 116 | -------------------------------------------------------------------------------- /gfx/clips/blackroof.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/gfx/clips/blackroof.gif -------------------------------------------------------------------------------- /gfx/clips/header.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/gfx/clips/header.png -------------------------------------------------------------------------------- /mesh2tex/.gitignore: -------------------------------------------------------------------------------- 1 | .pyc 2 | *.pyc -------------------------------------------------------------------------------- /mesh2tex/__init__.py: -------------------------------------------------------------------------------- 1 | from mesh2tex import geometry 2 | 3 | 4 | __all__ = [ 5 | geometry 6 | ] -------------------------------------------------------------------------------- /mesh2tex/checkpoints.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib 3 | import torch 4 | from torch.utils import model_zoo 5 | 6 | 7 | class CheckpointIO(object): 8 | def __init__(self, checkpoint_dir='./chkpts', **kwargs): 9 | self.module_dict = kwargs 10 | self.checkpoint_dir = checkpoint_dir 11 | 12 | if not os.path.exists(checkpoint_dir): 13 | os.makedirs(checkpoint_dir) 14 | 15 | def register_modules(self, **kwargs): 16 | self.module_dict.update(kwargs) 17 | 18 | def save(self, filename, **kwargs): 19 | filename = os.path.join(self.checkpoint_dir, filename) 20 | 21 | outdict = kwargs 22 | for k, v in self.module_dict.items(): 23 | outdict[k] = v.state_dict() 24 | torch.save(outdict, filename) 25 | 26 | def load(self, filename): 27 | '''Loads a module dictionary from local file or url. 28 | 29 | Args: 30 | filename (str): name of saved module dictionary 31 | ''' 32 | print(filename) 33 | if is_url(filename): 34 | return self.load_url(filename) 35 | else: 36 | return self.load_file(filename) 37 | 38 | def load_url(self, url): 39 | '''Load a module dictionary from url. 40 | 41 | Args: 42 | url (str): url to saved model 43 | ''' 44 | print(url) 45 | print('=> Loading checkpoint from url...') 46 | out_dict = model_zoo.load_url(url, progress=True) 47 | # scalars = self.parse_state_dict(state_dict) 48 | for k, v in self.module_dict.items(): 49 | print("Start loading: %s" % k) 50 | if k in out_dict: 51 | # print(out_dict[k]) 52 | v.load_state_dict(out_dict[k]) 53 | print("Finished: %s" % k) 54 | else: 55 | print('Warning: Could not find %s in checkpoint!' % k) 56 | scalars = {k: v for k, v in out_dict.items() 57 | if k not in self.module_dict} 58 | return scalars 59 | 60 | def load_file(self, filename): 61 | filename = os.path.join(self.checkpoint_dir, filename) 62 | 63 | if os.path.exists(filename): 64 | 65 | print('=> Loading checkpoint...') 66 | out_dict = torch.load(filename) 67 | for k, v in self.module_dict.items(): 68 | print("Start loading: %s" % k) 69 | if k in out_dict: 70 | # print(out_dict[k]) 71 | v.load_state_dict(out_dict[k]) 72 | print("Finished: %s" % k) 73 | else: 74 | print('Warning: Could not find %s in checkpoint!' % k) 75 | scalars = {k: v for k, v in out_dict.items() 76 | if k not in self.module_dict} 77 | return scalars 78 | else: 79 | raise FileExistsError 80 | 81 | def is_url(url): 82 | scheme = urllib.parse.urlparse(url).scheme 83 | return scheme in ('http', 'https') 84 | 85 | -------------------------------------------------------------------------------- /mesh2tex/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy.spatial import cKDTree as KDTree 3 | 4 | 5 | def get_nearest_neighbors_indices_batch(points_src, points_tgt, k=1): 6 | indices = [] 7 | distances = [] 8 | 9 | for (p1, p2) in zip(points_src, points_tgt): 10 | p1 = p1.detach().cpu().numpy().T 11 | p2 = p2.detach().cpu().numpy().T 12 | 13 | kdtree = KDTree(p2) 14 | dist, idx = kdtree.query(p1, k=k, n_jobs=-1) 15 | indices.append(idx) 16 | distances.append(dist) 17 | 18 | indices = torch.LongTensor(indices) 19 | distances = torch.FloatTensor(distances) 20 | 21 | return indices, distances 22 | 23 | 24 | def normalize_imagenet(x): 25 | x = x.clone() 26 | x[:, 0] = (x[:, 0] - 0.485) / 0.229 27 | x[:, 1] = (x[:, 1] - 0.456) / 0.224 28 | x[:, 2] = (x[:, 2] - 0.406) / 0.225 29 | return x 30 | -------------------------------------------------------------------------------- /mesh2tex/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from mesh2tex import texnet 3 | from mesh2tex import nvs 4 | method_dict = { 5 | 'texnet': texnet, 6 | 'nvs': nvs, 7 | } 8 | 9 | 10 | # General config 11 | def load_config(path, default_path=None): 12 | # Load configuration from file itself 13 | with open(path, 'r') as f: 14 | cfg_special = yaml.load(f) 15 | 16 | # Check if we should inherit from a config 17 | inherit_from = cfg_special.get('inherit_from') 18 | 19 | # If yes, load this config first as default 20 | # If no, use the default_path 21 | if inherit_from is not None: 22 | cfg = load_config(inherit_from, default_path) 23 | elif default_path is not None: 24 | with open(default_path, 'r') as f: 25 | cfg = yaml.load(f) 26 | else: 27 | cfg = dict() 28 | 29 | # Include main configuration 30 | update_recursive(cfg, cfg_special) 31 | 32 | return cfg 33 | 34 | 35 | def update_recursive(dict1, dict2): 36 | for k, v in dict2.items(): 37 | if k not in dict1: 38 | dict1[k] = dict() 39 | if isinstance(v, dict): 40 | update_recursive(dict1[k], v) 41 | else: 42 | dict1[k] = v 43 | 44 | 45 | # Individual configs 46 | def get_models(cfg, dataset=None, device=None): 47 | method = cfg['method'] 48 | models = method_dict[method].config.get_models(cfg, 49 | dataset=dataset, 50 | device=device) 51 | return models 52 | 53 | 54 | def get_optimizers(models, cfg): 55 | method = cfg['method'] 56 | optimizers = method_dict[method].config.get_optimizers(models, cfg) 57 | return optimizers 58 | 59 | 60 | def get_dataset(split, cfg, input_sampling=True): 61 | method = cfg['method'] 62 | dataset = method_dict[method].config.get_dataset(split, cfg, 63 | input_sampling) 64 | return dataset 65 | 66 | 67 | def get_dataloader(split, cfg): 68 | method = cfg['method'] 69 | dataloader = method_dict[method].config.get_dataloader(split, cfg) 70 | return dataloader 71 | 72 | 73 | def get_meshloader(split, cfg): 74 | method = cfg['method'] 75 | dataloader = method_dict[method].config.get_meshloader(split, cfg) 76 | return dataloader 77 | 78 | 79 | def get_generator(model, cfg, device): 80 | method = cfg['method'] 81 | generator = method_dict[method].config.get_generator(model, cfg, device) 82 | return generator 83 | 84 | 85 | def get_trainer(models, optimizers, cfg, device=None): 86 | method = cfg['method'] 87 | print("method: " + method) 88 | trainer = method_dict[method].config.get_trainer( 89 | models, optimizers, cfg, device=device) 90 | return trainer 91 | -------------------------------------------------------------------------------- /mesh2tex/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from mesh2tex.data.core import ( 3 | Shapes3dDataset, Shapes3dClassDataset, 4 | CombinedDataset, 5 | collate_remove_none, worker_init_fn 6 | ) 7 | from mesh2tex.data.fields import ( 8 | ImagesField, PointCloudField, 9 | DepthImageField, MeshField, 10 | DepthImageVisualizeField, IndexField, 11 | ) 12 | 13 | from mesh2tex.data.transforms import ( 14 | PointcloudNoise, SubsamplePointcloud, 15 | ComputeKNNPointcloud, 16 | ImageToGrayscale, ResizeImage, 17 | ImageToDepthValue 18 | ) 19 | 20 | 21 | __all__ = [ 22 | # Core 23 | Shapes3dDataset, 24 | Shapes3dClassDataset, 25 | CombinedDataset, 26 | collate_remove_none, 27 | worker_init_fn, 28 | # Fields 29 | ImagesField, 30 | PointCloudField, 31 | DepthImageField, 32 | MeshField, 33 | DepthImageVisualizeField, 34 | IndexField, 35 | # Transforms 36 | PointcloudNoise, 37 | SubsamplePointcloud, 38 | ComputeKNNPointcloud, 39 | ImageToGrayscale, 40 | ImageToDepthValue, 41 | ResizeImage, 42 | ] 43 | -------------------------------------------------------------------------------- /mesh2tex/data/core.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from torch.utils import data 4 | import numpy as np 5 | import yaml 6 | 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | # Fields 12 | class Field(object): 13 | def load(self, data_path, idx): 14 | raise NotImplementedError 15 | 16 | def check_complete(self, files): 17 | raise NotImplementedError 18 | 19 | 20 | class Shapes3dDataset(data.Dataset): 21 | def __init__(self, dataset_folder, fields, split=None, 22 | classes=None, no_except=True, transform=None): 23 | # Read metadata file 24 | metadata_file = os.path.join(dataset_folder, 'metadata.yaml') 25 | if os.path.exists(metadata_file): 26 | with open(metadata_file, 'r') as f: 27 | metadata = yaml.load(f) 28 | else: 29 | metadata = {} 30 | 31 | # If classes is None, use all subfolders 32 | if classes is None: 33 | classes = os.listdir(dataset_folder) 34 | classes = [c for c in classes 35 | if os.path.isdir(os.path.join(dataset_folder, c))] 36 | 37 | # Get all sub-datasets 38 | self.datasets_classes = [] 39 | for c in classes: 40 | subpath = os.path.join(dataset_folder, c) 41 | if not os.path.isdir(subpath): 42 | logger.warning('Class %s does not exist in dataset.' % c) 43 | 44 | metadata_c = metadata.get(c, {'id': c, 'name': 'n/a'}) 45 | dataset = Shapes3dClassDataset(subpath, fields, split, 46 | metadata_c, no_except, 47 | transform=transform) 48 | self.datasets_classes.append(dataset) 49 | 50 | self._concat_dataset = data.ConcatDataset(self.datasets_classes) 51 | 52 | def __len__(self): 53 | return len(self._concat_dataset) 54 | 55 | def __getitem__(self, idx): 56 | return self._concat_dataset[idx] 57 | 58 | 59 | class Shapes3dClassDataset(data.Dataset): 60 | def __init__(self, dataset_folder, fields, split=None, 61 | metadata=dict(), no_except=True, transform=None): 62 | self.dataset_folder = dataset_folder 63 | self.fields = fields 64 | self.metadata = metadata 65 | self.no_except = no_except 66 | self.transform = transform 67 | # Get (filtered) model list 68 | if split is None: 69 | models = [ 70 | f for f in os.listdir(dataset_folder) 71 | if os.path.isdir(os.path.join(dataset_folder, f)) 72 | ] 73 | else: 74 | split_file = os.path.join(dataset_folder, split + '.lst') 75 | with open(split_file, 'r') as f: 76 | models = f.read().split('\n') 77 | 78 | # self.models = list(filter(self.test_model_complete, models)) 79 | self.models = models 80 | 81 | def test_model_complete(self, model): 82 | model_path = os.path.join(self.dataset_folder, model) 83 | files = os.listdir(model_path) 84 | for field_name, field in self.fields.items(): 85 | if not field.check_complete(files): 86 | logger.warn('Field "%s" is incomplete: %s' 87 | % (field_name, model_path)) 88 | return False 89 | else: 90 | return True 91 | 92 | def __len__(self): 93 | return len(self.models) 94 | 95 | def __getitem__(self, idx): 96 | model = self.models[idx] 97 | model_path = os.path.join(self.dataset_folder, model) 98 | data = {} 99 | for field_name, field in self.fields.items(): 100 | try: 101 | field_data = field.load(model_path, idx) 102 | except Exception: 103 | if self.no_except: 104 | logger.warn( 105 | 'Error occured when loading field %s of model %s' 106 | % (field_name, model) 107 | ) 108 | return None 109 | else: 110 | raise 111 | if isinstance(field_data, dict): 112 | 113 | for k, v in field_data.items(): 114 | if k is None: 115 | data[field_name] = v 116 | else: 117 | data['%s.%s' % (field_name, k)] = v 118 | else: 119 | data[field_name] = field_data 120 | if self.transform is not None: 121 | data = self.transform(data) 122 | return data 123 | 124 | def get_model(self, idx): 125 | return self.models[idx] 126 | 127 | 128 | class CombinedDataset(data.Dataset): 129 | def __init__(self, datasets, idx_main=0): 130 | self.datasets = datasets 131 | self.idx_main = idx_main 132 | 133 | def __len__(self): 134 | return len(self.datasets[self.idx_main]) 135 | 136 | def __getitem__(self, idx): 137 | out = [] 138 | for it, ds in enumerate(self.datasets): 139 | if it != self.idx_main: 140 | x_idx = np.random.randint(0, len(ds)) 141 | else: 142 | x_idx = idx 143 | out.append(ds[x_idx]) 144 | return out 145 | 146 | 147 | # Collater 148 | def collate_remove_none(batch): 149 | "Puts each data field into a tensor with outer dimension batch size" 150 | batch = list(filter(check_element_valid, batch)) 151 | return data.dataloader.default_collate(batch) 152 | 153 | 154 | def check_element_valid(batch): 155 | if batch is None: 156 | return False 157 | elif isinstance(batch, list): 158 | for b in batch: 159 | if not check_element_valid(b): 160 | return False 161 | elif isinstance(batch, dict): 162 | for b in batch.values(): 163 | if not check_element_valid(b): 164 | return False 165 | return True 166 | 167 | 168 | # Worker initialization to ensure true randomeness 169 | def worker_init_fn(worker_id): 170 | random_data = os.urandom(4) 171 | base_seed = int.from_bytes(random_data, byteorder="big") 172 | np.random.seed(base_seed + worker_id) 173 | -------------------------------------------------------------------------------- /mesh2tex/data/fields.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import numpy as np 5 | import trimesh 6 | import imageio 7 | from mesh2tex.data.core import Field 8 | 9 | 10 | # Make sure loading xlr works 11 | imageio.plugins.freeimage.download() 12 | 13 | 14 | # Basic index field 15 | class IndexField(Field): 16 | def load(self, model_path, idx): 17 | return idx 18 | 19 | def check_complete(self, files): 20 | return True 21 | 22 | 23 | class MeshField(Field): 24 | def __init__(self, folder_name, transform=None): 25 | self.folder_name = folder_name 26 | self.transform = transform 27 | 28 | def load(self, model_path, idx): 29 | folder_path = os.path.join(model_path, self.folder_name) 30 | file_path = os.path.join(folder_path, 'model.off') 31 | mesh = trimesh.load(file_path, process=False) 32 | if self.transform is not None: 33 | mesh = self.transform(mesh) 34 | 35 | data = { 36 | 'vertices': np.array(mesh.vertices), 37 | 'faces': np.array(mesh.faces), 38 | } 39 | 40 | return data 41 | 42 | def check_complete(self, files): 43 | complete = (self.folder_name in files) 44 | return complete 45 | 46 | # Image field 47 | class ImagesField(Field): 48 | def __init__(self, folder_name, transform=None, 49 | extension='jpg', random_view=True, 50 | with_camera=False, 51 | imageio_kwargs=dict()): 52 | self.folder_name = folder_name 53 | self.transform = transform 54 | self.extension = extension 55 | self.random_view = random_view 56 | self.with_camera = with_camera 57 | self.imageio_kwargs = dict() 58 | 59 | def load(self, model_path, idx): 60 | folder = os.path.join(model_path, self.folder_name) 61 | files = glob.glob(os.path.join(folder, '*.%s' % self.extension)) 62 | files.sort() 63 | if self.random_view: 64 | idx_img = random.randint(0, len(files)-1) 65 | else: 66 | idx_img = 0 67 | filename = files[idx_img] 68 | 69 | image = imageio.imread(filename, **self.imageio_kwargs) 70 | image = np.asarray(image) 71 | 72 | if len(image.shape) == 2: 73 | image = image.reshape(image.shape[0], image.shape[1], 1) 74 | image = np.concatenate([image, image, image], axis=2) 75 | 76 | if image.shape[2] == 4: 77 | image = image[:, :, :3] 78 | 79 | if image.dtype == np.uint8: 80 | image = image.astype(np.float32) / 255 81 | else: 82 | image = image.astype(np.float32) 83 | 84 | if self.transform is not None: 85 | image = self.transform(image) 86 | image = image.transpose(2, 0, 1) 87 | data = { 88 | None: image 89 | } 90 | 91 | if self.with_camera: 92 | camera_file = os.path.join(folder, 'cameras.npz') 93 | camera_dict = np.load(camera_file) 94 | Rt = camera_dict['world_mat_%d' % idx_img].astype(np.float32) 95 | K = camera_dict['camera_mat_%d' % idx_img].astype(np.float32) 96 | data['world_mat'] = Rt 97 | data['camera_mat'] = K 98 | 99 | return data 100 | 101 | def check_complete(self, files): 102 | complete = (self.folder_name in files) 103 | # TODO: check camera 104 | return complete 105 | 106 | 107 | # 3D Fields 108 | class PointCloudField(Field): 109 | def __init__(self, file_name, transform=None, with_transforms=False): 110 | self.file_name = file_name 111 | self.transform = transform 112 | self.with_transforms = with_transforms 113 | 114 | def load(self, model_path, idx): 115 | file_path = os.path.join(model_path, self.file_name) 116 | 117 | pointcloud_dict = np.load(file_path) 118 | 119 | points = pointcloud_dict['points'].astype(np.float32) 120 | normals = pointcloud_dict['normals'].astype(np.float32) 121 | 122 | data = { 123 | None: points.T, 124 | 'normals': normals.T, 125 | } 126 | 127 | if self.with_transforms: 128 | data['loc'] = pointcloud_dict['loc'].astype(np.float32) 129 | data['scale'] = pointcloud_dict['scale'].astype(np.float32) 130 | 131 | if self.transform is not None: 132 | data = self.transform(data) 133 | 134 | return data 135 | 136 | def check_complete(self, files): 137 | complete = (self.file_name in files) 138 | return complete 139 | 140 | 141 | class DepthImageVisualizeField(Field): 142 | def __init__(self, folder_name_img, folder_name_depth, transform_img=None, transform_depth=None, 143 | extension_img='jpg', extension_depth='exr', random_view=True, 144 | with_camera=False, 145 | imageio_kwargs=dict()): 146 | self.folder_name_img = folder_name_img 147 | self.folder_name_depth = folder_name_depth 148 | self.transform_depth = transform_depth 149 | self.transform_img = transform_img 150 | self.extension_img = extension_img 151 | self.extension_depth = extension_depth 152 | self.random_view = random_view 153 | self.with_camera = with_camera 154 | self.imageio_kwargs = dict() 155 | 156 | def load(self, model_path, idx): 157 | folder_img = os.path.join(model_path, self.folder_name_img) 158 | files_img = glob.glob(os.path.join(folder_img, '*.%s' % self.extension_img)) 159 | files_img.sort() 160 | folder_depth = os.path.join(model_path, self.folder_name_depth) 161 | files_depth = glob.glob(os.path.join(folder_depth, '*.%s' % self.extension_depth)) 162 | files_depth.sort() 163 | if self.random_view: 164 | idx_img = random.randint(0, len(files_img)-1) 165 | else: 166 | idx_img = 0 167 | 168 | image_all = [] 169 | depth_all = [] 170 | Rt = [] 171 | K = [] 172 | camera_file = os.path.join(folder_depth, 'cameras.npz') 173 | camera_dict = np.load(camera_file) 174 | 175 | for i in range(len(files_img)): 176 | filename_img = files_img[i] 177 | filename_depth = files_depth[i] 178 | 179 | image = imageio.imread(filename_img, **self.imageio_kwargs) 180 | image = np.asarray(image) 181 | 182 | if image.shape[2] == 4: 183 | image = image[:,:,:3] 184 | 185 | depth = imageio.imread(filename_depth, **self.imageio_kwargs) 186 | 187 | depth = np.asarray(depth) 188 | 189 | if image.dtype == np.uint8: 190 | image = image.astype(np.float32) / 255 191 | else: 192 | image = image.astype(np.float32) 193 | 194 | if self.transform_img is not None: 195 | image = self.transform_img(image) 196 | 197 | if self.transform_depth is not None: 198 | depth = self.transform_depth(depth) 199 | 200 | image = image.transpose(2, 0, 1) 201 | depth = depth.transpose(2, 0, 1) 202 | image_all.append(image) 203 | depth_all.append(depth) 204 | 205 | camera_file = os.path.join(folder_depth, 'cameras.npz') 206 | camera_dict = np.load(camera_file) 207 | Rt.append(camera_dict['world_mat_%d' % i].astype(np.float32)) 208 | K.append(camera_dict['camera_mat_%d' % i].astype(np.float32)) 209 | 210 | data = { 211 | 'img': np.stack(image_all), 212 | 'depth': np.stack(depth_all) 213 | } 214 | 215 | data['world_mat'] = np.stack(Rt) 216 | data['camera_mat'] = np.stack(K) 217 | 218 | return data 219 | 220 | 221 | # Image field 222 | class DepthImageField(Field): 223 | def __init__(self, folder_name_img, folder_name_depth, transform_img=None, transform_depth=None, 224 | extension_img='jpg', extension_depth='exr', random_view=True, 225 | with_camera=False, 226 | imageio_kwargs=dict()): 227 | self.folder_name_img = folder_name_img 228 | self.folder_name_depth = folder_name_depth 229 | self.transform_depth = transform_depth 230 | self.transform_img = transform_img 231 | self.extension_img = extension_img 232 | self.extension_depth = extension_depth 233 | self.random_view = random_view 234 | self.with_camera = with_camera 235 | self.imageio_kwargs = dict() 236 | 237 | def load(self, model_path, idx): 238 | folder_img = os.path.join(model_path, self.folder_name_img) 239 | files_img = glob.glob(os.path.join(folder_img, '*.%s' % self.extension_img)) 240 | files_img.sort() 241 | folder_depth = os.path.join(model_path, self.folder_name_depth) 242 | files_depth = glob.glob(os.path.join(folder_depth, '*.%s' % self.extension_depth)) 243 | files_depth.sort() 244 | if self.random_view: 245 | idx_img = random.randint(0, len(files_img)-1) 246 | else: 247 | idx_img = 0 248 | 249 | filename_img = files_img[idx_img] 250 | filename_depth = files_depth[idx_img] 251 | 252 | image = imageio.imread(filename_img, **self.imageio_kwargs) 253 | image = np.asarray(image) 254 | if image.shape[2] == 4: 255 | image = image[:,:,:3] 256 | 257 | depth = imageio.imread(filename_depth, **self.imageio_kwargs) 258 | 259 | depth = np.asarray(depth) 260 | 261 | if image.dtype == np.uint8: 262 | image = image.astype(np.float32) / 255 263 | else: 264 | image = image.astype(np.float32) 265 | 266 | if self.transform_img is not None: 267 | image = self.transform_img(image) 268 | 269 | if self.transform_depth is not None: 270 | depth = self.transform_depth(depth) 271 | 272 | image = image.transpose(2, 0, 1) 273 | #TODO adapt depth transpose 274 | depth = depth.transpose(2, 0, 1) 275 | 276 | data = { 277 | 'img': image, 278 | 'depth': depth 279 | } 280 | 281 | if self.with_camera: 282 | camera_file = os.path.join(folder_depth, 'cameras.npz') 283 | camera_dict = np.load(camera_file) 284 | Rt = camera_dict['world_mat_%d' % idx_img].astype(np.float32) 285 | K = camera_dict['camera_mat_%d' % idx_img].astype(np.float32) 286 | data['world_mat'] = Rt 287 | data['camera_mat'] = K 288 | 289 | return data 290 | 291 | def check_complete(self, files): 292 | complete = (self.folder_name_img in files) 293 | # TODO: check camera 294 | return complete -------------------------------------------------------------------------------- /mesh2tex/data/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.transform import resize 3 | from scipy.spatial import cKDTree as KDTree 4 | 5 | 6 | class PointcloudNoise(object): 7 | def __init__(self, stddev): 8 | self.stddev = stddev 9 | 10 | def __call__(self, data): 11 | data_out = data.copy() 12 | points = data[None] 13 | noise = self.stddev * np.random.randn(*points.shape) 14 | noise = noise.astype(np.float32) 15 | data_out[None] = points + noise 16 | return data_out 17 | 18 | 19 | class SubsamplePointcloud(object): 20 | def __init__(self, N): 21 | self.N = N 22 | 23 | def __call__(self, data): 24 | data_out = data.copy() 25 | points = data[None] 26 | normals = data['normals'] 27 | 28 | indices = np.random.randint(points.shape[1], size=self.N) 29 | data_out[None] = points[:, indices] 30 | data_out['normals'] = normals[:, indices] 31 | 32 | return data_out 33 | 34 | 35 | class ComputeKNNPointcloud(object): 36 | def __init__(self, K): 37 | self.K = K 38 | 39 | def __call__(self, data): 40 | data_out = data.copy() 41 | points = data[None] 42 | kdtree = KDTree(points.T) 43 | knn_idx = kdtree.query(points.T, k=self.K)[1] 44 | knn_idx = knn_idx.T 45 | 46 | data_out['knn_idx'] = knn_idx 47 | 48 | return data_out 49 | 50 | 51 | class ImageToGrayscale(object): 52 | def __call__(self, img): 53 | r, g, b = img[..., 0:1], img[..., 1:2], img[..., 2:3] 54 | out = 0.2990 * r + 0.5870 * g + 0.1140 * b 55 | return out 56 | 57 | 58 | class ImageToDepthValue(object): 59 | def __call__(self, img): 60 | return img[..., :1] 61 | 62 | 63 | class ResizeImage(object): 64 | def __init__(self, size, order=1): 65 | self.size = size 66 | self.order = order 67 | 68 | def __call__(self, img): 69 | img_out = resize(img, self.size, order=self.order, 70 | clip=False, mode='constant', 71 | anti_aliasing=False) 72 | img_out = img_out.astype(img.dtype) 73 | return img_out 74 | -------------------------------------------------------------------------------- /mesh2tex/eval.py: -------------------------------------------------------------------------------- 1 | import mesh2tex.utils.FID.fid_score as FID 2 | import mesh2tex.utils.FID.feature_l1 as feature_l1 3 | import mesh2tex.utils.SSIM_L1.ssim_l1_score as SSIM 4 | 5 | 6 | def evaluate_generated_images(metric, path_fake, path_real, batch_size=64): 7 | """ 8 | Start requested evaluation functions 9 | 10 | args: 11 | metric 12 | path_fake (Path to fake images) 13 | path_real (Path to real images) 14 | batch_size 15 | 16 | return: 17 | val_dict: dict with all metrics 18 | """ 19 | if metric == 'FID': 20 | paths = (path_fake, path_real) 21 | value = FID.calculate_fid_given_paths(paths, 22 | batch_size, 23 | True, 24 | 2048) 25 | val_dict = {'FID': value} 26 | 27 | elif metric == 'SSIM_L1': 28 | paths = (path_fake, path_real) 29 | value = SSIM.calculate_ssim_l1_given_paths(paths) 30 | val_dict = {'SSIM': value[0], 31 | 'L1': value[1]} 32 | elif metric == 'FeatL1': 33 | paths = (path_fake, path_real) 34 | value = feature_l1.calculate_feature_l1_given_paths( 35 | paths, batch_size, True, 2048) 36 | val_dict = {'FeatL1': value} 37 | 38 | elif metric == 'all': 39 | paths = (path_fake, path_real) 40 | value = SSIM.calculate_ssim_l1_given_paths(paths) 41 | value_FID = FID.calculate_fid_given_paths( 42 | paths, batch_size, True, 2048) 43 | value_FeatL1 = feature_l1.calculate_feature_l1_given_paths( 44 | paths, batch_size, True, 2048) 45 | val_dict = {'FID': value_FID, 46 | 'SSIM': value[0], 47 | 'L1': value[1], 48 | 'FeatL1': value_FeatL1} 49 | 50 | return val_dict 51 | -------------------------------------------------------------------------------- /mesh2tex/geometry/__init__.py: -------------------------------------------------------------------------------- 1 | from mesh2tex.geometry import ( 2 | pointnet 3 | ) 4 | 5 | 6 | encoder_dict = { 7 | 'simple': pointnet.SimplePointnet, 8 | 'resnet': pointnet.ResnetPointnetConv, 9 | } 10 | 11 | 12 | def get_representation(batch, device=None): 13 | mesh_points = batch['pointcloud'].to(device) 14 | mesh_normals = batch['pointcloud.normals'].to(device) 15 | geom_repr = { 16 | 'points': mesh_points, 17 | 'normals': mesh_normals, 18 | } 19 | if 'pointcloud.knn_idx' in batch: 20 | knn_idx = batch['pointcloud.knn_idx'].to(device) 21 | geom_repr['knn_idx'] = knn_idx 22 | 23 | return geom_repr 24 | 25 | -------------------------------------------------------------------------------- /mesh2tex/geometry/pointnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from mesh2tex.layers import EqualizedLR 5 | from mesh2tex.layers import ResnetBlockFC, ResnetBlockConv1D 6 | 7 | 8 | class SimplePointnet(nn.Module): 9 | def __init__(self, c_dim=128, hidden_dim=128, 10 | leaky=False, eq_lr=False): 11 | super().__init__() 12 | # Attributes 13 | self.c_dim = c_dim 14 | self.eq_lr = eq_lr 15 | 16 | # Activation function 17 | if not leaky: 18 | self.actvn = F.relu 19 | self.pool = maxpool 20 | else: 21 | self.actvn = lambda x: F.leaky_relu(x, 0.2) 22 | self.pool = avgpool 23 | 24 | # Submodules 25 | self.conv_p = nn.Conv1d(6, 2*hidden_dim, 1) 26 | self.conv_0 = nn.Conv1d(2*hidden_dim, hidden_dim, 1) 27 | self.conv_1 = nn.Conv1d(2*hidden_dim, hidden_dim, 1) 28 | self.conv_2 = nn.Conv1d(2*hidden_dim, hidden_dim, 1) 29 | self.conv_3 = nn.Conv1d(2*hidden_dim, hidden_dim, 1) 30 | self.fc_c = nn.Linear(hidden_dim, c_dim) 31 | 32 | if self.eq_lr: 33 | self.conv_p = EqualizedLR(self.conv_p) 34 | self.conv_0 = EqualizedLR(self.conv_0) 35 | self.conv_1 = EqualizedLR(self.conv_1) 36 | self.conv_2 = EqualizedLR(self.conv_2) 37 | self.conv_3 = EqualizedLR(self.conv_3) 38 | self.fc_c = EqualizedLR(self.fc_c) 39 | 40 | def forward(self, geometry): 41 | p = geometry['points'] 42 | n = geometry['normals'] 43 | 44 | # Encode position into batch_size x F x T 45 | pn = torch.cat([p, n], dim=1) 46 | net = self.conv_p(pn) 47 | 48 | # Always pool to batch_size x F x 1, 49 | # expand to batch_size x F x T 50 | # and concatenate to batch_size x 2F x T 51 | net = self.conv_0(self.actvn(net)) 52 | pooled = self.pool(net, dim=2, keepdim=True) 53 | pooled = pooled.expand(net.size()) 54 | net = torch.cat([net, pooled], dim=1) 55 | 56 | net = self.conv_1(self.actvn(net)) 57 | pooled = self.pool(net, dim=2, keepdim=True) 58 | pooled = pooled.expand(net.size()) 59 | net = torch.cat([net, pooled], dim=1) 60 | 61 | net = self.conv_2(self.actvn(net)) 62 | pooled = self.pool(net, dim=2, keepdim=True) 63 | pooled = pooled.expand(net.size()) 64 | net = torch.cat([net, pooled], dim=1) 65 | 66 | net = self.conv_3(self.actvn(net)) 67 | 68 | # Recude to batch_size x F 69 | net = self.pool(net, dim=2) 70 | 71 | c = self.fc_c(self.actvn(net)) 72 | 73 | geom_descr = { 74 | 'global': c, 75 | } 76 | 77 | return geom_descr 78 | 79 | 80 | class ResnetPointnet(nn.Module): 81 | def __init__(self, c_dim=128, dim=6, hidden_dim=128): 82 | super().__init__() 83 | self.c_dim = c_dim 84 | 85 | self.fc_pos = nn.Linear(dim, 2*hidden_dim) 86 | self.block_0 = ResnetBlockFC(2*hidden_dim, hidden_dim) 87 | self.block_1 = ResnetBlockFC(2*hidden_dim, hidden_dim) 88 | self.block_2 = ResnetBlockFC(2*hidden_dim, hidden_dim) 89 | self.block_3 = ResnetBlockFC(2*hidden_dim, hidden_dim) 90 | self.block_4 = ResnetBlockFC(2*hidden_dim, hidden_dim) 91 | self.fc_c = nn.Linear(hidden_dim, c_dim) 92 | 93 | self.actvn = nn.ReLU() 94 | self.pool = maxpool 95 | 96 | def forward(self, geometry): 97 | p = geometry['points'] 98 | n = geometry['normals'] 99 | batch_size, T, D = p.size() 100 | 101 | pn = torch.cat([p, n], dim=1) 102 | # output size: B x T X F 103 | net = self.fc_pos(pn) 104 | net = self.block_0(net) 105 | pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 106 | net = torch.cat([net, pooled], dim=2) 107 | 108 | net = self.block_1(net) 109 | pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 110 | net = torch.cat([net, pooled], dim=2) 111 | 112 | net = self.block_2(net) 113 | pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 114 | net = torch.cat([net, pooled], dim=2) 115 | 116 | net = self.block_3(net) 117 | pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 118 | net = torch.cat([net, pooled], dim=2) 119 | 120 | net = self.block_4(net) 121 | 122 | # Recude to B x F 123 | net = self.pool(net, dim=1) 124 | 125 | c = self.fc_c(self.actvn(net)) 126 | 127 | return c 128 | 129 | 130 | class ResnetPointnetConv(nn.Module): 131 | def __init__(self, c_dim=128, dim=6, hidden_dim=128): 132 | super().__init__() 133 | self.c_dim = c_dim 134 | 135 | self.fc_pos = nn.Conv1d(dim, 2*hidden_dim, 1) 136 | self.block_0 = ResnetBlockConv1D(2*hidden_dim, hidden_dim) 137 | self.block_1 = ResnetBlockConv1D(2*hidden_dim, hidden_dim) 138 | self.block_2 = ResnetBlockConv1D(2*hidden_dim, hidden_dim) 139 | self.block_3 = ResnetBlockConv1D(2*hidden_dim, hidden_dim) 140 | self.block_4 = ResnetBlockConv1D(2*hidden_dim, hidden_dim) 141 | self.fc_c = nn.Linear(hidden_dim, c_dim) 142 | 143 | self.actvn = nn.ReLU() 144 | self.pool = maxpool 145 | 146 | def forward(self, geometry): 147 | p = geometry['points'] 148 | n = geometry['normals'] 149 | batch_size, T, D = p.size() 150 | 151 | pn = torch.cat([p, n], dim=1) 152 | # output size: B x T X F 153 | net = self.fc_pos(pn) 154 | net = self.block_0(net) 155 | pooled = self.pool(net, dim=2, keepdim=True).expand(net.size()) 156 | net = torch.cat([net, pooled], dim=1) 157 | 158 | net = self.block_1(net) 159 | pooled = self.pool(net, dim=2, keepdim=True).expand(net.size()) 160 | net = torch.cat([net, pooled], dim=1) 161 | 162 | net = self.block_2(net) 163 | pooled = self.pool(net, dim=2, keepdim=True).expand(net.size()) 164 | net = torch.cat([net, pooled], dim=1) 165 | 166 | net = self.block_3(net) 167 | pooled = self.pool(net, dim=2, keepdim=True).expand(net.size()) 168 | net = torch.cat([net, pooled], dim=1) 169 | 170 | net = self.block_4(net) 171 | 172 | # Recude to B x F 173 | net = self.pool(net, dim=2) 174 | 175 | c = self.fc_c(self.actvn(net)) 176 | 177 | geom_descr = { 178 | 'global': c, 179 | } 180 | 181 | return geom_descr 182 | 183 | 184 | def maxpool(x, dim=-1, keepdim=False): 185 | out, _ = x.max(dim=dim, keepdim=keepdim) 186 | return out 187 | 188 | 189 | def avgpool(x, dim=-1, keepdim=False): 190 | out = x.mean(dim=dim, keepdim=keepdim) 191 | return out 192 | -------------------------------------------------------------------------------- /mesh2tex/layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResnetBlockFC(nn.Module): 7 | def __init__(self, size_in, size_out=None, size_h=None): 8 | super().__init__() 9 | # Attributes 10 | if size_out is None: 11 | size_out = size_in 12 | 13 | if size_h is None: 14 | size_h = min(size_in, size_out) 15 | 16 | self.size_in = size_in 17 | self.size_h = size_h 18 | self.size_out = size_out 19 | # Submodules 20 | self.fc_0 = nn.Linear(size_in, size_h) 21 | self.fc_1 = nn.Linear(size_h, size_out) 22 | self.actvn = nn.ReLU() 23 | 24 | if size_in == size_out: 25 | self.shortcut = None 26 | else: 27 | self.shortcut = nn.Linear(size_in, size_out, bias=False) 28 | # Initialization 29 | nn.init.zeros_(self.fc_1.weight) 30 | 31 | def forward(self, x): 32 | net = self.fc_0(self.actvn(x)) 33 | dx = self.fc_1(self.actvn(net)) 34 | 35 | if self.shortcut is not None: 36 | x_s = self.shortcut(x) 37 | else: 38 | x_s = x 39 | 40 | return x_s + dx 41 | 42 | 43 | # Resnet Blocks 44 | class ResnetBlockConv1D(nn.Module): 45 | def __init__(self, size_in, size_out=None, size_h=None): 46 | super().__init__() 47 | # Attributes 48 | if size_out is None: 49 | size_out = size_in 50 | 51 | if size_h is None: 52 | size_h = min(size_in, size_out) 53 | 54 | self.size_in = size_in 55 | self.size_h = size_h 56 | self.size_out = size_out 57 | # Submodules 58 | self.fc_0 = nn.Conv1d(size_in, size_h, 1) 59 | self.fc_1 = nn.Conv1d(size_h, size_out, 1) 60 | self.actvn = nn.ReLU() 61 | 62 | if size_in == size_out: 63 | self.shortcut = None 64 | else: 65 | self.shortcut = nn.Conv1d(size_in, size_out, 1, bias=False) 66 | # Initialization 67 | nn.init.zeros_(self.fc_1.weight) 68 | 69 | def forward(self, x): 70 | net = self.fc_0(self.actvn(x)) 71 | dx = self.fc_1(self.actvn(net)) 72 | 73 | if self.shortcut is not None: 74 | x_s = self.shortcut(x) 75 | else: 76 | x_s = x 77 | 78 | return x_s + dx 79 | 80 | 81 | class ResnetBlockPointwise(nn.Module): 82 | def __init__(self, f_in, f_out=None, f_hidden=None, 83 | is_bias=True, actvn=F.relu, factor=1., eq_lr=False): 84 | super().__init__() 85 | # Filter dimensions 86 | if f_out is None: 87 | f_out = f_in 88 | 89 | if f_hidden is None: 90 | f_hidden = min(f_in, f_out) 91 | 92 | self.f_in = f_in 93 | self.f_hidden = f_hidden 94 | self.f_out = f_out 95 | 96 | self.factor = factor 97 | self.eq_lr = eq_lr 98 | 99 | # Activation function 100 | self.actvn = actvn 101 | 102 | # Submodules 103 | self.conv_0 = nn.Conv1d(f_in, f_hidden, 1) 104 | self.conv_1 = nn.Conv1d(f_hidden, f_out, 1, bias=is_bias) 105 | 106 | if self.eq_lr: 107 | self.conv_0 = EqualizedLR(self.conv_0) 108 | self.conv_1 = EqualizedLR(self.conv_1) 109 | 110 | if f_in == f_out: 111 | self.shortcut = nn.Sequential() 112 | else: 113 | self.shortcut = nn.Conv1d(f_in, f_out, 1, bias=False) 114 | if self.eq_lr: 115 | self.shortcut = EqualizedLR(self.shortcut) 116 | 117 | # Initialization 118 | nn.init.zeros_(self.conv_1.weight) 119 | 120 | def forward(self, x): 121 | net = self.conv_0(self.actvn(x)) 122 | dx = self.conv_1(self.actvn(net)) 123 | x_s = self.shortcut(x) 124 | return x_s + self.factor * dx 125 | 126 | 127 | class ResnetBlockConv2d(nn.Module): 128 | def __init__(self, f_in, f_out=None, f_hidden=None, 129 | is_bias=True, actvn=F.relu, factor=1., 130 | eq_lr=False, pixel_norm=False): 131 | super().__init__() 132 | # Filter dimensions 133 | if f_out is None: 134 | f_out = f_in 135 | 136 | if f_hidden is None: 137 | f_hidden = min(f_in, f_out) 138 | 139 | self.f_in = f_in 140 | self.f_hidden = f_hidden 141 | self.f_out = f_out 142 | self.factor = factor 143 | self.eq_lr = eq_lr 144 | self.use_pixel_norm = pixel_norm 145 | 146 | # Activation 147 | self.actvn = actvn 148 | 149 | # Submodules 150 | self.conv_0 = nn.Conv2d(self.f_in, self.f_hidden, 3, 151 | stride=1, padding=1) 152 | self.conv_1 = nn.Conv2d(self.f_hidden, self.f_out, 3, 153 | stride=1, padding=1, bias=is_bias) 154 | 155 | if self.eq_lr: 156 | self.conv_0 = EqualizedLR(self.conv_0) 157 | self.conv_1 = EqualizedLR(self.conv_1) 158 | 159 | if f_in == f_out: 160 | self.shortcut = nn.Sequential() 161 | else: 162 | self.shortcut = nn.Conv2d(f_in, f_out, 1, bias=False) 163 | if self.eq_lr: 164 | self.shortcut = EqualizedLR(self.shortcut) 165 | 166 | # Initialization 167 | nn.init.zeros_(self.conv_1.weight) 168 | 169 | def forward(self, x): 170 | x_s = self.shortcut(x) 171 | 172 | if self.use_pixel_norm: 173 | x = pixel_norm(x) 174 | dx = self.conv_0(self.actvn(x)) 175 | 176 | if self.use_pixel_norm: 177 | dx = pixel_norm(dx) 178 | dx = self.conv_1(self.actvn(dx)) 179 | 180 | out = x_s + self.factor * dx 181 | 182 | return out 183 | 184 | def _shortcut(self, x): 185 | if self.learned_shortcut: 186 | x_s = self.conv_s(x) 187 | else: 188 | x_s = x 189 | return x_s 190 | 191 | 192 | class EqualizedLR(nn.Module): 193 | def __init__(self, module): 194 | super().__init__() 195 | self.module = module 196 | self._make_params() 197 | 198 | def _make_params(self): 199 | weight = self.module.weight 200 | 201 | height = weight.data.shape[0] 202 | width = weight.view(height, -1).data.shape[1] 203 | 204 | # Delete parameters in child 205 | del self.module._parameters['weight'] 206 | self.module.weight = None 207 | 208 | # Add parameters to myself 209 | self.weight = nn.Parameter(weight.data) 210 | 211 | # Inherit parameters 212 | self.factor = np.sqrt(2 / width) 213 | 214 | # Initialize 215 | nn.init.normal_(self.weight) 216 | 217 | # Inherit bias if available 218 | self.bias = self.module.bias 219 | self.module.bias = None 220 | 221 | if self.bias is not None: 222 | del self.module._parameters['bias'] 223 | nn.init.zeros_(self.bias) 224 | 225 | def forward(self, *args, **kwargs): 226 | self.module.weight = self.factor * self.weight 227 | if self.bias is not None: 228 | self.module.bias = 1. * self.bias 229 | out = self.module.forward(*args, **kwargs) 230 | self.module.weight = None 231 | self.module.bias = None 232 | return out 233 | 234 | 235 | def pixel_norm(x): 236 | sigma = x.norm(dim=1, keepdim=True) 237 | out = x / (sigma + 1e-5) 238 | return out 239 | -------------------------------------------------------------------------------- /mesh2tex/nvs/__init__.py: -------------------------------------------------------------------------------- 1 | from mesh2tex.nvs import ( 2 | training, generation, config, models 3 | ) 4 | 5 | __all__ = [ 6 | training, 7 | generation, 8 | config, 9 | models, 10 | ] 11 | -------------------------------------------------------------------------------- /mesh2tex/nvs/config.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.distributions as dist 4 | import torch.optim as optim 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | from mesh2tex import data, geometry 8 | from mesh2tex.nvs import training, generation 9 | from mesh2tex.nvs import models 10 | 11 | 12 | def get_models(cfg, device=None, dataset=None): 13 | # Get configs 14 | encoder = cfg['model']['encoder'] 15 | decoder = cfg['model']['decoder'] 16 | discriminator = cfg['model']['discriminator'] 17 | 18 | encoder_kwargs = cfg['model']['encoder_kwargs'] 19 | decoder_kwargs = cfg['model']['decoder_kwargs'] 20 | discriminator_kwargs = cfg['model']['discriminator_kwargs'] 21 | 22 | img_size = cfg['data']['img_size'] 23 | z_dim = cfg['model']['z_dim'] 24 | c_dim = cfg['model']['c_dim'] 25 | white_bg = cfg['model']['white_bg'] 26 | # Create generator 27 | 28 | if encoder is not None: 29 | encoder = models.encoder_dict[encoder]( 30 | c_dim=c_dim, **encoder_kwargs 31 | ).to(device) 32 | 33 | decoder = models.decoder_dict[decoder]( 34 | c_dim=c_dim, white_bg=white_bg, **decoder_kwargs 35 | ).to(device) 36 | 37 | generator = models.NovelViewSynthesis( 38 | decoder, encoder 39 | ) 40 | 41 | p0_z = get_prior_z(cfg, device) 42 | 43 | # Create discriminator 44 | discriminator = models.discriminator_dict[discriminator]( 45 | img_size=img_size, **discriminator_kwargs 46 | ).to(device) 47 | 48 | # Output dict 49 | models_out = { 50 | 'generator': generator, 51 | 'discriminator': discriminator, 52 | } 53 | 54 | return models_out 55 | 56 | 57 | def get_optimizers(models, cfg): 58 | model_g = models['generator'] 59 | model_d = models['discriminator'] 60 | 61 | lr_g = cfg['training']['lr_g'] 62 | lr_d = cfg['training']['lr_d'] 63 | optimizer_g = optim.RMSprop(model_g.parameters(), lr=lr_g) 64 | optimizer_d = optim.RMSprop(model_d.parameters(), lr=lr_d) 65 | 66 | optimizers = { 67 | 'generator': optimizer_g, 68 | 'discriminator': optimizer_d, 69 | } 70 | return optimizers 71 | 72 | 73 | def get_dataset(mode, cfg, input_sampling=True): 74 | # Config 75 | path_shapes = cfg['data']['path_shapes'] 76 | img_size = cfg['data']['img_size'] 77 | 78 | # Transforms 79 | transform_img = transforms.Compose([ 80 | data.ResizeImage((img_size, img_size), order=0), 81 | ]) 82 | 83 | transform_img_input = transforms.Compose([ 84 | data.ResizeImage((224, 224), order=0), 85 | ]) 86 | 87 | transform_depth = torchvision.transforms.Compose([ 88 | data.ImageToDepthValue(), 89 | data.ResizeImage((img_size, img_size), order=0), 90 | ]) 91 | 92 | # Fields 93 | if mode == 'train': 94 | fields = { 95 | '2d': data.DepthImageField( 96 | 'image', 'depth', transform_img, transform_depth, 'png', 97 | 'exr', with_camera=True, random_view=True), 98 | 'condition': data.ImagesField('input_image', 99 | transform_img_input, 'jpg'), 100 | } 101 | mode_ = 'train' 102 | 103 | elif mode == 'val_eval': 104 | fields = { 105 | '2d': data.DepthImageField( 106 | 'image', 'depth', transform_img, transform_depth, 'png', 107 | 'exr', with_camera=True, random_view=True), 108 | 'condition': data.ImagesField( 109 | 'input_image', 110 | transform_img_input, 'jpg' 111 | ), 112 | } 113 | mode_ = 'val' 114 | 115 | elif mode == 'val_vis': 116 | # elif for_vis is True or cfg['training']['vis_fixviews'] is True: 117 | fields = { 118 | '2d': data.DepthImageVisualizeField( 119 | 'visualize/image', 'visualize/depth', 120 | transform_img, transform_depth, 'png', 121 | 'exr', with_camera=True, random_view=True 122 | ), 123 | 'condition': data.ImagesField( 124 | 'input_image', 125 | transform_img_input, 'jpg' 126 | ), 127 | } 128 | mode_ = 'val' 129 | 130 | elif mode == 'test_eval': 131 | # elif for_eval is True: 132 | fields = { 133 | '2d': data.DepthImageVisualizeField( 134 | 'image', 'depth', 135 | transform_img, transform_depth, 'png', 136 | 'exr', with_camera=True, random_view=True 137 | ), 138 | 'condition': data.ImagesField('input_image', 139 | transform_img_input, 'jpg', 140 | random_view=input_sampling 141 | ), 142 | 'idx': data.IndexField(), 143 | } 144 | mode_ = 'test' 145 | 146 | elif mode == 'test_vis': 147 | # elif for_vis is True or cfg['training']['vis_fixviews'] is True: 148 | fields = { 149 | '2d': data.DepthImageVisualizeField( 150 | 'visualize/image', 'visualize/depth', 151 | transform_img, transform_depth, 'png', 152 | 'exr', with_camera=True, random_view=True 153 | ), 154 | 'condition': data.ImagesField('input_image', 155 | transform_img_input, 'jpg', 156 | random_view=input_sampling 157 | ), 158 | 'idx': data.IndexField(), 159 | } 160 | mode_ = 'test' 161 | 162 | else: 163 | fields = { 164 | '2d': data.DepthImageField( 165 | 'image', 'depth', transform_img, transform_depth, 'png', 166 | 'exr', with_camera=True, random_view=True), 167 | 'condition': data.ImagesField('input_image', 168 | transform_img_input, 'jpg'), 169 | } 170 | 171 | if cfg['data']['shapes_multiclass']: 172 | ds_shapes = data.Shapes3dDataset( 173 | path_shapes, fields, split=mode_, no_except=True, 174 | ) 175 | else: 176 | ds_shapes = data.Shapes3dClassDataset( 177 | path_shapes, fields, split=mode_, no_except=True, 178 | ) 179 | 180 | if mode_ == 'val' or mode_ == 'test': 181 | ds = ds_shapes 182 | else: 183 | ds = data.CombinedDataset([ds_shapes, ds_shapes]) 184 | 185 | return ds 186 | return ds_shapes 187 | 188 | 189 | def get_dataloader(mode, cfg): 190 | # Config 191 | batch_size = cfg['training']['batch_size'] 192 | if mode != 'train': 193 | batch_size = cfg['generation']['batch_size'] 194 | ds_shapes = get_dataset(mode, cfg) 195 | 196 | data_loader = torch.utils.data.DataLoader( 197 | ds_shapes, batch_size=batch_size, num_workers=4, shuffle=True, 198 | collate_fn=data.collate_remove_none) 199 | 200 | return data_loader 201 | 202 | 203 | def get_meshloader(mode, cfg): 204 | # Config 205 | 206 | path_meshes = cfg['data']['path_meshes'] 207 | 208 | batch_size = cfg['training']['batch_size'] 209 | 210 | fields = { 211 | 'meshes': data.MeshField('mesh'), 212 | } 213 | 214 | ds_shapes = data.Shapes3dClassDataset( 215 | path_meshes, fields, split=None, no_except=True, 216 | ) 217 | 218 | data_loader = torch.utils.data.DataLoader( 219 | ds_shapes, batch_size=batch_size, num_workers=4, shuffle=True, 220 | collate_fn=data.collate_remove_none) 221 | 222 | return data_loader 223 | 224 | 225 | def get_trainer(models, optimizers, cfg, device=None): 226 | out_dir = cfg['training']['out_dir'] 227 | 228 | print_every = cfg['training']['print_every'] 229 | visualize_every = cfg['training']['visualize_every'] 230 | checkpoint_every = cfg['training']['checkpoint_every'] 231 | validate_every = cfg['training']['validate_every'] 232 | backup_every = cfg['training']['backup_every'] 233 | 234 | model_selection_metric = cfg['training']['model_selection_metric'] 235 | model_selection_mode = cfg['training']['model_selection_mode'] 236 | 237 | ma_beta = cfg['training']['moving_average_beta'] 238 | multi_gpu = cfg['training']['multi_gpu'] 239 | gp_reg = cfg['training']['gradient_penalties_reg'] 240 | w_pix = cfg['training']['weight_pixelloss'] 241 | w_gan = cfg['training']['weight_ganloss'] 242 | w_vae = cfg['training']['weight_vaeloss'] 243 | 244 | trainer = training.Trainer( 245 | models['generator'], models['discriminator'], 246 | optimizers['generator'], optimizers['discriminator'], 247 | ma_beta=ma_beta, 248 | multi_gpu=multi_gpu, 249 | gp_reg=gp_reg, 250 | w_pix=w_pix, w_gan=w_gan, w_vae=w_vae, 251 | out_dir=out_dir, 252 | model_selection_metric=model_selection_metric, 253 | model_selection_mode=model_selection_mode, 254 | print_every=print_every, 255 | visualize_every=visualize_every, 256 | checkpoint_every=checkpoint_every, 257 | backup_every=backup_every, 258 | validate_every=validate_every, 259 | device=device, 260 | ) 261 | 262 | return trainer 263 | 264 | 265 | def get_generator(model, cfg, device, **kwargs): 266 | 267 | generator = generation.Generator3D( 268 | model, 269 | device=device, 270 | ) 271 | return generator 272 | 273 | def get_prior_z(cfg, device, **kwargs): 274 | z_dim = cfg['model']['z_dim'] 275 | p0_z = dist.Normal( 276 | torch.zeros(z_dim, device=device), 277 | torch.ones(z_dim, device=device) 278 | ) 279 | 280 | return p0_z -------------------------------------------------------------------------------- /mesh2tex/nvs/generation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import trimesh 4 | import os 5 | from trimesh.util import array_to_string 6 | from mesh2tex import geometry 7 | from torchvision.utils import save_image 8 | 9 | 10 | class Generator3D(object): 11 | def __init__(self, model, device=None): 12 | 13 | self.model = model 14 | self.device = device 15 | 16 | def generate_images_4eval(self, batch, out_dir, model_names): 17 | depth = batch['2d.depth'].to(self.device) 18 | img_real = batch['2d.img'].to(self.device) 19 | condition = batch['condition'].to(self.device) 20 | batch_size = depth.size(0) 21 | num_views = depth.size(1) 22 | # assert(num_views == 5) 23 | # Save real images 24 | 25 | out_dir_real = out_dir + "/real/" 26 | out_dir_fake = out_dir + "/fake/" 27 | out_dir_condition = out_dir + "/condition/" 28 | if not os.path.exists(out_dir_real): 29 | os.makedirs(out_dir_real) 30 | if not os.path.exists(out_dir_fake): 31 | os.makedirs(out_dir_fake) 32 | if not os.path.exists(out_dir_condition): 33 | os.makedirs(out_dir_condition) 34 | 35 | for j in range(batch_size): 36 | save_image( 37 | condition[j].cpu(), 38 | os.path.join(out_dir_condition, '%s.png' % model_names[j])) 39 | 40 | for v in range(num_views): 41 | depth_ = depth[:, v] 42 | img_real_ = img_real[:, v] 43 | 44 | self.model.eval() 45 | 46 | with torch.no_grad(): 47 | img_fake_ = self.model(depth_, condition) 48 | 49 | for j in range(batch_size): 50 | save_image( 51 | img_real_[j].cpu(), 52 | os.path.join( 53 | out_dir_real, '%s%03d.png' % (model_names[j], v) 54 | )) 55 | save_image( 56 | img_fake_[j].cpu(), 57 | os.path.join( 58 | out_dir_fake, '%s%03d.png' % (model_names[j], v) 59 | )) 60 | 61 | def generate_images_4eval_condi_hd(self, batch, out_dir, model_names): 62 | depth = batch['2d.depth'] #.to(self.device) 63 | img_real = batch['2d.img'] #.to(self.device) 64 | condition = batch['condition'].to(self.device) 65 | batch_size = depth.size(0) 66 | num_views = depth.size(1) 67 | # if depth.size(1) >= 10: 68 | # num_views = 10 69 | # assert(num_views == 5) 70 | # Save real images 71 | out_dir_real = out_dir + "/real/" 72 | out_dir_fake = out_dir + "/fake/" 73 | out_dir_condition = out_dir + "/condition/" 74 | if not os.path.exists(out_dir_real): 75 | os.makedirs(out_dir_real) 76 | if not os.path.exists(out_dir_fake): 77 | os.makedirs(out_dir_fake) 78 | if not os.path.exists(out_dir_condition): 79 | os.makedirs(out_dir_condition) 80 | viewbatchsize = 2 81 | viewbatchnum = int(num_views / viewbatchsize) 82 | # points_batches = points.split(10, dim=0) 83 | for j in range(batch_size): 84 | for vidx in range(viewbatchnum): 85 | lower = vidx * viewbatchsize 86 | upper = (vidx + 1) * viewbatchsize 87 | 88 | depth_ = depth[j][lower:upper] 89 | img_real_ = img_real[j][lower:upper] 90 | condition_ = condition[j][:4].expand( 91 | viewbatchsize, condition.size(1), 92 | condition.size(2), condition.size(3)) 93 | 94 | self.model.eval() 95 | with torch.no_grad(): 96 | img_fake = self.model(depth_.to(self.device), condition_.to(self.device)) 97 | for v in range(viewbatchsize): 98 | save_image( 99 | img_real_[v], 100 | os.path.join(out_dir_real, 101 | '%s%03d.png' % (model_names[j], vidx * viewbatchsize + v))) 102 | save_image( 103 | img_fake[v].cpu(), 104 | os.path.join(out_dir_fake, 105 | '%s%03d.png' % (model_names[j], vidx * viewbatchsize + v))) 106 | save_image( 107 | condition[j].cpu(), 108 | os.path.join(out_dir_condition, 109 | '%s.png' % (model_names[j]))) 110 | -------------------------------------------------------------------------------- /mesh2tex/nvs/models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import distributions as dist 4 | import trimesh 5 | from mesh2tex.nvs.models import ( 6 | encoder, decoder, discriminator 7 | ) 8 | 9 | encoder_dict = { 10 | 'resnet18': encoder.Resnet18, 11 | } 12 | 13 | decoder_dict = { 14 | 'each_layer_c': decoder.UNetEachLayerC, 15 | } 16 | 17 | discriminator_dict = { 18 | 'resnet': discriminator.Resnet, 19 | } 20 | 21 | 22 | class NovelViewSynthesis(nn.Module): 23 | def __init__(self, decoder, encoder): 24 | super().__init__() 25 | 26 | self.decoder = decoder 27 | self.encoder = encoder 28 | 29 | def forward(self, depth, condition): 30 | """Generate an image . 31 | 32 | Args: 33 | depth (torch.FloatTensor): tensor of size B x 1 x N x M 34 | representing depth of at pixels 35 | Returns: 36 | img (torch.FloatTensor): tensor of size B x 3 x N x M representing 37 | output image 38 | """ 39 | batch_size, _, N, M = depth.size() 40 | 41 | assert(depth.size(1) == 1) 42 | 43 | c = self.encode(condition) 44 | img = self.decode(depth, c) 45 | 46 | return img 47 | 48 | def encode(self, cond): 49 | """Encode mesh using sampled 3D location on the mesh. 50 | 51 | Args: 52 | input_image (torch.FloatTensor): tensor of size B x 3 x N x M 53 | input image 54 | 55 | Returns: 56 | c (torch.FloatTensor): tensor of size B x C with encoding of 57 | the input image 58 | """ 59 | z = self.encoder(cond) 60 | return z 61 | 62 | def decode(self, depth, c): 63 | """Decode image from 3D locations, conditional encoding and latent 64 | encoding. 65 | 66 | Args: 67 | depth (torch.FloatTensor): tensor of size B x 1 x N x M 68 | representing depth of at pixels 69 | c (torch.FloatTensor): tensor of size B x C with the encoding of 70 | the 3D meshes 71 | 72 | Returns: 73 | rgb (torch.FloatTensor): tensor of size B x 3 x N representing 74 | color at given 3d locations 75 | """ 76 | rgb = self.decoder(depth, c) 77 | return rgb 78 | -------------------------------------------------------------------------------- /mesh2tex/nvs/models/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class UNetEachLayerC(nn.Module): 7 | def __init__(self, c_dim, white_bg=True, resnet_leaky=None): 8 | super().__init__() 9 | # Attributes 10 | self.c_dim = c_dim 11 | self.white_bg = white_bg 12 | 13 | # Submodules 14 | self.conv_0 = nn.Conv2d(1, 64, 3, padding=1, stride=2) 15 | self.conv_1 = nn.Conv2d(64, 128, 3, padding=1, stride=2) 16 | self.conv_2 = nn.Conv2d(128, 256, 3, padding=1, stride=2) 17 | self.conv_3 = nn.Conv2d(256, 512, 3, padding=1, stride=2) 18 | self.conv_4 = nn.Conv2d(512, 1024, 3, padding=1, stride=2) 19 | 20 | self.conv_trp_0 = nn.ConvTranspose2d(1024, 512, 3, padding=1, stride=2, output_padding=1) 21 | self.conv_trp_1 = nn.ConvTranspose2d(1024, 256, 3, padding=1, stride=2, output_padding=1) 22 | self.conv_trp_2 = nn.ConvTranspose2d(512, 128, 3, padding=1, stride=2, output_padding=1) 23 | self.conv_trp_3 = nn.ConvTranspose2d(256, 64, 3, padding=1, stride=2, output_padding=1) 24 | self.conv_trp_4 = nn.ConvTranspose2d(128, 3, 3, padding=1, stride=2, output_padding=1) 25 | 26 | self.fc_0 = nn.Linear(c_dim, 64) 27 | self.fc_1 = nn.Linear(c_dim, 128) 28 | self.fc_2 = nn.Linear(c_dim, 256) 29 | self.fc_3 = nn.Linear(c_dim, 512) 30 | self.fc_4 = nn.Linear(c_dim, 1024) 31 | 32 | def forward(self, depth, c): 33 | assert(c.size(0) == depth.size(0)) 34 | 35 | batch_size = depth.size(0) 36 | c_dim = self.c_dim 37 | 38 | mask = (depth != float('Inf')) 39 | depth = depth.clone() 40 | depth[~mask] = 0. 41 | 42 | net = depth 43 | 44 | # Downsample 45 | # 64 x 128 x 128 46 | net0 = self.conv_0(net) + self.fc_0(c).view(batch_size, 64, 1, 1) 47 | net0 = F.relu(net0) 48 | # 128 x 64 x 64 49 | net1 = self.conv_1(net0) + self.fc_1(c).view(batch_size, 128, 1, 1) 50 | net1 = F.relu(net1) 51 | # 256 x 32 x 32 52 | net2 = self.conv_2(net1) + self.fc_2(c).view(batch_size, 256, 1, 1) 53 | net2 = F.relu(net2) 54 | # 512 x 16 x 16 55 | net3 = self.conv_3(net2) + self.fc_3(c).view(batch_size, 512, 1, 1) 56 | net3 = F.relu(net3) 57 | # 1024 x 8 x 8 58 | net4 = self.conv_4(net3) + self.fc_4(c).view(batch_size, 1024, 1, 1) 59 | net4 = F.relu(net4) 60 | 61 | # Upsample 62 | # 512 x 16 x 16 63 | net = F.relu(self.conv_trp_0(net4)) 64 | # 256 x 32 x 32 65 | net = torch.cat([net, net3], dim=1) 66 | net = F.relu(self.conv_trp_1(net)) 67 | # 128 x 64 x 64 68 | net = torch.cat([net, net2], dim=1) 69 | net = F.relu(self.conv_trp_2(net)) 70 | # 64 x 128 x 128 71 | net = torch.cat([net, net1], dim=1) 72 | net = F.relu(self.conv_trp_3(net)) 73 | # 3 x 256 x 256 74 | net = torch.cat([net, net0], dim=1) 75 | net = self.conv_trp_4(net) 76 | net = torch.sigmoid(net) 77 | 78 | if self.white_bg: 79 | mask = mask.float() 80 | net = mask * net + (1 - mask) * torch.ones_like(net) 81 | else: 82 | mask = mask.float() 83 | net = mask * net + (1 - mask) * torch.zeros_like(net) 84 | 85 | return net -------------------------------------------------------------------------------- /mesh2tex/nvs/models/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from mesh2tex.layers import ResnetBlockConv2d 6 | 7 | 8 | class Resnet(nn.Module): 9 | def __init__(self, img_size, embed_size=256, 10 | nfilter=64, nfilter_max=1024, leaky=True): 11 | super().__init__() 12 | self.embed_size = embed_size 13 | s0 = self.s0 = 4 14 | nf = self.nf = nfilter 15 | nf_max = self.nf_max = nfilter_max 16 | 17 | # Activation function 18 | if not leaky: 19 | self.actvn = F.relu 20 | else: 21 | self.actvn = lambda x: F.leaky_relu(x, 0.2) 22 | 23 | # Submodules 24 | nlayers = int(np.log2(img_size / s0)) 25 | self.nf0 = min(nf_max, nf * 2**nlayers) 26 | 27 | blocks = [ 28 | ResnetBlockConv2d(nf, nf, actvn=self.actvn) 29 | ] 30 | 31 | for i in range(nlayers): 32 | nf0 = min(nf * 2**i, nf_max) 33 | nf1 = min(nf * 2**(i+1), nf_max) 34 | blocks += [ 35 | nn.AvgPool2d(3, stride=2, padding=1), 36 | ResnetBlockConv2d(nf0, nf1), 37 | ] 38 | 39 | self.conv_img = nn.Conv2d(3 +1 , 1*nf, 3, padding=1) 40 | self.resnet = nn.Sequential(*blocks) 41 | self.fc = nn.Linear(self.nf0*s0*s0, 1) 42 | 43 | # Initialization 44 | nn.init.zeros_(self.fc.weight) 45 | 46 | def forward(self, x, depth): 47 | batch_size = x.size(0) 48 | 49 | depth = depth.clone() 50 | depth[depth == float("Inf")] = 0 51 | depth[depth == -1*float("Inf")] = 0 52 | 53 | x_and_depth = torch.cat([x, depth], dim=1) 54 | 55 | out = self.conv_img(x_and_depth) 56 | out = self.resnet(out) 57 | out = out.view(batch_size, self.nf0*self.s0*self.s0) 58 | out = self.fc(self.actvn(out)) 59 | out = out.squeeze() 60 | 61 | return out 62 | -------------------------------------------------------------------------------- /mesh2tex/nvs/models/encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision import models 3 | 4 | 5 | class Resnet18(nn.Module): 6 | ''' ResNet-18 conditioning network. 7 | ''' 8 | def __init__(self, c_dim=128, normalize=True, use_linear=True): 9 | ''' Initialisation. 10 | 11 | Args: 12 | c_dim (int): output dimension of the latent embedding 13 | normalize (bool): whether the input images should be normalized 14 | use_linear (bool): whether a final linear layer should be used 15 | ''' 16 | super().__init__() 17 | self.normalize = normalize 18 | self.use_linear = use_linear 19 | self.features = models.resnet18(pretrained=True) 20 | self.features.fc = nn.Sequential() 21 | if use_linear: 22 | self.fc = nn.Linear(512, c_dim) 23 | elif c_dim == 512: 24 | self.fc = nn.Sequential() 25 | else: 26 | raise ValueError('c_dim must be 512 if use_linear is False') 27 | 28 | def forward(self, x): 29 | if self.normalize: 30 | x = normalize_imagenet(x) 31 | net = self.features(x) 32 | out = self.fc(net) 33 | return out 34 | 35 | 36 | def normalize_imagenet(x): 37 | x = x.clone() 38 | x[:, 0] = (x[:, 0] - 0.485) / 0.229 39 | x[:, 1] = (x[:, 1] - 0.456) / 0.224 40 | x[:, 2] = (x[:, 2] - 0.406) / 0.225 41 | return x 42 | -------------------------------------------------------------------------------- /mesh2tex/nvs/training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.autograd as autograd 7 | from torchvision.utils import save_image 8 | from mesh2tex import geometry 9 | from mesh2tex.training import BaseTrainer 10 | from mesh2tex.utils.io import export_pointcloud 11 | import mesh2tex.utils.FID.feature_l1 as feature_l1 12 | import mesh2tex.utils.SSIM_L1.ssim_l1_score as SSIM 13 | 14 | 15 | class Trainer(BaseTrainer): 16 | def __init__(self, model_g, model_d, 17 | optimizer_g, optimizer_d, 18 | ma_beta=0.99, 19 | loss_type='L1', 20 | gp_reg=10., 21 | w_pix=0., w_gan=0., w_vae=0., 22 | gan_type='standard', 23 | multi_gpu=False, 24 | **kwargs): 25 | # Initialize base trainer 26 | super().__init__(**kwargs) 27 | 28 | # Models and optimizers 29 | self.model_g = model_g 30 | self.model_d = model_d 31 | 32 | self.model_g_ma = copy.deepcopy(model_g) 33 | 34 | for p in self.model_g_ma.parameters(): 35 | p.requires_grad = False 36 | self.model_g_ma.eval() 37 | 38 | self.optimizer_g = optimizer_g 39 | self.optimizer_d = optimizer_d 40 | self.loss_type = loss_type 41 | 42 | # Attributes 43 | self.gp_reg = gp_reg 44 | self.ma_beta = ma_beta 45 | self.gan_type = gan_type 46 | self.multi_gpu = multi_gpu 47 | self.w_pix = w_pix 48 | self.w_vae = w_vae 49 | self.w_gan = w_gan 50 | self.pix_loss = w_pix != 0 51 | self.vae_loss = w_vae != 0 52 | self.gan_loss = w_gan != 0 53 | if self.vae_loss and self.pix_loss: 54 | print('Not possible to combine pix and vae loss') 55 | # Checkpointer 56 | if self.gan_loss is True: 57 | self.checkpoint_io.register_modules( 58 | model_g=self.model_g, model_d=self.model_d, 59 | model_g_ma=self.model_g_ma, 60 | optimizer_g=self.optimizer_g, 61 | optimizer_d=self.optimizer_d, 62 | ) 63 | else: 64 | self.checkpoint_io.register_modules( 65 | model_g=self.model_g, model_d=self.model_d, 66 | model_g_ma=self.model_g_ma, 67 | optimizer_g=self.optimizer_g, 68 | ) 69 | 70 | def train_step(self, batch, epoch_it, it): 71 | # Seperate batch into batch and model 72 | # batch_model, (img, _) = batch 73 | batch_model0, batch_model1 = batch 74 | loss_g = self.train_step_g(batch_model0) 75 | if self.gan_loss is True: 76 | loss_d = self.train_step_d(batch_model1) 77 | else: 78 | loss_d = 0 79 | 80 | losses = { 81 | 'loss_g': loss_g, 82 | 'loss_d': loss_d, 83 | } 84 | return losses 85 | 86 | def train_step_d(self, batch): 87 | ''' 88 | A single train step of the discriminator 89 | ''' 90 | model_d = self.model_d 91 | model_g = self.model_g 92 | 93 | model_d.train() 94 | model_g.train() 95 | 96 | if self.multi_gpu: 97 | model_d = nn.DataParallel(model_d) 98 | model_g = nn.DataParallel(model_g) 99 | 100 | self.optimizer_d.zero_grad() 101 | 102 | # Get data 103 | depth = batch['2d.depth'].to(self.device) 104 | img_real = batch['2d.img'].to(self.device) 105 | condition = batch['condition'].to(self.device) 106 | 107 | # Loss on real 108 | img_real.requires_grad_() 109 | d_real = model_d(img_real, depth) 110 | 111 | dloss_real = self.compute_gan_loss(d_real, 1) 112 | dloss_real.backward(retain_graph=True) 113 | 114 | # R1 Regularizer 115 | reg = self.gp_reg * compute_grad2(d_real, img_real).mean() 116 | reg.backward() 117 | 118 | # Loss on fake 119 | with torch.no_grad(): 120 | if self.gan_loss is True: 121 | img_fake = model_g(depth, condition) 122 | 123 | d_fake = model_d(img_fake, depth) 124 | 125 | dloss_fake = self.compute_gan_loss(d_fake, 0) 126 | dloss_fake.backward() 127 | 128 | # Gradient step 129 | self.optimizer_d.step() 130 | 131 | return self.w_gan * (dloss_fake.item() + dloss_real.item()) 132 | 133 | def train_step_g(self, batch): 134 | model_d = self.model_d 135 | model_g = self.model_g 136 | 137 | model_d.train() 138 | model_g.train() 139 | 140 | if self.multi_gpu: 141 | model_d = nn.DataParallel(model_d) 142 | model_g = nn.DataParallel(model_g) 143 | 144 | self.optimizer_g.zero_grad() 145 | 146 | # Get data 147 | depth = batch['2d.depth'].to(self.device) 148 | img_real = batch['2d.img'].to(self.device) 149 | condition = batch['condition'].to(self.device) 150 | 151 | # Loss on fake 152 | img_fake = model_g(depth, condition) 153 | loss_pix = 0 154 | loss_gan = 0 155 | 156 | if self.pix_loss is True: 157 | loss_pix = self.compute_loss(img_fake, img_real) 158 | if self.gan_loss is True: 159 | d_fake = model_d(img_fake, depth) 160 | loss_gan = self.compute_gan_loss(d_fake, 1) 161 | 162 | # weighting 163 | loss = self.w_pix * loss_pix + self.w_gan * loss_gan 164 | loss.backward() 165 | 166 | # Gradient step 167 | self.optimizer_g.step() 168 | 169 | # Update moving average 170 | #self.update_moving_average() 171 | 172 | return loss.item() 173 | 174 | def compute_loss(self, img_fake, img_real): 175 | 176 | if self.loss_type == 'L2': 177 | loss = F.mse_loss(img_fake, img_real) 178 | elif self.loss_type == 'L1': 179 | loss = F.l1_loss(img_fake, img_real) 180 | else: 181 | raise NotImplementedError 182 | 183 | return loss 184 | 185 | def compute_gan_loss(self, d_out, target): 186 | ''' 187 | Compute GAN loss (standart cross entropy or wasserstein distance) 188 | !!! Without Regularizer 189 | ''' 190 | targets = d_out.new_full(size=d_out.size(), fill_value=target) 191 | 192 | if self.gan_type == 'standard': 193 | loss = F.binary_cross_entropy_with_logits(d_out, targets) 194 | elif self.gan_type == 'wgan': 195 | loss = (2*target - 1) * d_out.mean() 196 | else: 197 | raise NotImplementedError 198 | 199 | return loss 200 | 201 | def eval_step(self, batch): 202 | depth = batch['2d.depth'].to(self.device) 203 | img_real = batch['2d.img'].to(self.device) 204 | condition = batch['condition'].to(self.device) 205 | 206 | # Get model 207 | model_g = self.model_g 208 | model_g.eval() 209 | 210 | if self.multi_gpu: 211 | model_g = nn.DataParallel(model_g) 212 | 213 | # Predict 214 | with torch.no_grad(): 215 | img_fake = model_g(depth, condition) 216 | loss_val = self.compute_loss(img_fake, img_real) 217 | ssim, l1 = SSIM.calculate_ssim_l1_given_tensor(img_fake, img_real) 218 | featl1 = feature_l1.calculate_feature_l1_given_tensors( 219 | img_fake, img_real, img_real.size(0), True, 2048) 220 | 221 | loss_val_dict = {'loss_val': loss_val.item(), 'SSIM': ssim, 222 | 'featl1': featl1} 223 | 224 | return loss_val_dict 225 | 226 | def eval_step_old(self, batch): 227 | depth = batch['2d.depth'].to(self.device) 228 | img_real = batch['2d.img'].to(self.device) 229 | condition = batch['condition'].to(self.device) 230 | 231 | self.model_g.eval() 232 | with torch.no_grad(): 233 | img_fake = self.model_g(depth, condition) 234 | loss_val = self.compute_loss(img_fake, img_real) 235 | 236 | loss_val_dict = {'loss_val': loss_val.item()} 237 | return loss_val_dict 238 | 239 | def visualize(self, batch): 240 | # b atch_model, (img_real, _) = batch 241 | 242 | depth = batch['2d.depth'].to(self.device) 243 | img_real = batch['2d.img'].to(self.device) 244 | condition = batch['condition'].to(self.device) 245 | batch_size = depth.size(0) 246 | num_views = depth.size(1) 247 | # Save real images 248 | 249 | for j in range(batch_size): 250 | depth_ = depth[j] 251 | img_real_ = img_real[j] 252 | condition_ = condition[j].expand(num_views, condition.size(1), 253 | condition.size(2), condition.size(3)) 254 | save_image(img_real_, os.path.join(self.vis_dir, 'real_%i.png' % j)) 255 | # Create fake images and save 256 | self.model_g.eval() 257 | with torch.no_grad(): 258 | img_fake = self.model_g(depth_, condition_) 259 | save_image(img_fake.cpu(), os.path.join(self.vis_dir, 'fake_%i.png' % j)) 260 | save_image(condition[j].cpu(), os.path.join(self.vis_dir, 'condition_%i.png' % j)) 261 | 262 | def update_moving_average(self): 263 | param_dict_src = dict(self.model_g.named_parameters()) 264 | beta = self.ma_beta 265 | for p_name, p_tgt in self.model_g_ma.named_parameters(): 266 | p_src = param_dict_src[p_name] 267 | assert(p_src is not p_tgt) 268 | with torch.no_grad(): 269 | p_ma = beta * p_tgt + (1. - beta) * p_src 270 | p_tgt.copy_(p_ma) 271 | 272 | 273 | def compute_grad2(d_out, x_in): 274 | ''' 275 | Derive L2-Gradient penalty for regularizing the GAN 276 | ''' 277 | batch_size = x_in.size(0) 278 | grad_dout = autograd.grad( 279 | outputs=d_out.sum(), inputs=x_in, 280 | create_graph=True, retain_graph=True, only_inputs=True 281 | )[0] 282 | grad_dout2 = grad_dout.pow(2) 283 | assert(grad_dout2.size() == x_in.size()) 284 | reg = grad_dout2.view(batch_size, -1).sum(1) 285 | return reg 286 | -------------------------------------------------------------------------------- /mesh2tex/texnet/__init__.py: -------------------------------------------------------------------------------- 1 | from mesh2tex.texnet import ( 2 | training, generation, config, models 3 | ) 4 | 5 | __all__ = [ 6 | training, 7 | generation, 8 | config, 9 | models, 10 | ] 11 | -------------------------------------------------------------------------------- /mesh2tex/texnet/config.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.distributions as dist 5 | import torch.optim as optim 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | from mesh2tex import data, geometry 9 | from mesh2tex.texnet import training, generation 10 | from mesh2tex.texnet import models 11 | 12 | 13 | def get_models(cfg, dataset=None, device=None): 14 | # Get configs 15 | encoder = cfg['model']['encoder'] 16 | decoder = cfg['model']['decoder'] 17 | geometry_encoder = cfg['model']['geometry_encoder'] 18 | vae_encoder = cfg['model']['vae_encoder'] 19 | discriminator = cfg['model']['discriminator'] 20 | 21 | encoder_kwargs = cfg['model']['encoder_kwargs'] 22 | decoder_kwargs = cfg['model']['decoder_kwargs'] 23 | geometry_encoder_kwargs = cfg['model']['geometry_encoder_kwargs'] 24 | discriminator_kwargs = cfg['model']['discriminator_kwargs'] 25 | vae_encoder_kwargs = cfg['model']['vae_encoder_kwargs'] 26 | img_size = cfg['data']['img_size'] 27 | z_dim = cfg['model']['z_dim'] 28 | c_dim = cfg['model']['c_dim'] 29 | white_bg = cfg['model']['white_bg'] 30 | # Create generator 31 | 32 | if encoder == "idx": 33 | encoder = nn.Embedding(len(dataset), c_dim) 34 | elif encoder is not None: 35 | encoder = models.encoder_dict[encoder]( 36 | c_dim=c_dim, **encoder_kwargs 37 | ).to(device) 38 | 39 | decoder = models.decoder_dict[decoder]( 40 | c_dim=c_dim, z_dim=z_dim, **decoder_kwargs 41 | ).to(device) 42 | 43 | geometry_encoder = geometry.encoder_dict[geometry_encoder]( 44 | c_dim=c_dim, **geometry_encoder_kwargs 45 | ).to(device) 46 | 47 | if vae_encoder is not None: 48 | vae_encoder = models.vae_encoder_dict[vae_encoder]( 49 | img_size=img_size, c_dim=c_dim, z_dim=z_dim, **vae_encoder_kwargs 50 | ).to(device) 51 | 52 | p0_z = get_prior_z(cfg, device) 53 | 54 | generator = models.TextureNetwork( 55 | decoder, geometry_encoder, encoder, vae_encoder, p0_z, white_bg 56 | ) 57 | 58 | # Create discriminator 59 | discriminator = models.discriminator_dict[discriminator]( 60 | geometry_encoder, 61 | img_size=img_size, **discriminator_kwargs 62 | ).to(device) 63 | 64 | # Output dict 65 | models_out = { 66 | 'generator': generator, 67 | 'discriminator': discriminator, 68 | } 69 | 70 | return models_out 71 | 72 | 73 | def get_optimizers(models, cfg): 74 | model_g = models['generator'] 75 | model_d = models['discriminator'] 76 | 77 | lr_g = cfg['training']['lr_g'] 78 | lr_d = cfg['training']['lr_d'] 79 | optimizer_g = optim.RMSprop(model_g.parameters(), lr=lr_g) 80 | optimizer_d = optim.RMSprop(model_d.parameters(), lr=lr_d) 81 | 82 | optimizers = { 83 | 'generator': optimizer_g, 84 | 'discriminator': optimizer_d, 85 | } 86 | return optimizers 87 | 88 | 89 | def get_dataset(mode, cfg, input_sampling=True): 90 | # Config 91 | path_shapes = cfg['data']['path_shapes'] 92 | img_size = cfg['data']['img_size'] 93 | pc_subsampling = cfg['training']['pc_subsampling'] 94 | pcl_knn = cfg['data']['pcl_knn'] 95 | 96 | # Fields 97 | transform_img = transforms.Compose([ 98 | data.ResizeImage((img_size, img_size), order=0), 99 | ]) 100 | 101 | transform_img_input = transforms.Compose([ 102 | data.ResizeImage((224, 224), order=0), 103 | ]) 104 | 105 | transform_depth = torchvision.transforms.Compose([ 106 | data.ImageToDepthValue(), 107 | data.ResizeImage((img_size, img_size), order=0), 108 | ]) 109 | 110 | pcl_transform = [data.SubsamplePointcloud(pc_subsampling)] 111 | if pcl_knn is not None: 112 | pcl_transform += [data.ComputeKNNPointcloud(pcl_knn)] 113 | 114 | pcl_transform = transforms.Compose(pcl_transform) 115 | 116 | if mode == 'train': 117 | fields = { 118 | '2d': data.DepthImageField( 119 | 'image', 'depth', transform_img, transform_depth, 'png', 120 | 'exr', with_camera=True, random_view=True), 121 | 'pointcloud': data.PointCloudField('pointcloud.npz', pcl_transform), 122 | 'condition': data.ImagesField('input_image', 123 | transform_img_input, 'jpg'), 124 | } 125 | mode_ = 'train' 126 | 127 | elif mode == 'val_eval': 128 | fields = { 129 | '2d': data.DepthImageField( 130 | 'image', 'depth', transform_img, transform_depth, 'png', 131 | 'exr', with_camera=True, random_view=True), 132 | 'pointcloud': data.PointCloudField('pointcloud.npz', pcl_transform), 133 | 'condition': data.ImagesField('input_image', 134 | transform_img_input, 'jpg'), 135 | } 136 | mode_ = 'val' 137 | 138 | elif mode == 'val_vis': 139 | fields = { 140 | '2d': data.DepthImageVisualizeField( 141 | 'visualize/image', 'visualize/depth', 142 | transform_img, transform_depth, 'png', 143 | 'exr', with_camera=True, random_view=True 144 | ), 145 | 'pointcloud': data.PointCloudField('pointcloud.npz', pcl_transform), 146 | 'condition': data.ImagesField('input_image', 147 | transform_img_input, 'jpg'), 148 | } 149 | mode_ = 'val' 150 | 151 | elif mode == 'test_eval': 152 | fields = { 153 | '2d': data.DepthImageVisualizeField( 154 | 'image', 'depth', 155 | transform_img, transform_depth, 'png', 156 | 'exr', with_camera=True, random_view=True 157 | ), 158 | 'pointcloud': data.PointCloudField('pointcloud.npz', pcl_transform), 159 | 'condition': data.ImagesField('input_image', 160 | transform_img_input, 'jpg', 161 | random_view=input_sampling), 162 | 'idx': data.IndexField(), 163 | } 164 | mode_ = 'test' 165 | 166 | elif mode == 'test_vis': 167 | fields = { 168 | '2d': data.DepthImageVisualizeField( 169 | 'visualize/image', 'visualize/depth', 170 | transform_img, transform_depth, 'png', 171 | 'exr', with_camera=True, random_view=True 172 | ), 173 | 'pointcloud': data.PointCloudField('pointcloud.npz', pcl_transform), 174 | 'condition': data.ImagesField('input_image', 175 | transform_img_input, 'jpg', 176 | random_view=input_sampling), 177 | 'idx': data.IndexField(), 178 | } 179 | mode_ = 'test' 180 | 181 | else: 182 | print('Invalid data loading mode') 183 | 184 | # Dataset 185 | if cfg['data']['shapes_multiclass']: 186 | ds_shapes = data.Shapes3dDataset( 187 | path_shapes, fields, split=mode_, no_except=True, 188 | ) 189 | else: 190 | ds_shapes = data.Shapes3dClassDataset( 191 | path_shapes, fields, split=mode_, no_except=False, 192 | ) 193 | 194 | if mode_ == 'val' or mode_ == 'test': 195 | ds = ds_shapes 196 | else: 197 | ds = data.CombinedDataset([ds_shapes, ds_shapes]) 198 | 199 | return ds 200 | 201 | 202 | def get_dataloader(mode, cfg): 203 | # Config 204 | batch_size = cfg['training']['batch_size'] 205 | with_shuffle = cfg['data']['with_shuffle'] 206 | 207 | ds_shapes = get_dataset(mode, cfg) 208 | data_loader = torch.utils.data.DataLoader( 209 | ds_shapes, batch_size=batch_size, num_workers=12, shuffle=with_shuffle) 210 | #gcollate_fn=data.collate_remove_none) 211 | 212 | return data_loader 213 | 214 | 215 | def get_meshloader(mode, cfg): 216 | # Config 217 | 218 | path_meshes = cfg['data']['path_meshes'] 219 | 220 | batch_size = cfg['training']['batch_size'] 221 | 222 | fields = { 223 | 'meshes': data.MeshField('mesh'), 224 | } 225 | 226 | ds_shapes = data.Shapes3dClassDataset( 227 | path_meshes, fields, split=None, no_except=True, 228 | ) 229 | 230 | data_loader = torch.utils.data.DataLoader( 231 | ds_shapes, batch_size=batch_size, num_workers=12, shuffle=True) 232 | # collate_fn=data.collate_remove_none) 233 | 234 | return data_loader 235 | 236 | 237 | def get_trainer(models, optimizers, cfg, device=None): 238 | out_dir = cfg['training']['out_dir'] 239 | 240 | print_every = cfg['training']['print_every'] 241 | visualize_every = cfg['training']['visualize_every'] 242 | checkpoint_every = cfg['training']['checkpoint_every'] 243 | validate_every = cfg['training']['validate_every'] 244 | backup_every = cfg['training']['backup_every'] 245 | 246 | model_selection_metric = cfg['training']['model_selection_metric'] 247 | model_selection_mode = cfg['training']['model_selection_mode'] 248 | 249 | ma_beta = cfg['training']['moving_average_beta'] 250 | multi_gpu = cfg['training']['multi_gpu'] 251 | gp_reg = cfg['training']['gradient_penalties_reg'] 252 | w_pix = cfg['training']['weight_pixelloss'] 253 | w_gan = cfg['training']['weight_ganloss'] 254 | w_vae = cfg['training']['weight_vaeloss'] 255 | experiment = cfg['training']['experiment'] 256 | model_url = cfg['model']['model_url'] 257 | trainer = training.Trainer( 258 | models['generator'], models['discriminator'], 259 | optimizers['generator'], optimizers['discriminator'], 260 | ma_beta=ma_beta, 261 | gp_reg=gp_reg, 262 | w_pix=w_pix, w_gan=w_gan, w_vae=w_vae, 263 | multi_gpu=multi_gpu, 264 | experiment=experiment, 265 | out_dir=out_dir, 266 | model_selection_metric=model_selection_metric, 267 | model_selection_mode=model_selection_mode, 268 | print_every=print_every, 269 | visualize_every=visualize_every, 270 | checkpoint_every=checkpoint_every, 271 | backup_every=backup_every, 272 | validate_every=validate_every, 273 | device=device, 274 | model_url=model_url 275 | ) 276 | 277 | return trainer 278 | 279 | 280 | def get_generator(model, cfg, device, **kwargs): 281 | 282 | generator = generation.Generator3D( 283 | model, 284 | device=device, 285 | ) 286 | return generator 287 | 288 | 289 | def get_prior_z(cfg, device, **kwargs): 290 | z_dim = cfg['model']['z_dim'] 291 | p0_z = dist.Normal( 292 | torch.zeros(z_dim, device=device), 293 | torch.ones(z_dim, device=device) 294 | ) 295 | 296 | return p0_z 297 | -------------------------------------------------------------------------------- /mesh2tex/texnet/generation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from trimesh.util import array_to_string 5 | from mesh2tex import geometry 6 | from torchvision.utils import save_image 7 | from torch.nn.functional import interpolate 8 | 9 | #TODO comment the generation functions 10 | 11 | 12 | class Generator3D(object): 13 | def __init__(self, model, device=None): 14 | 15 | self.model = model 16 | self.device = device 17 | 18 | def save_mesh(self, mesh, out_file, digits=10): 19 | ''' 20 | Saving meshes to OFF file 21 | 22 | ''' 23 | digits = int(digits) 24 | # prepend a 3 (face count) to each face 25 | if mesh.visual.face_colors is None: 26 | faces_stacked = np.column_stack(( 27 | np.ones(len(mesh.faces)) * 3, 28 | mesh.faces)).astype(np.int64) 29 | else: 30 | assert(mesh.visual.face_colors.shape[0] == mesh.faces.shape[0]) 31 | faces_stacked = np.column_stack(( 32 | np.ones(len(mesh.faces)) * 3, 33 | mesh.faces, mesh.visual.face_colors[:, :3])).astype(np.int64) 34 | export = 'OFF\n' 35 | # the header is vertex count, face count, edge number 36 | export += str(len(mesh.vertices)) + ' ' + str(len(mesh.faces)) + ' 0\n' 37 | export += array_to_string( 38 | mesh.vertices, col_delim=' ', row_delim='\n', digits=digits) + '\n' 39 | export += array_to_string(faces_stacked, col_delim=' ', row_delim='\n') 40 | 41 | with open(out_file, 'w') as f: 42 | f.write(export) 43 | 44 | return mesh 45 | 46 | def generate_images_4eval_condi(self, batch, out_dir, model_names): 47 | ''' 48 | Generate textures in the conditional setting (given image) 49 | 50 | ''' 51 | 52 | # Extract depth, gt, camera info, shape pc and condition 53 | depth = batch['2d.depth'].to(self.device) 54 | img_real = batch['2d.img'].to(self.device) 55 | cam_K = batch['2d.camera_mat'].to(self.device) 56 | cam_W = batch['2d.world_mat'].to(self.device) 57 | mesh_repr = geometry.get_representation(batch, self.device) 58 | mesh_points = mesh_repr['points'] 59 | mesh_normals = mesh_repr['normals'] 60 | condition = batch['condition'].to(self.device) 61 | 62 | # Determine constants and check 63 | batch_size = depth.size(0) 64 | num_views = depth.size(1) 65 | 66 | # Define Output folders 67 | out_dir_real = out_dir + "/real/" 68 | out_dir_fake = out_dir + "/fake/" 69 | out_dir_condition = out_dir + "/condition/" 70 | if not os.path.exists(out_dir_real): 71 | os.makedirs(out_dir_real) 72 | if not os.path.exists(out_dir_fake): 73 | os.makedirs(out_dir_fake) 74 | if not os.path.exists(out_dir_condition): 75 | os.makedirs(out_dir_condition) 76 | 77 | # Batch loop 78 | for j in range(batch_size): 79 | 80 | # Expand shape info to tensors 81 | # for all views of the same objects 82 | geom_repr = { 83 | 'points': mesh_points[j][:num_views].expand( 84 | num_views, mesh_points.size(1), 85 | mesh_points.size(2)), 86 | 'normals': mesh_normals[j][:num_views].expand( 87 | num_views, mesh_normals.size(1), 88 | mesh_normals.size(2)), 89 | } 90 | 91 | depth_ = depth[j][:num_views] 92 | img_real_ = img_real[j][:num_views] 93 | condition_ = condition[j][:num_views].expand( 94 | num_views, condition.size(1), 95 | condition.size(2), condition.size(3)) 96 | cam_K_ = cam_K[j][:num_views] 97 | cam_W_ = cam_W[j][:num_views] 98 | 99 | # Generate images and save 100 | self.model.eval() 101 | with torch.no_grad(): 102 | img_fake = self.model(depth_, cam_K_, cam_W_, 103 | geom_repr, condition_) 104 | 105 | save_image( 106 | condition[j].cpu(), 107 | os.path.join(out_dir_condition, 108 | '%s.png' % (model_names[j]))) 109 | 110 | for v in range(num_views): 111 | save_image( 112 | img_real_[v], 113 | os.path.join(out_dir_real, 114 | '%s%03d.png' % (model_names[j], v))) 115 | save_image( 116 | img_fake[v].cpu(), 117 | os.path.join(out_dir_fake, 118 | '%s%03d.png' % (model_names[j], v))) 119 | 120 | def generate_images_4eval_condi_hd(self, batch, out_dir, model_names): 121 | ''' 122 | Generate textures in hd images given condition 123 | 124 | ''' 125 | 126 | # Extract depth, gt, camera info, shape pc and condition 127 | depth = batch['2d.depth'] 128 | img_real = batch['2d.img'] 129 | cam_K = batch['2d.camera_mat'] 130 | cam_W = batch['2d.world_mat'] 131 | mesh_repr = geometry.get_representation(batch, self.device) 132 | mesh_points = mesh_repr['points'] 133 | mesh_normals = mesh_repr['normals'] 134 | condition = batch['condition'] 135 | 136 | # Determine constants and check 137 | batch_size = depth.size(0) 138 | num_views = depth.size(1) 139 | 140 | # Define Output folders 141 | out_dir_real = out_dir + "/real/" 142 | out_dir_fake = out_dir + "/fake/" 143 | out_dir_condition = out_dir + "/condition/" 144 | if not os.path.exists(out_dir_real): 145 | os.makedirs(out_dir_real) 146 | if not os.path.exists(out_dir_fake): 147 | os.makedirs(out_dir_fake) 148 | if not os.path.exists(out_dir_condition): 149 | os.makedirs(out_dir_condition) 150 | 151 | # Loop through batch and views, because of memory requirement 152 | viewbatchsize = 1 153 | viewbatchnum = int(num_views / viewbatchsize) 154 | for j in range(batch_size): 155 | for vidx in range(viewbatchnum): 156 | lower = vidx * viewbatchsize 157 | upper = (vidx + 1) * viewbatchsize 158 | 159 | # Expand shape info to tensors 160 | # for all views of the same objects 161 | geom_repr = { 162 | 'points': mesh_points[j][:4].expand( 163 | viewbatchsize, mesh_points.size(1), 164 | mesh_points.size(2)), 165 | 'normals': mesh_normals[j][:4].expand( 166 | viewbatchsize, mesh_normals.size(1), 167 | mesh_normals.size(2)), 168 | } 169 | 170 | depth_ = depth[j][lower:upper].to(self.device) 171 | img_real_ = img_real[j][lower:upper] 172 | if len(condition.size()) == 1: 173 | condition_ = condition[j:j+1].expand( 174 | viewbatchsize) 175 | else: 176 | condition_ = condition[j:j+1][:4].expand( 177 | viewbatchsize, condition.size(1), 178 | condition.size(2), condition.size(3)).to(self.device) 179 | cam_K_ = cam_K[j][lower:upper].to(self.device) 180 | cam_W_ = cam_W[j][lower:upper].to(self.device) 181 | 182 | # Generate images and save 183 | self.model.eval() 184 | with torch.no_grad(): 185 | img_fake = self.model(depth_, cam_K_, cam_W_, 186 | geom_repr, condition_) 187 | if len(condition.size()) != 1: 188 | save_image( 189 | condition[j].cpu(), 190 | os.path.join(out_dir_condition, 191 | '%s.png' % (model_names[j]))) 192 | 193 | for v in range(viewbatchsize): 194 | save_image( 195 | img_real_[v], 196 | os.path.join( 197 | out_dir_real, 198 | '%s%03d.png' % (model_names[j], 199 | vidx * viewbatchsize + v))) 200 | save_image( 201 | img_fake[v].cpu(), 202 | os.path.join( 203 | out_dir_fake, 204 | '%s%03d.png' % (model_names[j], 205 | vidx * viewbatchsize + v))) 206 | 207 | def generate_images_4eval_vae(self, batch, out_dir, model_names): 208 | ''' 209 | Generate texture using the VAE 210 | 211 | ''' 212 | # Extract depth, gt, camera info, shape pc and condition 213 | depth = batch['2d.depth'].to(self.device) 214 | img_real = batch['2d.img'].to(self.device) 215 | cam_K = batch['2d.camera_mat'].to(self.device) 216 | cam_W = batch['2d.world_mat'].to(self.device) 217 | mesh_repr = geometry.get_representation(batch, self.device) 218 | mesh_points = mesh_repr['points'] 219 | mesh_normals = mesh_repr['normals'] 220 | 221 | # Determine constants and check 222 | batch_size = depth.size(0) 223 | num_views = depth.size(1) 224 | if depth.size(1) >= 10: 225 | num_views = 10 226 | 227 | # Define Output folders 228 | out_dir_real = out_dir + "/real/" 229 | out_dir_fake = out_dir + "/fake/" 230 | if not os.path.exists(out_dir_real): 231 | os.makedirs(out_dir_real) 232 | if not os.path.exists(out_dir_fake): 233 | os.makedirs(out_dir_fake) 234 | 235 | # batch loop 236 | for j in range(batch_size): 237 | geom_repr = { 238 | 'points': mesh_points[j][:num_views].expand( 239 | num_views, mesh_points.size(1), mesh_points.size(2)), 240 | 'normals': mesh_normals[j][:num_views].expand( 241 | num_views, mesh_normals.size(1), mesh_normals.size(2)), 242 | } 243 | depth_ = depth[j][:num_views] 244 | img_real_ = img_real[j][:num_views] 245 | cam_K_ = cam_K[j][:num_views] 246 | cam_W_ = cam_W[j][:num_views] 247 | 248 | # Sample latent code 249 | z_ = np.random.normal(0, 1, 512) 250 | inter = torch.from_numpy(z_).float().to(self.device) 251 | z = inter.expand(num_views, 512) 252 | 253 | # Generate images and save 254 | self.model.eval() 255 | with torch.no_grad(): 256 | img_fake = self.model(depth_, cam_K_, cam_W_, 257 | geom_repr, z=z, sample=False) 258 | 259 | for v in range(num_views): 260 | save_image( 261 | img_real_[v], 262 | os.path.join(out_dir_real, '%s%03d.png' 263 | % (model_names[j], v))) 264 | save_image( 265 | img_fake[v].cpu(), 266 | os.path.join(out_dir_fake, '%s%03d.png' 267 | % (model_names[j], v))) 268 | 269 | def generate_images_4eval_vae_interpol(self, batch, out_dir, model_names): 270 | ''' 271 | Interpolates between latent encoding 272 | of first and second element of batch 273 | 274 | ''' 275 | # Extract depth, gt, camera info, shape pc and condition 276 | depth = batch['2d.depth'].to(self.device) 277 | img_real = batch['2d.img'].to(self.device) 278 | cam_K = batch['2d.camera_mat'].to(self.device) 279 | cam_W = batch['2d.world_mat'].to(self.device) 280 | mesh_repr = geometry.get_representation(batch, self.device) 281 | mesh_points = mesh_repr['points'] 282 | mesh_normals = mesh_repr['normals'] 283 | 284 | # Determine constants and check 285 | batch_size = depth.size(0) 286 | num_views = depth.size(1) 287 | if depth.size(1) >= 10: 288 | num_views = 10 289 | 290 | # Define Output folders 291 | out_dir_real = out_dir + "/real/" 292 | out_dir_fake = out_dir + "/fake/" 293 | if not os.path.exists(out_dir_real): 294 | os.makedirs(out_dir_real) 295 | if not os.path.exists(out_dir_fake): 296 | os.makedirs(out_dir_fake) 297 | 298 | # Derive latent texture code as starting point of interpolation 299 | geom_repr = { 300 | 'points': mesh_points[:1], 301 | 'normals': mesh_normals[:1], 302 | } 303 | self.model.eval() 304 | shape_encoding = self.model.encode_geometry(geom_repr) 305 | image_input = img_real[0][:1] 306 | img = interpolate(image_input, size=[128, 128]) 307 | latent_input = self.model.infer_z_transfer(img, shape_encoding) 308 | 309 | # Derive latent texture code as end point of interpolation 310 | geom_repr2 = { 311 | 'points': mesh_points[1:2], 312 | 'normals': mesh_normals[1:2], 313 | } 314 | shape_encoding2 = self.model.encode_geometry(geom_repr2) 315 | image_input2 = img_real[1][:1] 316 | img2 = interpolate(image_input2, size=[128, 128]) 317 | latent_input2 = self.model.infer_z_transfer(img2, shape_encoding2) 318 | 319 | # Derive stepsize 320 | steps = 20 321 | step = (latent_input2-latent_input)/steps 322 | 323 | # batch loop 324 | for j in range(1, batch_size): 325 | 326 | geom_repr = { 327 | 'points': mesh_points[j][:num_views].expand( 328 | num_views, mesh_points.size(1), mesh_points.size(2)), 329 | 'normals': mesh_normals[j][:num_views].expand( 330 | num_views, mesh_normals.size(1), mesh_normals.size(2)), 331 | } 332 | 333 | depth_ = depth[j][:num_views] 334 | img_real_ = img_real[j][:num_views] 335 | cam_K_ = cam_K[j][:num_views] 336 | cam_W_ = cam_W[j][:num_views] 337 | 338 | self.model.eval() 339 | # steps loop 340 | for num in range(steps): 341 | inter = latent_input + step*num 342 | z = inter.expand(num_views, 512) 343 | with torch.no_grad(): 344 | img_fake = self.model(depth_, cam_K_, cam_W_, 345 | geom_repr, z=z, sample=False) 346 | for v in range(1): 347 | save_image( 348 | img_real_[v], 349 | os.path.join( 350 | out_dir_real, '%s%03d_%03d.png' 351 | % (model_names[j], v, num))) 352 | save_image( 353 | img_fake[v].cpu(), 354 | os.path.join( 355 | out_dir_fake, '%s%03d_%03d.png' 356 | % (model_names[j], v, num))) 357 | 358 | def generate_images_4eval_gan(self, batch, out_dir, model_names): 359 | ''' 360 | Generate Texture using a GAN 361 | 362 | ''' 363 | # Extract depth, gt, camera info, shape pc and condition 364 | depth = batch['2d.depth'].to(self.device) 365 | img_real = batch['2d.img'].to(self.device) 366 | cam_K = batch['2d.camera_mat'].to(self.device) 367 | cam_W = batch['2d.world_mat'].to(self.device) 368 | mesh_repr = geometry.get_representation(batch, self.device) 369 | mesh_points = mesh_repr['points'] 370 | mesh_normals = mesh_repr['normals'] 371 | 372 | # Determine constants and check 373 | batch_size = depth.size(0) 374 | num_views = depth.size(1) 375 | if depth.size(1) >= 10: 376 | num_views = 10 377 | 378 | # Define Output folders 379 | out_dir_real = out_dir + "/real/" 380 | out_dir_fake = out_dir + "/fake/" 381 | out_dir_condition = out_dir + "/condition/" 382 | if not os.path.exists(out_dir_real): 383 | os.makedirs(out_dir_real) 384 | if not os.path.exists(out_dir_fake): 385 | os.makedirs(out_dir_fake) 386 | if not os.path.exists(out_dir_condition): 387 | os.makedirs(out_dir_condition) 388 | 389 | # batch loop 390 | for j in range(batch_size): 391 | 392 | geom_repr = { 393 | 'points': mesh_points[j][:num_views].expand( 394 | num_views, mesh_points.size(1), 395 | mesh_points.size(2)), 396 | 'normals': mesh_normals[j][:num_views].expand( 397 | num_views, mesh_normals.size(1), 398 | mesh_normals.size(2)), 399 | } 400 | 401 | depth_ = depth[j][:num_views] 402 | img_real_ = img_real[j][:num_views] 403 | cam_K_ = cam_K[j][:num_views] 404 | cam_W_ = cam_W[j][:num_views] 405 | 406 | self.model.eval() 407 | with torch.no_grad(): 408 | img_fake = self.model(depth_, cam_K_, cam_W_, 409 | geom_repr, sample=False) 410 | for v in range(num_views): 411 | save_image( 412 | img_real_[v], 413 | os.path.join( 414 | out_dir_real, '%s%03d.png' % (model_names[j], v))) 415 | save_image( 416 | img_fake[v].cpu(), 417 | os.path.join( 418 | out_dir_fake, '%s%03d.png' % (model_names[j], v))) 419 | 420 | 421 | def make_3d_grid(bb_min, bb_max, shape): 422 | ''' 423 | Outputs gird points of a 3d grid 424 | 425 | ''' 426 | size = shape[0] * shape[1] * shape[2] 427 | 428 | pxs = torch.linspace(bb_min[0], bb_max[0], shape[0]) 429 | pys = torch.linspace(bb_min[1], bb_max[1], shape[1]) 430 | pzs = torch.linspace(bb_min[2], bb_max[2], shape[2]) 431 | 432 | pxs = pxs.view(-1, 1, 1).expand(*shape).contiguous().view(size) 433 | pys = pys.view(1, -1, 1).expand(*shape).contiguous().view(size) 434 | pzs = pzs.view(1, 1, -1).expand(*shape).contiguous().view(size) 435 | p = torch.stack([pxs, pys, pzs], dim=1) 436 | 437 | return p 438 | -------------------------------------------------------------------------------- /mesh2tex/texnet/models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import distributions as dist 5 | import trimesh 6 | from mesh2tex.texnet.models import ( 7 | image_encoder, decoder, discriminator, vae_encoder 8 | ) 9 | 10 | encoder_dict = { 11 | 'resnet18': image_encoder.Resnet18, 12 | } 13 | 14 | decoder_dict = { 15 | 'each_layer_c': decoder.DecoderEachLayerC, 16 | 'each_layer_c_larger': decoder.DecoderEachLayerCLarger, 17 | } 18 | 19 | discriminator_dict = { 20 | 'resnet_conditional': discriminator.Resnet_Conditional, 21 | } 22 | 23 | vae_encoder_dict = { 24 | 'resnet': vae_encoder.Resnet, 25 | } 26 | 27 | 28 | class TextureNetwork(nn.Module): 29 | def __init__(self, decoder, geometry_encoder, encoder=None, 30 | vae_encoder=None, p0_z=None, white_bg=True): 31 | super().__init__() 32 | 33 | if p0_z is None: 34 | p0_z = dist.Normal(torch.tensor([]), torch.tensor([])) 35 | 36 | self.decoder = decoder 37 | self.encoder = encoder 38 | self.geometry_encoder = geometry_encoder 39 | self.vae_encoder = vae_encoder 40 | self.p0_z = p0_z 41 | self.white_bg = white_bg 42 | 43 | def forward(self, depth, cam_K, cam_W, geometry, 44 | condition=None, z=None, sample=True): 45 | """Generate an image . 46 | 47 | Args: 48 | depth (torch.FloatTensor): tensor of size B x 1 x N x M 49 | representing depth of at pixels 50 | cam_K (torch.FloatTensor): tensor of size B x 3 x 4 representing 51 | camera projectin matrix 52 | cam_W (torch.FloatTensor): tensor of size B x 3 x 4 representing 53 | camera world matrix 54 | geometry (dict): representation of geometry 55 | condition 56 | z 57 | sample (Boolean): wether to sample latent code or take MAP 58 | Returns: 59 | img (torch.FloatTensor): tensor of size B x 3 x N x M representing 60 | output image 61 | """ 62 | batch_size, _, N, M = depth.size() 63 | assert(depth.size(1) == 1) 64 | assert(cam_K.size() == (batch_size, 3, 4)) 65 | assert(cam_W.size() == (batch_size, 3, 4)) 66 | 67 | loc3d, mask = self.depth_map_to_3d(depth, cam_K, cam_W) 68 | geom_descr = self.encode_geometry(geometry) 69 | 70 | if self.encoder is not None: 71 | z = self.encode(condition) 72 | z = z.cuda() 73 | elif z is None: 74 | z = self.get_z_from_prior((batch_size,), sample=sample) 75 | 76 | loc3d = loc3d.view(batch_size, 3, N * M) 77 | x = self.decode(loc3d, geom_descr, z) 78 | x = x.view(batch_size, 3, N, M) 79 | 80 | if self.white_bg is False: 81 | x_bg = torch.zeros_like(x) 82 | else: 83 | x_bg = torch.ones_like(x) 84 | 85 | img = (mask * x).permute(0, 1, 3, 2) + (1 - mask.permute(0, 1, 3, 2)) * x_bg 86 | 87 | return img 88 | 89 | def load_mesh2facecenter(in_path): 90 | mesh = trimesh.load(in_path, process=False) 91 | faces_center = mesh.triangles_center 92 | return mesh, faces_center 93 | 94 | def elbo(self, image_real, depth, cam_K, cam_W, geometry): 95 | batch_size, _, N, M = depth.size() 96 | 97 | assert(depth.size(1) == 1) 98 | assert(cam_K.size() == (batch_size, 3, 4)) 99 | assert(cam_W.size() == (batch_size, 3, 4)) 100 | 101 | loc3d, mask = self.depth_map_to_3d(depth, cam_K, cam_W) 102 | geom_descr = self.encode_geometry(geometry) 103 | 104 | q_z = self.infer_z(image_real, geom_descr) 105 | z = q_z.rsample() 106 | 107 | loc3d = loc3d.view(batch_size, 3, N * M) 108 | x = self.decode(loc3d, geom_descr, z) 109 | x = x.view(batch_size, 3, N, M) 110 | 111 | if self.white_bg is False: 112 | x_bg = torch.zeros_like(x) 113 | else: 114 | x_bg = torch.ones_like(x) 115 | 116 | image_fake = (mask * x).permute(0, 1, 3, 2) + (1 - mask.permute(0, 1, 3, 2)) * x_bg 117 | 118 | recon_loss = F.mse_loss(image_fake, image_real).sum(dim=-1) 119 | kl = dist.kl_divergence(q_z, self.p0_z).sum(dim=-1) 120 | elbo = recon_loss.mean() + kl.mean()/float(N*M*3) 121 | return elbo, recon_loss.mean(), kl.mean()/float(N*M*3), image_fake 122 | 123 | def encode(self, cond): 124 | """Encode mesh using sampled 3D location on the mesh. 125 | 126 | Args: 127 | input_image (torch.FloatTensor): tensor of size B x 3 x N x M 128 | input image 129 | 130 | Returns: 131 | c (torch.FloatTensor): tensor of size B x C with encoding of 132 | the input image 133 | """ 134 | z = self.encoder(cond) 135 | return z 136 | 137 | def encode_geometry(self, geometry): 138 | """Encode mesh using sampled 3D location on the mesh. 139 | 140 | Args: 141 | geometry (dict): representation of teometry 142 | Returns: 143 | geom_descr (dict): geometry discriptor 144 | 145 | """ 146 | geom_descr = self.geometry_encoder(geometry) 147 | return geom_descr 148 | 149 | def decode(self, loc3d, c, z): 150 | """Decode image from 3D locations, conditional encoding and latent 151 | encoding. 152 | 153 | Args: 154 | loc3d (torch.FloatTensor): tensor of size B x 3 x K 155 | with 3D locations of the query 156 | c (torch.FloatTensor): tensor of size B x C with the encoding of 157 | the 3D meshes 158 | z (torch.FloatTensor): tensor of size B x Z with latent codes 159 | 160 | Returns: 161 | rgb (torch.FloatTensor): tensor of size B x 3 x N representing 162 | color at given 3d locations 163 | """ 164 | rgb = self.decoder(loc3d, c, z) 165 | return rgb 166 | 167 | def depth_map_to_3d(self, depth, cam_K, cam_W): 168 | """Derive 3D locations of each pixel of a depth map. 169 | 170 | Args: 171 | depth (torch.FloatTensor): tensor of size B x 1 x N x M 172 | with depth at every pixel 173 | cam_K (torch.FloatTensor): tensor of size B x 3 x 4 representing 174 | camera matrices 175 | cam_W (torch.FloatTensor): tensor of size B x 3 x 4 representing 176 | world matrices 177 | Returns: 178 | loc3d (torch.FloatTensor): tensor of size B x 3 x N x M 179 | representing color at given 3d locations 180 | mask (torch.FloatTensor): tensor of size B x 1 x N x M with 181 | a binary mask if the given pixel is present or not 182 | """ 183 | 184 | assert(depth.size(1) == 1) 185 | batch_size, _, N, M = depth.size() 186 | device = depth.device 187 | # Turn depth around. This also avoids problems with inplace operations 188 | depth = -depth .permute(0, 1, 3, 2) 189 | 190 | zero_one_row = torch.tensor([[0., 0., 0., 1.]]) 191 | zero_one_row = zero_one_row.expand(batch_size, 1, 4).to(device) 192 | 193 | # add row to world mat 194 | cam_W = torch.cat((cam_W, zero_one_row), dim=1) 195 | 196 | # clean depth image for mask 197 | mask = (depth.abs() != float("Inf")).float() 198 | depth[depth == float("Inf")] = 0 199 | depth[depth == -1*float("Inf")] = 0 200 | 201 | # 4d array to 2d array k=N*M 202 | d = depth.reshape(batch_size, 1, N * M) 203 | 204 | # create pixel location tensor 205 | px, py = torch.meshgrid([torch.arange(0, N), torch.arange(0, M)]) 206 | px, py = px.to(device), py.to(device) 207 | 208 | p = torch.cat(( 209 | px.expand(batch_size, 1, px.size(0), px.size(1)), 210 | (M - py).expand(batch_size, 1, py.size(0), py.size(1)) 211 | ), dim=1) 212 | p = p.reshape(batch_size, 2, py.size(0) * py.size(1)) 213 | p = (p.float() / M * 2) 214 | 215 | # create terms of mapping equation x = P^-1 * d*(qp - b) 216 | P = cam_K[:, :2, :2].float().to(device) 217 | q = cam_K[:, 2:3, 2:3].float().to(device) 218 | b = cam_K[:, :2, 2:3].expand(batch_size, 2, d.size(2)).to(device) 219 | Inv_P = torch.inverse(P).to(device) 220 | 221 | rightside = (p.float() * q.float() - b.float()) * d.float() 222 | x_xy = torch.bmm(Inv_P, rightside) 223 | 224 | # add depth and ones to location in world coord system 225 | x_world = torch.cat((x_xy, d, torch.ones_like(d)), dim=1) 226 | 227 | # derive loactoion in object coord via loc3d = W^-1 * x_world 228 | Inv_W = torch.inverse(cam_W) 229 | loc3d = torch.bmm( 230 | Inv_W.expand(batch_size, 4, 4), 231 | x_world 232 | ).reshape(batch_size, 4, N, M) 233 | 234 | loc3d = loc3d[:, :3].to(device) 235 | mask = mask.to(device) 236 | return loc3d, mask 237 | 238 | def get_z_from_prior(self, size=torch.Size([]), sample=True): 239 | """Draw latent code z from prior either using sampling or 240 | using the MAP. 241 | 242 | Args: 243 | size (torch.Size): size of sample to draw. 244 | sample (Boolean): wether to sample or to use the MAP 245 | 246 | Return: 247 | z (torch.FloatTensor): tensor of shape *size x Z representing 248 | the latent code 249 | """ 250 | if sample: 251 | z = self.p0_z.sample(size) 252 | else: 253 | z = self.p0_z.mean 254 | z = z.expand(*size, *z.size()) 255 | 256 | return z 257 | 258 | def infer_z(self, image, c, **kwargs): 259 | if self.vae_encoder is not None: 260 | mean_z, logstd_z = self.vae_encoder(image, c, **kwargs) 261 | else: 262 | batch_size = image.size(0) 263 | mean_z = torch.empty(batch_size, 0).to(self._device) 264 | logstd_z = torch.empty(batch_size, 0).to(self._device) 265 | 266 | q_z = dist.Normal(mean_z, torch.exp(logstd_z)) 267 | return q_z 268 | 269 | def infer_z_transfer(self, image, c, **kwargs): 270 | if self.vae_encoder is not None: 271 | mean_z, logstd_z = self.vae_encoder(image, c, **kwargs) 272 | else: 273 | batch_size = image.size(0) 274 | mean_z = torch.empty(batch_size, 0).to(self._device) 275 | return mean_z 276 | -------------------------------------------------------------------------------- /mesh2tex/texnet/models/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from mesh2tex import common 5 | from mesh2tex.layers import ( 6 | ResnetBlockPointwise, 7 | EqualizedLR 8 | ) 9 | 10 | 11 | class DecoderEachLayerC(nn.Module): 12 | def __init__(self, c_dim=128, z_dim=128, dim=3, 13 | hidden_size=128, leaky=True, 14 | resnet_leaky=True, eq_lr=False): 15 | super().__init__() 16 | self.c_dim = c_dim 17 | self.eq_lr = eq_lr 18 | 19 | # Submodules 20 | if not leaky: 21 | self.actvn = F.relu 22 | else: 23 | self.actvn = lambda x: F.leaky_relu(x, 0.2) 24 | 25 | if not resnet_leaky: 26 | self.resnet_actvn = F.relu 27 | else: 28 | self.resnet_actvn = lambda x: F.leaky_relu(x, 0.2) 29 | 30 | self.conv_p = nn.Conv1d(dim, hidden_size, 1) 31 | 32 | self.block0 = ResnetBlockPointwise( 33 | hidden_size, actvn=self.resnet_actvn, eq_lr=eq_lr) 34 | self.block1 = ResnetBlockPointwise( 35 | hidden_size, actvn=self.resnet_actvn, eq_lr=eq_lr) 36 | self.block2 = ResnetBlockPointwise( 37 | hidden_size, actvn=self.resnet_actvn, eq_lr=eq_lr) 38 | self.block3 = ResnetBlockPointwise( 39 | hidden_size, actvn=self.resnet_actvn, eq_lr=eq_lr) 40 | self.block4 = ResnetBlockPointwise( 41 | hidden_size, actvn=self.resnet_actvn, eq_lr=eq_lr) 42 | 43 | self.fc_cz_0 = nn.Linear(c_dim + z_dim, hidden_size) 44 | self.fc_cz_1 = nn.Linear(c_dim + z_dim, hidden_size) 45 | self.fc_cz_2 = nn.Linear(c_dim + z_dim, hidden_size) 46 | self.fc_cz_3 = nn.Linear(c_dim + z_dim, hidden_size) 47 | self.fc_cz_4 = nn.Linear(c_dim + z_dim, hidden_size) 48 | 49 | self.conv_out = nn.Conv1d(hidden_size, 3, 1) 50 | 51 | if self.eq_lr: 52 | self.conv_p = EqualizedLR(self.conv_p) 53 | self.conv_out = EqualizedLR(self.conv_out) 54 | self.fc_cz_0 = EqualizedLR(self.fc_cz_0) 55 | self.fc_cz_1 = EqualizedLR(self.fc_cz_1) 56 | self.fc_cz_2 = EqualizedLR(self.fc_cz_2) 57 | self.fc_cz_3 = EqualizedLR(self.fc_cz_3) 58 | self.fc_cz_4 = EqualizedLR(self.fc_cz_4) 59 | 60 | # Initialization 61 | nn.init.zeros_(self.conv_out.weight) 62 | 63 | def forward(self, p, geom_descr, z, **kwargs): 64 | c = geom_descr['global'] 65 | batch_size, D, T = p.size() 66 | 67 | cz = torch.cat([c, z], dim=1) 68 | net = self.conv_p(p) 69 | net = net + self.fc_cz_0(cz).unsqueeze(2) 70 | net = self.block0(net) 71 | net = net + self.fc_cz_1(cz).unsqueeze(2) 72 | net = self.block1(net) 73 | net = net + self.fc_cz_2(cz).unsqueeze(2) 74 | net = self.block2(net) 75 | net = net + self.fc_cz_3(cz).unsqueeze(2) 76 | net = self.block3(net) 77 | net = net + self.fc_cz_4(cz).unsqueeze(2) 78 | net = self.block4(net) 79 | 80 | out = self.conv_out(self.actvn(net)) 81 | out = torch.sigmoid(out) 82 | 83 | return out 84 | 85 | 86 | class DecoderEachLayerCLarger(nn.Module): 87 | def __init__(self, c_dim=128, z_dim=128, dim=3, 88 | hidden_size=128, leaky=True, 89 | resnet_leaky=True, eq_lr=False): 90 | super().__init__() 91 | self.c_dim = c_dim 92 | self.eq_lr = eq_lr 93 | if not leaky: 94 | self.actvn = F.relu 95 | else: 96 | self.actvn = lambda x: F.leaky_relu(x, 0.2) 97 | 98 | if not resnet_leaky: 99 | self.resnet_actvn = F.relu 100 | else: 101 | self.resnet_actvn = lambda x: F.leaky_relu(x, 0.2) 102 | 103 | # Submodules 104 | self.conv_p = nn.Conv1d(dim, hidden_size, 1) 105 | 106 | self.block0 = ResnetBlockPointwise( 107 | hidden_size, actvn=self.resnet_actvn, eq_lr=eq_lr) 108 | self.block1 = ResnetBlockPointwise( 109 | hidden_size, actvn=self.resnet_actvn, eq_lr=eq_lr) 110 | self.block2 = ResnetBlockPointwise( 111 | hidden_size, actvn=self.resnet_actvn, eq_lr=eq_lr) 112 | self.block3 = ResnetBlockPointwise( 113 | hidden_size, actvn=self.resnet_actvn, eq_lr=eq_lr) 114 | self.block4 = ResnetBlockPointwise( 115 | hidden_size, actvn=self.resnet_actvn, eq_lr=eq_lr) 116 | self.block5 = ResnetBlockPointwise( 117 | hidden_size, actvn=self.resnet_actvn, eq_lr=eq_lr) 118 | self.block6 = ResnetBlockPointwise( 119 | hidden_size, actvn=self.resnet_actvn, eq_lr=eq_lr) 120 | 121 | self.fc_cz_0 = nn.Linear(c_dim + z_dim, hidden_size) 122 | self.fc_cz_1 = nn.Linear(c_dim + z_dim, hidden_size) 123 | self.fc_cz_2 = nn.Linear(c_dim + z_dim, hidden_size) 124 | self.fc_cz_3 = nn.Linear(c_dim + z_dim, hidden_size) 125 | self.fc_cz_4 = nn.Linear(c_dim + z_dim, hidden_size) 126 | self.fc_cz_5 = nn.Linear(c_dim + z_dim, hidden_size) 127 | self.fc_cz_6 = nn.Linear(c_dim + z_dim, hidden_size) 128 | 129 | self.conv_out = nn.Conv1d(hidden_size, 3, 1) 130 | 131 | if self.eq_lr: 132 | self.conv_p = EqualizedLR(self.conv_p) 133 | self.conv_out = EqualizedLR(self.conv_out) 134 | self.fc_cz_0 = EqualizedLR(self.fc_cz_0) 135 | self.fc_cz_1 = EqualizedLR(self.fc_cz_1) 136 | self.fc_cz_2 = EqualizedLR(self.fc_cz_2) 137 | self.fc_cz_3 = EqualizedLR(self.fc_cz_3) 138 | self.fc_cz_4 = EqualizedLR(self.fc_cz_4) 139 | self.fc_cz_5 = EqualizedLR(self.fc_cz_5) 140 | self.fc_cz_6 = EqualizedLR(self.fc_cz_6) 141 | 142 | # Initialization 143 | nn.init.zeros_(self.conv_out.weight) 144 | 145 | def forward(self, p, geom_descr, z, **kwargs): 146 | c = geom_descr['global'] 147 | batch_size, D, T = p.size() 148 | 149 | cz = torch.cat([c, z], dim=1) 150 | 151 | net = self.conv_p(p) 152 | net = net + self.fc_cz_0(cz).unsqueeze(2) 153 | net = self.block0(net) 154 | net = net + self.fc_cz_1(cz).unsqueeze(2) 155 | net = self.block1(net) 156 | net = net + self.fc_cz_2(cz).unsqueeze(2) 157 | net = self.block2(net) 158 | net = net + self.fc_cz_3(cz).unsqueeze(2) 159 | net = self.block3(net) 160 | net = net + self.fc_cz_4(cz).unsqueeze(2) 161 | net = self.block4(net) 162 | net = net + self.fc_cz_5(cz).unsqueeze(2) 163 | net = self.block5(net) 164 | net = net + self.fc_cz_6(cz).unsqueeze(2) 165 | net = self.block6(net) 166 | 167 | out = self.conv_out(self.actvn(net)) 168 | out = torch.sigmoid(out) 169 | 170 | return out 171 | -------------------------------------------------------------------------------- /mesh2tex/texnet/models/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from mesh2tex.layers import ResnetBlockConv2d, EqualizedLR, pixel_norm 6 | 7 | 8 | class Resnet_Conditional(nn.Module): 9 | def __init__(self, geometry_encoder, img_size, c_dim=128, embed_size=256, 10 | nfilter=64, nfilter_max=1024, 11 | leaky=True, eq_lr=False, pixel_norm=False, 12 | factor=1.): 13 | super().__init__() 14 | self.embed_size = embed_size 15 | s0 = self.s0 = 4 16 | nf = self.nf = nfilter 17 | nf_max = self.nf_max = nfilter_max 18 | self.eq_lr = eq_lr 19 | self.use_pixel_norm = pixel_norm 20 | 21 | # Activation function 22 | if not leaky: 23 | self.actvn = F.relu 24 | else: 25 | self.actvn = lambda x: F.leaky_relu(x, 0.2) 26 | 27 | # Submodules 28 | nlayers = int(np.log2(img_size / s0)) 29 | self.nf0 = min(nf_max, nf * 2**nlayers) 30 | 31 | blocks = [ 32 | ResnetBlockConv2d( 33 | nf, nf, actvn=self.actvn, 34 | eq_lr=eq_lr, 35 | factor=factor, 36 | pixel_norm=pixel_norm) 37 | ] 38 | 39 | for i in range(nlayers): 40 | nf0 = min(nf * 2**i, nf_max) 41 | nf1 = min(nf * 2**(i+1), nf_max) 42 | blocks += [ 43 | nn.AvgPool2d(3, stride=2, padding=1), 44 | ResnetBlockConv2d( 45 | nf0, nf1, actvn=self.actvn, eq_lr=eq_lr, 46 | factor=factor, 47 | pixel_norm=pixel_norm), 48 | ] 49 | 50 | self.conv_img = nn.Conv2d(4, 1*nf, 3, padding=1) 51 | self.resnet = nn.Sequential(*blocks) 52 | self.fc = nn.Linear(self.nf0*s0*s0, 1) 53 | 54 | if self.eq_lr: 55 | self.conv_img = EqualizedLR(self.conv_img) 56 | self.fc = EqualizedLR(self.fc) 57 | 58 | # Initialization 59 | nn.init.zeros_(self.fc.weight) 60 | 61 | def forward(self, x, depth, geom_descr): 62 | batch_size = x.size(0) 63 | 64 | depth = depth.clone() 65 | depth[depth == float("Inf")] = 0 66 | depth[depth == -1*float("Inf")] = 0 67 | 68 | x_and_depth = torch.cat([x, depth], dim=1) 69 | 70 | out = self.conv_img(x_and_depth) 71 | out = self.resnet(out) 72 | 73 | if self.use_pixel_norm: 74 | out = pixel_norm(out) 75 | out = out.view(batch_size, self.nf0*self.s0*self.s0) 76 | out = self.fc(self.actvn(out)) 77 | out = out.squeeze() 78 | return out 79 | -------------------------------------------------------------------------------- /mesh2tex/texnet/models/image_encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision import models 3 | 4 | 5 | class Resnet18(nn.Module): 6 | ''' ResNet-18 conditioning network. 7 | ''' 8 | def __init__(self, c_dim=128, normalize=True, use_linear=True): 9 | ''' Initialization. 10 | 11 | Args: 12 | c_dim (int): output dimension of the latent embedding 13 | normalize (bool): whether the input images should be normalized 14 | use_linear (bool): whether a final linear layer should be used 15 | ''' 16 | super().__init__() 17 | self.normalize = normalize 18 | self.use_linear = use_linear 19 | self.features = models.resnet18(pretrained=True) 20 | self.features.fc = nn.Sequential() 21 | if use_linear: 22 | self.fc = nn.Linear(512, c_dim) 23 | elif c_dim == 512: 24 | self.fc = nn.Sequential() 25 | else: 26 | raise ValueError('c_dim must be 512 if use_linear is False') 27 | 28 | def forward(self, x): 29 | if self.normalize: 30 | x = normalize_imagenet(x) 31 | net = self.features(x) 32 | out = self.fc(net) 33 | return out 34 | 35 | 36 | def normalize_imagenet(x): 37 | x = x.clone() 38 | x[:, 0] = (x[:, 0] - 0.485) / 0.229 39 | x[:, 1] = (x[:, 1] - 0.456) / 0.224 40 | x[:, 2] = (x[:, 2] - 0.406) / 0.225 41 | return x 42 | -------------------------------------------------------------------------------- /mesh2tex/texnet/models/vae_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from mesh2tex.layers import ResnetBlockConv2d, EqualizedLR 6 | 7 | 8 | class Resnet(nn.Module): 9 | def __init__(self, img_size, z_dim=128, c_dim=128, embed_size=256, 10 | nfilter=32, nfilter_max=1024, leaky=True, eq_lr=False): 11 | super().__init__() 12 | self.embed_size = embed_size 13 | s0 = self.s0 = 4 14 | nf = self.nf = nfilter 15 | nf_max = self.nf_max = nfilter_max 16 | self.eq_lr = eq_lr 17 | self.c_dim = c_dim 18 | 19 | # Activation function 20 | if not leaky: 21 | self.actvn = F.relu 22 | else: 23 | self.actvn = lambda x: F.leaky_relu(x, 0.2) 24 | 25 | # Submodules 26 | nlayers = int(np.log2(img_size / s0)) 27 | self.nf0 = min(nf_max, nf * 2**nlayers) 28 | 29 | blocks = [ 30 | ResnetBlockConv2d( 31 | nf, nf, actvn=self.actvn, eq_lr=eq_lr) 32 | ] 33 | 34 | for i in range(nlayers): 35 | nf0 = min(nf * 2**i, nf_max) 36 | nf1 = min(nf * 2**(i+1), nf_max) 37 | blocks += [ 38 | nn.AvgPool2d(3, stride=2, padding=1), 39 | ResnetBlockConv2d( 40 | nf0, nf1, actvn=self.actvn, eq_lr=eq_lr), 41 | ] 42 | 43 | self.conv_img = nn.Conv2d(3, 1*nf, 3, padding=1) 44 | self.resnet = nn.Sequential(*blocks) 45 | self.fc_mean = nn.Linear(self.nf0*s0*s0, z_dim) 46 | self.fc_logstd = nn.Linear(self.nf0*s0*s0, z_dim) 47 | self.fc_inject_c = nn.Linear(self.c_dim, 1*nf) 48 | if self.eq_lr: 49 | self.conv_img = EqualizedLR(self.conv_img) 50 | self.fc = EqualizedLR(self.fc) 51 | 52 | def forward(self, x, geom_descr): 53 | c = geom_descr['global'] 54 | batch_size = x.size(0) 55 | 56 | out = self.conv_img(x) 57 | add = self.fc_inject_c(c).view(out.size(0), self.nf, 1, 1) 58 | out = out + add 59 | out = self.resnet(out) 60 | out = out.view(batch_size, self.nf0*self.s0*self.s0) 61 | 62 | mean = self.fc_mean(self.actvn(out)) 63 | logstd = self.fc_logstd(self.actvn(out)) 64 | return mean, logstd 65 | -------------------------------------------------------------------------------- /mesh2tex/texnet/training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | import numpy as np 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.autograd as autograd 8 | from torchvision.utils import save_image 9 | from mesh2tex import geometry 10 | from mesh2tex.training import BaseTrainer 11 | import mesh2tex.utils.FID.feature_l1 as feature_l1 12 | import mesh2tex.utils.SSIM_L1.ssim_l1_score as SSIM 13 | 14 | 15 | class Trainer(BaseTrainer): 16 | ''' 17 | Subclass of Basetrainer for defining train_step, eval_step and visualize 18 | ''' 19 | def __init__(self, model_g, model_d, 20 | optimizer_g, optimizer_d, 21 | ma_beta=0.99, gp_reg=10., 22 | w_pix=0., w_gan=0., w_vae=0., 23 | experiment='conditional', 24 | gan_type='standard', 25 | loss_type='L1', 26 | multi_gpu=False, 27 | **kwargs): 28 | 29 | # Initialize base trainer 30 | super().__init__(**kwargs) 31 | 32 | # Models and optimizers 33 | self.model_g = model_g 34 | self.model_d = model_d 35 | 36 | self.model_g_ma = copy.deepcopy(model_g) 37 | 38 | for p in self.model_g_ma.parameters(): 39 | p.requires_grad = False 40 | self.model_g_ma.eval() 41 | 42 | self.optimizer_g = optimizer_g 43 | self.optimizer_d = optimizer_d 44 | self.loss_type = loss_type 45 | self.experiment = experiment 46 | # Attributes 47 | self.gp_reg = gp_reg 48 | self.ma_beta = ma_beta 49 | self.gan_type = gan_type 50 | self.multi_gpu = multi_gpu 51 | self.w_pix = w_pix 52 | self.w_vae = w_vae 53 | self.w_gan = w_gan 54 | self.pix_loss = w_pix != 0 55 | self.vae_loss = w_vae != 0 56 | self.gan_loss = w_gan != 0 57 | if self.vae_loss and self.pix_loss: 58 | print('Not possible to combine pix and vae loss') 59 | # Checkpointer 60 | if self.gan_loss is True: 61 | self.checkpoint_io.register_modules( 62 | model_g=self.model_g, model_d=self.model_d, 63 | model_g_ma=self.model_g_ma, 64 | optimizer_g=self.optimizer_g, 65 | optimizer_d=self.optimizer_d, 66 | ) 67 | else: 68 | self.checkpoint_io.register_modules( 69 | model_g=self.model_g, model_d=self.model_d, 70 | model_g_ma=self.model_g_ma, 71 | optimizer_g=self.optimizer_g, 72 | ) 73 | 74 | print('w_pix: %f w_gan: %f w_vae: %f' 75 | % (self.w_pix, self.w_gan, self.w_vae)) 76 | 77 | def train_step(self, batch, epoch_it, it): 78 | ''' 79 | A single training step for the conditional or generative experiment 80 | Output: 81 | Losses 82 | ''' 83 | batch_model0, batch_model1 = batch 84 | if self.experiment == 'conditional': 85 | loss_con = self.train_step_cond(batch_model0) 86 | if self.gan_loss is True: 87 | loss_d = self.train_step_d(batch_model1) 88 | else: 89 | loss_d = 0 90 | 91 | losses = { 92 | 'loss_con': loss_con, 93 | 'loss_d': loss_d, 94 | } 95 | 96 | elif self.experiment == 'generative': 97 | loss_g = self.train_step_g(batch_model0) 98 | if self.gan_loss is True: 99 | loss_d = self.train_step_d(batch_model1) 100 | else: 101 | loss_d = 0 102 | losses = { 103 | 'loss_g': loss_g, 104 | 'loss_d': loss_d, 105 | } 106 | 107 | return losses 108 | 109 | def train_step_d(self, batch): 110 | ''' 111 | A single train step of the discriminator 112 | ''' 113 | model_d = self.model_d 114 | model_g = self.model_g 115 | 116 | model_d.train() 117 | model_g.train() 118 | 119 | if self.multi_gpu: 120 | model_d = nn.DataParallel(model_d) 121 | model_g = nn.DataParallel(model_g) 122 | 123 | self.optimizer_d.zero_grad() 124 | 125 | # Get data 126 | depth = batch['2d.depth'].to(self.device) 127 | img_real = batch['2d.img'].to(self.device) 128 | cam_K = batch['2d.camera_mat'].to(self.device) 129 | cam_W = batch['2d.world_mat'].to(self.device) 130 | mesh_repr = geometry.get_representation(batch, self.device) 131 | condition = batch['condition'].to(self.device) 132 | 133 | # Loss on real 134 | img_real.requires_grad_() 135 | d_real = model_d(img_real, depth, mesh_repr) 136 | 137 | dloss_real = self.compute_gan_loss(d_real, 1) 138 | dloss_real.backward(retain_graph=True) 139 | 140 | # R1 Regularizer 141 | reg = self.gp_reg * compute_grad2(d_real, img_real).mean() 142 | reg.backward() 143 | 144 | # Loss on fake 145 | with torch.no_grad(): 146 | if self.vae_loss is True: 147 | loss_vae, re, kl, img_fake = model_g.elbo(img_real, depth, 148 | cam_K, cam_W, 149 | mesh_repr) 150 | elif self.gan_loss is True: 151 | img_fake = model_g(depth, cam_K, cam_W, mesh_repr, condition) 152 | 153 | d_fake = model_d(img_fake, depth, mesh_repr) 154 | 155 | dloss_fake = self.compute_gan_loss(d_fake, 0) 156 | dloss_fake.backward() 157 | 158 | # Gradient step 159 | self.optimizer_d.step() 160 | 161 | return self.w_gan * (dloss_fake.item() + dloss_real.item()) 162 | 163 | def train_step_g(self, batch): 164 | ''' 165 | A single train step of the generator part of generative model 166 | (VAE: Encoder+Decoder and GAN: Generator) 167 | ''' 168 | model_d = self.model_d 169 | model_g = self.model_g 170 | 171 | model_d.train() 172 | model_g.train() 173 | 174 | if self.multi_gpu: 175 | model_d = nn.DataParallel(model_d) 176 | model_g = nn.DataParallel(model_g) 177 | 178 | self.optimizer_g.zero_grad() 179 | 180 | # Get data 181 | depth = batch['2d.depth'].to(self.device) 182 | img_real = batch['2d.img'].to(self.device) 183 | cam_K = batch['2d.camera_mat'].to(self.device) 184 | cam_W = batch['2d.world_mat'].to(self.device) 185 | mesh_repr = geometry.get_representation(batch, self.device) 186 | 187 | # Loss on fake 188 | loss_vae = 0 189 | loss_gan = 0 190 | 191 | # Forward part and loss derivation for given experiment 192 | if self.vae_loss is True: 193 | loss_vae, re, kl, img_fake = model_g.elbo(img_real, depth, 194 | cam_K, cam_W, 195 | mesh_repr) 196 | if self.gan_loss is True: 197 | d_fake = model_d(img_fake, depth, mesh_repr) 198 | loss_gan = self.compute_gan_loss(d_fake, 1) 199 | elif self.gan_loss is True: 200 | img_fake = model_g(depth, cam_K, cam_W, mesh_repr) 201 | d_fake = model_d(img_fake, depth, mesh_repr) 202 | loss_gan = self.compute_gan_loss(d_fake, 1) 203 | 204 | # weighting of losses (128*128*3=49152) 205 | loss = self.w_vae * loss_vae + self.w_gan * loss_gan 206 | loss.backward() 207 | 208 | # Gradient step 209 | self.optimizer_g.step() 210 | 211 | # Update moving average 212 | # self.update_moving_average() 213 | 214 | return loss.item() 215 | 216 | def train_step_cond(self, batch): 217 | ''' 218 | A single train step of the conditional model 219 | with or w/o generator part of adversarial loss 220 | ''' 221 | model_d = self.model_d 222 | model_g = self.model_g 223 | 224 | model_d.train() 225 | model_g.train() 226 | 227 | if self.multi_gpu: 228 | model_d = nn.DataParallel(model_d) 229 | model_g = nn.DataParallel(model_g) 230 | 231 | self.optimizer_g.zero_grad() 232 | 233 | # Get data 234 | depth = batch['2d.depth'].to(self.device) 235 | img_real = batch['2d.img'].to(self.device) 236 | cam_K = batch['2d.camera_mat'].to(self.device) 237 | cam_W = batch['2d.world_mat'].to(self.device) 238 | mesh_repr = geometry.get_representation(batch, self.device) 239 | condition = batch['condition'].to(self.device) 240 | 241 | # Loss on fake 242 | img_fake = model_g(depth, cam_K, cam_W, mesh_repr, condition) 243 | loss_pix = 0 244 | loss_gan = 0 245 | 246 | if self.pix_loss is True: 247 | loss_pix = self.compute_loss(img_fake, img_real) 248 | if self.gan_loss is True: 249 | d_fake = model_d(img_fake, depth, mesh_repr) 250 | loss_gan = self.compute_gan_loss(d_fake, 1) 251 | 252 | # weighting 253 | loss = self.w_pix * loss_pix + self.w_gan * loss_gan 254 | loss.backward() 255 | 256 | # Gradient step 257 | self.optimizer_g.step() 258 | 259 | # Update moving average 260 | # self.update_moving_average() 261 | 262 | return loss.item() 263 | 264 | def compute_loss(self, img_fake, img_real): 265 | ''' 266 | Compute Pixelloss 267 | ''' 268 | if self.loss_type == 'L2': 269 | loss = F.mse_loss(img_fake, img_real) 270 | elif self.loss_type == 'L1': 271 | loss = F.l1_loss(img_fake, img_real) 272 | else: 273 | raise NotImplementedError 274 | 275 | return loss 276 | 277 | def compute_gan_loss(self, d_out, target): 278 | ''' 279 | Compute GAN loss (standart cross entropy or wasserstein distance) 280 | !!! Without Regularizer 281 | ''' 282 | targets = d_out.new_full(size=d_out.size(), fill_value=target) 283 | 284 | if self.gan_type == 'standard': 285 | loss = F.binary_cross_entropy_with_logits(d_out, targets) 286 | elif self.gan_type == 'wgan': 287 | loss = (2*target - 1) * d_out.mean() 288 | else: 289 | raise NotImplementedError 290 | 291 | return loss 292 | 293 | def eval_step(self, batch): 294 | ''' 295 | Evaluation step with L1, SSIM, featl1 metrics 296 | ''' 297 | depth = batch['2d.depth'].to(self.device) 298 | img_real = batch['2d.img'].to(self.device) 299 | cam_K = batch['2d.camera_mat'].to(self.device) 300 | cam_W = batch['2d.world_mat'].to(self.device) 301 | mesh_repr = geometry.get_representation(batch, self.device) 302 | condition = batch['condition'].to(self.device) 303 | 304 | # Get model 305 | model_g = self.model_g 306 | model_g.eval() 307 | 308 | if self.multi_gpu: 309 | model_g = nn.DataParallel(model_g) 310 | 311 | # Predict 312 | with torch.no_grad(): 313 | img_fake = model_g(depth, cam_K, cam_W, mesh_repr, condition) 314 | 315 | # Derive metrics 316 | loss_val = self.compute_loss(img_fake, img_real) 317 | ssim, l1 = SSIM.calculate_ssim_l1_given_tensor(img_fake, img_real) 318 | featl1 = feature_l1.calculate_feature_l1_given_tensors( 319 | img_fake, img_real, img_real.size(0), True, 2048) 320 | 321 | loss_val_dict = {'loss_val': loss_val.item(), 'SSIM': ssim, 322 | 'featl1': featl1} 323 | return loss_val_dict 324 | 325 | def visualize(self, batch): 326 | ''' 327 | Visualization step 328 | ''' 329 | depth = batch['2d.depth'].to(self.device) 330 | img_real = batch['2d.img'].to(self.device) 331 | cam_K = batch['2d.camera_mat'].to(self.device) 332 | cam_W = batch['2d.world_mat'].to(self.device) 333 | mesh_repr = geometry.get_representation(batch, self.device) 334 | condition = batch['condition'].to(self.device) 335 | 336 | # determine constants 337 | batch_size = depth.size(0) 338 | num_views = depth.size(1) 339 | assert(num_views == 5) 340 | 341 | mesh_points = mesh_repr['points'] 342 | mesh_normals = mesh_repr['normals'] 343 | 344 | # batch loop 345 | for j in range(batch_size): 346 | 347 | # Gather input information on shape, depth, camera and condition 348 | geom_repr = { 349 | 'points': mesh_points[j].expand(num_views, 350 | mesh_points.size(1), 351 | mesh_points.size(2)), 352 | 'normals': mesh_normals[j].expand(num_views, 353 | mesh_normals.size(1), 354 | mesh_normals.size(2)), 355 | } 356 | 357 | depth_ = depth[j] 358 | img_real_ = img_real[j] 359 | if len(condition.size()) == 1: 360 | condition_ = condition[j].expand(num_views) 361 | else: 362 | condition_ = condition[j].expand(num_views, 363 | condition.size(1), 364 | condition.size(2), 365 | condition.size(3)) 366 | cam_K_ = cam_K[j] 367 | cam_W_ = cam_W[j] 368 | 369 | # save real images 370 | save_image(img_real_, 371 | os.path.join(self.vis_dir, 'real_%i.png' % j)) 372 | 373 | # predict fake images and save 374 | self.model_g.eval() 375 | with torch.no_grad(): 376 | img_fake = self.model_g(depth_, cam_K_, cam_W_, 377 | geom_repr, condition_) 378 | save_image(img_fake.cpu(), 379 | os.path.join(self.vis_dir, 'fake_%i.png' % j)) 380 | 381 | # save condition images 382 | if len(condition.size()) != 1: 383 | save_image(condition[j].cpu(), 384 | os.path.join(self.vis_dir, 'condition_%i.png' % j)) 385 | else: 386 | np.savetxt(os.path.join(self.vis_dir, 'condition_%i.txt' % j), 387 | condition_.cpu()) 388 | 389 | def update_moving_average(self): 390 | ''' 391 | Update moving average 392 | ''' 393 | param_dict_src = dict(self.model_g.named_parameters()) 394 | beta = self.ma_beta 395 | for p_name, p_tgt in self.model_g_ma.named_parameters(): 396 | p_src = param_dict_src[p_name] 397 | assert(p_src is not p_tgt) 398 | with torch.no_grad(): 399 | p_ma = beta * p_tgt + (1. - beta) * p_src 400 | p_tgt.copy_(p_ma) 401 | 402 | 403 | def compute_grad2(d_out, x_in): 404 | ''' 405 | Derive L2-Gradient penalty for regularizing the GAN 406 | ''' 407 | batch_size = x_in.size(0) 408 | grad_dout = autograd.grad( 409 | outputs=d_out.sum(), inputs=x_in, 410 | create_graph=True, retain_graph=True, only_inputs=True 411 | )[0] 412 | grad_dout2 = grad_dout.pow(2) 413 | assert(grad_dout2.size() == x_in.size()) 414 | reg = grad_dout2.view(batch_size, -1).sum(1) 415 | return reg 416 | -------------------------------------------------------------------------------- /mesh2tex/training.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | import time 4 | import logging 5 | import numpy as np 6 | from tensorboardX import SummaryWriter 7 | from tqdm import tqdm 8 | from mesh2tex.checkpoints import CheckpointIO 9 | 10 | LOGGER = logging.getLogger(__name__) 11 | 12 | 13 | class BaseTrainer(object): 14 | def __init__(self, 15 | out_dir, 16 | model_selection_metric, model_selection_mode, 17 | print_every, visualize_every, checkpoint_every, 18 | backup_every, validate_every, device=None, model_url=None): 19 | # Directories 20 | self.out_dir = out_dir 21 | self.vis_dir = os.path.join(out_dir, 'vis') 22 | self.log_dir = os.path.join(out_dir, 'log') 23 | self.model_url = model_url 24 | 25 | self.model_selection_metric = model_selection_metric 26 | if model_selection_mode == 'maximize': 27 | self.model_selection_sign = 1 28 | elif model_selection_mode == 'minimize': 29 | self.model_selection_sign = -1 30 | else: 31 | raise ValueError('model_selection_mode must be ' 32 | 'either maximize or minimize.') 33 | 34 | self.print_every = print_every 35 | self.visualize_every = visualize_every 36 | self.checkpoint_every = checkpoint_every 37 | self.backup_every = backup_every 38 | self.validate_every = validate_every 39 | self.device = device 40 | 41 | # Checkpointer 42 | self.checkpoint_io = CheckpointIO(out_dir) 43 | 44 | # Create directories 45 | all_dirs = [self.out_dir, self.vis_dir, self.log_dir] 46 | for directory in all_dirs: 47 | if not os.path.exists(directory): 48 | os.makedirs(directory) 49 | 50 | def train(self, train_loader, val_loader, vis_loader, 51 | exit_after=None, n_epochs=None): 52 | """ 53 | Main training method with epoch loop, validation and model selection 54 | 55 | args: 56 | train_loader 57 | val_loader (Validation) 58 | vis_loader (Visualsation during training) 59 | """ 60 | # Load if checkpoint exist 61 | epoch_it, it, metric_val_best = self.init_training() 62 | print('Current best validation metric (%s): %.8f' 63 | % (self.model_selection_metric, metric_val_best)) 64 | 65 | # for tensorboard 66 | summary_writer = SummaryWriter(os.path.join(self.log_dir)) 67 | 68 | if self.visualize_every > 0: 69 | if vis_loader is None: 70 | data_vis = next(iter(val_loader)) 71 | else: 72 | data_vis = next(iter(vis_loader)) 73 | 74 | # Main training loop 75 | t0 = time.time() 76 | while (n_epochs is None) or (epoch_it < n_epochs): 77 | epoch_it += 1 78 | 79 | for batch in train_loader: 80 | it += 1 81 | 82 | losses = self.train_step(batch, epoch_it=epoch_it, it=it) 83 | 84 | if isinstance(losses, dict): 85 | loss_str = [] 86 | for k, v in losses.items(): 87 | summary_writer.add_scalar('train/%s' % k, v, it) 88 | loss_str.append('%s=%.4f' % (k, v)) 89 | loss_str = ' '.join(loss_str) 90 | else: 91 | summary_writer.add_scalar('train/loss', losses, it) 92 | loss_str = ('loss=%.4f' % losses) 93 | 94 | # Print output 95 | if self.print_every > 0 and (it % self.print_every) == 0: 96 | print('[Epoch %02d] it=%03d, %s' 97 | % (epoch_it, it, loss_str)) 98 | 99 | # Visualize output 100 | if (self.visualize_every > 0 101 | and (it % self.visualize_every) == 0): 102 | print('Visualizing') 103 | try: 104 | self.visualize(data_vis) 105 | except NotImplementedError: 106 | LOGGER.warn('Visualizing method not implemented.') 107 | 108 | # Save checkpoint 109 | if (self.checkpoint_every > 0 110 | and (it % self.checkpoint_every) == 0): 111 | print('Saving checkpoint') 112 | self.checkpoint_io.save( 113 | 'model.pt', epoch_it=epoch_it, it=it, 114 | loss_val_best=metric_val_best) 115 | 116 | # Backup if necessary 117 | if (self.backup_every > 0 118 | and (it % self.backup_every) == 0): 119 | print('Backup checkpoint') 120 | self.checkpoint_io.save( 121 | 'model_%d.pt' % it, epoch_it=epoch_it, it=it, 122 | loss_val_best=metric_val_best) 123 | 124 | # Run validation and select if better 125 | if self.validate_every > 0 and (it % self.validate_every) == 0: 126 | try: 127 | eval_dict = self.evaluate(val_loader) 128 | print(eval_dict) 129 | except NotImplementedError: 130 | LOGGER.warn('Evaluation method not implemented.') 131 | eval_dict = {} 132 | 133 | for k, v in eval_dict.items(): 134 | summary_writer.add_scalar('val/%s' % k, v, it) 135 | 136 | if self.model_selection_metric is not None: 137 | metric_val = eval_dict[self.model_selection_metric] 138 | print( 139 | 'Validation metric (%s): %.4f' 140 | % (self.model_selection_metric, metric_val)) 141 | 142 | improvement = ( 143 | self.model_selection_sign 144 | * (metric_val - metric_val_best) 145 | ) 146 | if improvement > 0: 147 | metric_val_best = metric_val 148 | print('New best model (loss %.4f)' 149 | % metric_val_best) 150 | self.checkpoint_io.save( 151 | 'model_best.pt', epoch_it=epoch_it, it=it, 152 | loss_val_best=metric_val_best) 153 | 154 | # Exit if necessary 155 | if exit_after > 0 and (time.time() - t0) >= exit_after: 156 | print('Time limit reached. Exiting.') 157 | self.checkpoint_io.save( 158 | 'model.pt', epoch_it=epoch_it, it=it, 159 | loss_val_best=metric_val_best) 160 | exit(3) 161 | 162 | print('Maximum number of epochs reached. Exiting.') 163 | self.checkpoint_io.save( 164 | 'model.pt', epoch_it=epoch_it, it=it, 165 | loss_val_best=metric_val_best) 166 | 167 | def evaluate(self, val_loader): 168 | ''' 169 | Evaluate model with validation data using eval_step 170 | 171 | args: 172 | data loader 173 | ''' 174 | 175 | eval_list = defaultdict(list) 176 | 177 | for data in tqdm(val_loader): 178 | eval_step_dict = self.eval_step(data) 179 | 180 | for k, v in eval_step_dict.items(): 181 | eval_list[k].append(v) 182 | 183 | eval_dict = {k: np.mean(v) for k, v in eval_list.items()} 184 | return eval_dict 185 | 186 | def init_training(self): 187 | ''' 188 | Init training by loading the latest checkpoint 189 | ''' 190 | try: 191 | if self.model_url is not None: 192 | load_dict = self.checkpoint_io.load(self.model_url) 193 | else: 194 | load_dict = self.checkpoint_io.load('model.pt') 195 | except FileExistsError: 196 | load_dict = dict() 197 | epoch_it = load_dict.get('epoch_it', -1) 198 | it = load_dict.get('it', -1) 199 | metric_val_best = load_dict.get( 200 | 'loss_val_best', -self.model_selection_sign * np.inf) 201 | 202 | return epoch_it, it, metric_val_best 203 | 204 | def train_step(self, *args, **kwargs): 205 | raise NotImplementedError 206 | 207 | def eval_step(self, *args, **kwargs): 208 | raise NotImplementedError 209 | 210 | def visualize(self, *args, **kwargs): 211 | raise NotImplementedError 212 | -------------------------------------------------------------------------------- /mesh2tex/utils/FID/feature_l1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import torch 4 | import numpy as np 5 | from imageio import imread 6 | from torch.autograd import Variable 7 | from torch.nn.functional import adaptive_avg_pool2d 8 | 9 | from mesh2tex.utils.FID.inception import InceptionV3 10 | 11 | 12 | def get_activations(images, model, batch_size=64, dims=2048, 13 | cuda=False, verbose=False): 14 | """Calculates the activations of the pool_3 layer for all images. 15 | 16 | Params: 17 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values 18 | must lie between 0 and 1. 19 | -- model : Instance of inception model 20 | -- batch_size : the images numpy array is split into batches with 21 | batch size batch_size. A reasonable batch size depends 22 | on the hardware. 23 | -- dims : Dimensionality of features returned by Inception 24 | -- cuda : If set to True, use GPU 25 | -- verbose : If set to True and parameter out_step is given, the number 26 | of calculated batches is reported. 27 | Returns: 28 | -- A numpy array of dimension (num images, dims) that contains the 29 | activations of the given tensor when feeding inception with the 30 | query tensor. 31 | """ 32 | model.eval() 33 | 34 | d0 = images.shape[0] 35 | if batch_size > d0: 36 | print(('Warning: batch size is bigger than the data size. ' 37 | 'Setting batch size to data size')) 38 | batch_size = d0 39 | 40 | n_batches = d0 // batch_size 41 | n_used_imgs = n_batches * batch_size 42 | 43 | pred_arr = np.empty((n_used_imgs, dims)) 44 | for i in range(n_batches): 45 | if verbose: 46 | print('\rPropagating batch %d/%d' % (i + 1, n_batches), 47 | end='', flush=True) 48 | start = i * batch_size 49 | end = start + batch_size 50 | 51 | batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor) 52 | batch = Variable(batch, volatile=True) 53 | if cuda: 54 | batch = batch.cuda() 55 | 56 | pred = model(batch)[0] 57 | 58 | # If model output is not scalar, apply global spatial average pooling. 59 | # This happens if you choose a dimensionality not equal 2048. 60 | if pred.shape[2] != 1 or pred.shape[3] != 1: 61 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 62 | 63 | pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1) 64 | 65 | if verbose: 66 | print(' done') 67 | 68 | return pred_arr 69 | 70 | 71 | def _compute_statistics_of_path(path0, path1, model, batch_size, dims, cuda): 72 | path0 = pathlib.Path(path0) 73 | path1 = pathlib.Path(path1) 74 | 75 | files_list = os.listdir(path0) 76 | files0 = [os.path.join(path0, f) for f in files_list] 77 | files1 = [os.path.join(path1, f) for f in files_list] 78 | assert(len(files0) == len(files1)) 79 | 80 | # First set of images 81 | imgs = np.array([imread(str(fn)).astype(np.float32) for fn in files0]) 82 | # Bring images to shape (B, 3, H, W) 83 | imgs = imgs.transpose((0, 3, 1, 2))[:, :3] 84 | # Rescale images to be between 0 and 1 85 | imgs /= 255 86 | feat0 = get_activations(imgs, model, batch_size, dims, cuda, False) 87 | 88 | # Second set of images 89 | imgs = np.array([imread(str(fn)).astype(np.float32) for fn in files1]) 90 | # Bring images to shape (B, 3, H, W) 91 | imgs = imgs.transpose((0, 3, 1, 2))[:, :3] 92 | # Rescale images to be between 0 and 1 93 | imgs /= 255 94 | feat1 = get_activations(imgs, model, batch_size, dims, cuda, False) 95 | 96 | feature_l1 = np.mean(np.abs(feat0 - feat1)) 97 | return feature_l1 98 | 99 | 100 | def _compute_statistics_of_tensors(images_fake, images_real, model, batch_size, 101 | dims, cuda): 102 | 103 | # First set of images 104 | imgs = images_fake.cpu().numpy().astype(np.float32) 105 | feat0 = get_activations(imgs, model, batch_size, dims, cuda, False) 106 | 107 | # Second set of images 108 | imgs = images_real.cpu().numpy().astype(np.float32) 109 | feat1 = get_activations(imgs, model, batch_size, dims, cuda, False) 110 | 111 | feature_l1 = np.mean(np.abs(feat0 - feat1)) 112 | return feature_l1 113 | 114 | 115 | def calculate_feature_l1_given_paths(paths, batch_size, cuda, dims): 116 | """Calculates the FID of two paths""" 117 | for p in paths: 118 | if not os.path.exists(p): 119 | raise RuntimeError('Invalid path: %s' % p) 120 | 121 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 122 | 123 | model = InceptionV3([block_idx]) 124 | if cuda: 125 | model.cuda() 126 | 127 | feature_l1 = _compute_statistics_of_path( 128 | paths[0], paths[1], model, batch_size, dims, cuda) 129 | 130 | return feature_l1 131 | 132 | 133 | def calculate_feature_l1_given_tensors(tensor1, tensor2, 134 | batch_size, cuda, dims): 135 | """Calculates the FID of two image tensors""" 136 | 137 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 138 | 139 | model = InceptionV3([block_idx]) 140 | if cuda: 141 | model.cuda() 142 | 143 | feature_l1 = _compute_statistics_of_tensors( 144 | tensor1, tensor2, model, batch_size, dims, cuda) 145 | 146 | return feature_l1 147 | -------------------------------------------------------------------------------- /mesh2tex/utils/FID/fid_score.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 4 | 5 | import torch 6 | import numpy as np 7 | from scipy.misc import imread 8 | from scipy import linalg 9 | from torch.autograd import Variable 10 | from torch.nn.functional import adaptive_avg_pool2d 11 | 12 | from mesh2tex.utils.FID.inception import InceptionV3 13 | 14 | 15 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 16 | parser.add_argument('path', type=str, nargs=2, 17 | help=('Path to the generated images or ' 18 | 'to .npz statistic files')) 19 | parser.add_argument('--batch-size', type=int, default=64, 20 | help='Batch size to use') 21 | parser.add_argument('--dims', type=int, default=2048, 22 | choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), 23 | help=('Dimensionality of Inception features to use. ' 24 | 'By default, uses pool3 features')) 25 | parser.add_argument('-c', '--gpu', default='', type=str, 26 | help='GPU to use (leave blank for CPU only)') 27 | 28 | 29 | def get_activations(images, model, batch_size=64, dims=2048, 30 | cuda=False, verbose=False): 31 | """Calculates the activations of the pool_3 layer for all images. 32 | 33 | Params: 34 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values 35 | must lie between 0 and 1. 36 | -- model : Instance of inception model 37 | -- batch_size : the images numpy array is split into batches with 38 | batch size batch_size. A reasonable batch size depends 39 | on the hardware. 40 | -- dims : Dimensionality of features returned by Inception 41 | -- cuda : If set to True, use GPU 42 | -- verbose : If set to True and parameter out_step is given, the number 43 | of calculated batches is reported. 44 | Returns: 45 | -- A numpy array of dimension (num images, dims) that contains the 46 | activations of the given tensor when feeding inception with the 47 | query tensor. 48 | """ 49 | model.eval() 50 | 51 | d0 = images.shape[0] 52 | if batch_size > d0: 53 | print(('Warning: batch size is bigger than the data size. ' 54 | 'Setting batch size to data size')) 55 | batch_size = d0 56 | 57 | n_batches = d0 // batch_size 58 | n_used_imgs = n_batches * batch_size 59 | 60 | pred_arr = np.empty((n_used_imgs, dims)) 61 | for i in range(n_batches): 62 | if verbose: 63 | print('\rPropagating batch %d/%d' % (i + 1, n_batches), 64 | end='', flush=True) 65 | start = i * batch_size 66 | end = start + batch_size 67 | 68 | batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor) 69 | batch = Variable(batch, volatile=True) 70 | if cuda: 71 | batch = batch.cuda() 72 | 73 | pred = model(batch)[0] 74 | 75 | # If model output is not scalar, apply global spatial average pooling. 76 | # This happens if you choose a dimensionality not equal 2048. 77 | if pred.shape[2] != 1 or pred.shape[3] != 1: 78 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 79 | 80 | pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1) 81 | 82 | if verbose: 83 | print(' done') 84 | 85 | return pred_arr 86 | 87 | 88 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 89 | """Numpy implementation of the Frechet Distance. 90 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 91 | and X_2 ~ N(mu_2, C_2) is 92 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 93 | 94 | Stable version by Dougal J. Sutherland. 95 | 96 | Params: 97 | -- mu1 : Numpy array containing the activations of a layer of the 98 | inception net (like returned by the function 'get_predictions') 99 | for generated samples. 100 | -- mu2 : The sample mean over activations, precalculated on an 101 | representative data set. 102 | -- sigma1: The covariance matrix over activations for generated samples. 103 | -- sigma2: The covariance matrix over activations, precalculated on an 104 | representative data set. 105 | 106 | Returns: 107 | -- : The Frechet Distance. 108 | """ 109 | 110 | mu1 = np.atleast_1d(mu1) 111 | mu2 = np.atleast_1d(mu2) 112 | 113 | sigma1 = np.atleast_2d(sigma1) 114 | sigma2 = np.atleast_2d(sigma2) 115 | 116 | assert mu1.shape == mu2.shape, \ 117 | 'Training and test mean vectors have different lengths' 118 | assert sigma1.shape == sigma2.shape, \ 119 | 'Training and test covariances have different dimensions' 120 | 121 | diff = mu1 - mu2 122 | 123 | # Product might be almost singular 124 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 125 | if not np.isfinite(covmean).all(): 126 | msg = ('fid calculation produces singular product; ' 127 | 'adding %s to diagonal of cov estimates') % eps 128 | print(msg) 129 | offset = np.eye(sigma1.shape[0]) * eps 130 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 131 | 132 | # Numerical error might give slight imaginary component 133 | if np.iscomplexobj(covmean): 134 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 135 | m = np.max(np.abs(covmean.imag)) 136 | raise ValueError('Imaginary component {}'.format(m)) 137 | covmean = covmean.real 138 | 139 | tr_covmean = np.trace(covmean) 140 | 141 | return (diff.dot(diff) + np.trace(sigma1) + 142 | np.trace(sigma2) - 2 * tr_covmean) 143 | 144 | 145 | def calculate_activation_statistics(images, model, batch_size=64, 146 | dims=2048, cuda=False, verbose=False): 147 | """Calculation of the statistics used by the FID. 148 | Params: 149 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values 150 | must lie between 0 and 1. 151 | -- model : Instance of inception model 152 | -- batch_size : The images numpy array is split into batches with 153 | batch size batch_size. A reasonable batch size 154 | depends on the hardware. 155 | -- dims : Dimensionality of features returned by Inception 156 | -- cuda : If set to True, use GPU 157 | -- verbose : If set to True and parameter out_step is given, the 158 | number of calculated batches is reported. 159 | Returns: 160 | -- mu : The mean over samples of the activations of the pool_3 layer of 161 | the inception model. 162 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 163 | the inception model. 164 | """ 165 | act = get_activations(images, model, batch_size, dims, cuda, verbose) 166 | mu = np.mean(act, axis=0) 167 | sigma = np.cov(act, rowvar=False) 168 | return mu, sigma 169 | 170 | 171 | def _compute_statistics_of_path(path, model, batch_size, dims, cuda): 172 | if path.endswith('.npz'): 173 | f = np.load(path) 174 | m, s = f['mu'][:], f['sigma'][:] 175 | f.close() 176 | else: 177 | path = pathlib.Path(path) 178 | files = list(path.glob('*.jpg')) + list(path.glob('*.png')) 179 | 180 | imgs = np.array([imread(str(fn)).astype(np.float32) for fn in files]) 181 | 182 | # Bring images to shape (B, 3, H, W) 183 | imgs = imgs.transpose((0, 3, 1, 2))[:, :3] 184 | 185 | # Rescale images to be between 0 and 1 186 | imgs /= 255 187 | 188 | m, s = calculate_activation_statistics(imgs, model, batch_size, 189 | dims, cuda) 190 | 191 | return m, s 192 | 193 | 194 | def calculate_fid_given_paths(paths, batch_size, cuda, dims): 195 | """Calculates the FID of two paths""" 196 | for p in paths: 197 | if not os.path.exists(p): 198 | raise RuntimeError('Invalid path: %s' % p) 199 | 200 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 201 | 202 | model = InceptionV3([block_idx]) 203 | if cuda: 204 | model.cuda() 205 | 206 | m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, 207 | dims, cuda) 208 | m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, 209 | dims, cuda) 210 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 211 | 212 | return fid_value 213 | 214 | 215 | if __name__ == '__main__': 216 | args = parser.parse_args() 217 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 218 | 219 | fid_value = calculate_fid_given_paths(args.path, 220 | args.batch_size, 221 | args.gpu != '', 222 | args.dims) 223 | print('FID: ', fid_value) 224 | -------------------------------------------------------------------------------- /mesh2tex/utils/FID/inception.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torchvision import models 4 | 5 | 6 | class InceptionV3(nn.Module): 7 | """Pretrained InceptionV3 network returning feature maps""" 8 | 9 | # Index of default block of inception to return, 10 | # corresponds to output of final average pooling 11 | DEFAULT_BLOCK_INDEX = 3 12 | 13 | # Maps feature dimensionality to their output blocks indices 14 | BLOCK_INDEX_BY_DIM = { 15 | 64: 0, # First max pooling features 16 | 192: 1, # Second max pooling featurs 17 | 768: 2, # Pre-aux classifier features 18 | 2048: 3 # Final average pooling features 19 | } 20 | 21 | def __init__(self, 22 | output_blocks=[DEFAULT_BLOCK_INDEX], 23 | resize_input=True, 24 | normalize_input=True, 25 | requires_grad=False): 26 | """Build pretrained InceptionV3 27 | 28 | Parameters 29 | ---------- 30 | output_blocks : list of int 31 | Indices of blocks to return features of. Possible values are: 32 | - 0: corresponds to output of first max pooling 33 | - 1: corresponds to output of second max pooling 34 | - 2: corresponds to output which is fed to aux classifier 35 | - 3: corresponds to output of final average pooling 36 | resize_input : bool 37 | If true, bilinearly resizes input to width and height 299 before 38 | feeding input to model. As the network without fully connected 39 | layers is fully convolutional, it should be able to handle inputs 40 | of arbitrary size, so resizing might not be strictly needed 41 | normalize_input : bool 42 | If true, normalizes the input to the statistics the pretrained 43 | Inception network expects 44 | requires_grad : bool 45 | If true, parameters of the model require gradient. Possibly useful 46 | for finetuning the network 47 | """ 48 | super(InceptionV3, self).__init__() 49 | 50 | self.resize_input = resize_input 51 | self.normalize_input = normalize_input 52 | self.output_blocks = sorted(output_blocks) 53 | self.last_needed_block = max(output_blocks) 54 | 55 | assert self.last_needed_block <= 3, \ 56 | 'Last possible output block index is 3' 57 | 58 | self.blocks = nn.ModuleList() 59 | 60 | inception = models.inception_v3(pretrained=True) 61 | 62 | # Block 0: input to maxpool1 63 | block0 = [ 64 | inception.Conv2d_1a_3x3, 65 | inception.Conv2d_2a_3x3, 66 | inception.Conv2d_2b_3x3, 67 | nn.MaxPool2d(kernel_size=3, stride=2) 68 | ] 69 | self.blocks.append(nn.Sequential(*block0)) 70 | 71 | # Block 1: maxpool1 to maxpool2 72 | if self.last_needed_block >= 1: 73 | block1 = [ 74 | inception.Conv2d_3b_1x1, 75 | inception.Conv2d_4a_3x3, 76 | nn.MaxPool2d(kernel_size=3, stride=2) 77 | ] 78 | self.blocks.append(nn.Sequential(*block1)) 79 | 80 | # Block 2: maxpool2 to aux classifier 81 | if self.last_needed_block >= 2: 82 | block2 = [ 83 | inception.Mixed_5b, 84 | inception.Mixed_5c, 85 | inception.Mixed_5d, 86 | inception.Mixed_6a, 87 | inception.Mixed_6b, 88 | inception.Mixed_6c, 89 | inception.Mixed_6d, 90 | inception.Mixed_6e, 91 | ] 92 | self.blocks.append(nn.Sequential(*block2)) 93 | 94 | # Block 3: aux classifier to final avgpool 95 | if self.last_needed_block >= 3: 96 | block3 = [ 97 | inception.Mixed_7a, 98 | inception.Mixed_7b, 99 | inception.Mixed_7c, 100 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 101 | ] 102 | self.blocks.append(nn.Sequential(*block3)) 103 | 104 | for param in self.parameters(): 105 | param.requires_grad = requires_grad 106 | 107 | def forward(self, inp): 108 | """Get Inception feature maps 109 | 110 | Parameters 111 | ---------- 112 | inp : torch.autograd.Variable 113 | Input tensor of shape Bx3xHxW. Values are expected to be in 114 | range (0, 1) 115 | 116 | Returns 117 | ------- 118 | List of torch.autograd.Variable, corresponding to the selected output 119 | block, sorted ascending by index 120 | """ 121 | outp = [] 122 | x = inp 123 | 124 | if self.resize_input: 125 | x = F.upsample(x, size=(299, 299), mode='bilinear') 126 | 127 | if self.normalize_input: 128 | x = x.clone() 129 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 130 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 131 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 132 | 133 | for idx, block in enumerate(self.blocks): 134 | x = block(x) 135 | if idx in self.output_blocks: 136 | outp.append(x) 137 | 138 | if idx == self.last_needed_block: 139 | break 140 | 141 | return outp 142 | -------------------------------------------------------------------------------- /mesh2tex/utils/SSIM_L1/ssim_l1_score.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.measure import compare_ssim as ssim 3 | import imageio 4 | import os 5 | 6 | 7 | def calculate_ssim_l1_given_paths(paths): 8 | file_list = os.listdir(paths[0]) 9 | ssim_value = 0 10 | l1_value = 0 11 | for f in file_list: 12 | # assert(i[0] == i[1]) 13 | fake = load_img(paths[0] + f) 14 | real = load_img(paths[1] + f) 15 | ssim_value += np.mean( 16 | ssim(fake, real, multichannel=True)) 17 | l1_value += np.mean(abs(fake - real)) 18 | 19 | ssim_value = ssim_value/float(len(file_list)) 20 | l1_value = l1_value/float(len(file_list)) 21 | 22 | return ssim_value, l1_value 23 | 24 | 25 | def calculate_ssim_l1_given_tensor(images_fake, images_real): 26 | bs = images_fake.size(0) 27 | images_fake = images_fake.permute(0, 2, 3, 1).cpu().numpy() 28 | images_real = images_real.permute(0, 2, 3, 1).cpu().numpy() 29 | 30 | ssim_value = 0 31 | l1_value = 0 32 | for i in range(bs): 33 | # assert(i[0] == i[1]) 34 | fake = images_fake[i] 35 | real = images_real[i] 36 | ssim_value += np.mean( 37 | ssim(fake, real, multichannel=True)) 38 | l1_value += np.mean(abs(fake - real)) 39 | ssim_value = ssim_value/float(bs) 40 | l1_value = l1_value/float(bs) 41 | 42 | return ssim_value, l1_value 43 | 44 | 45 | def load_img(path): 46 | img = imageio.imread(path) 47 | img = img.astype(np.float64) / 255 48 | if img.ndim == 2: 49 | img = np.stack([img, img, img], axis=-1) 50 | elif img.shape[2] == 1: 51 | img = np.concatenate([img, img, img], axis=-1) 52 | elif img.shape[2] == 4: 53 | img = img[:, :, :3] 54 | 55 | return img 56 | -------------------------------------------------------------------------------- /mesh2tex/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/mesh2tex/utils/__init__.py -------------------------------------------------------------------------------- /mesh2tex/utils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | from plyfile import PlyElement, PlyData 3 | import numpy as np 4 | 5 | 6 | def export_pointcloud(vertices, out_file, as_text=True): 7 | assert(vertices.shape[1] == 3) 8 | vertices = vertices.astype(np.float32) 9 | vector_dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4')] 10 | vertices = vertices.view(dtype=vector_dtype).flatten() 11 | plyel = PlyElement.describe(vertices, 'vertex') 12 | plydata = PlyData([plyel], text=as_text) 13 | plydata.write(out_file) 14 | 15 | 16 | def load_pointcloud(in_file): 17 | plydata = PlyData.read(in_file) 18 | vertices = np.stack([ 19 | plydata['vertex']['x'], 20 | plydata['vertex']['y'], 21 | plydata['vertex']['z'] 22 | ], axis=1) 23 | return vertices 24 | 25 | 26 | def read_off(file): 27 | """ 28 | Reads vertices and faces from an off file. 29 | 30 | :param file: path to file to read 31 | :type file: str 32 | :return: vertices and faces as lists of tuples 33 | :rtype: [(float)], [(int)] 34 | """ 35 | 36 | assert os.path.exists(file), 'file %s not found' % file 37 | 38 | with open(file, 'r') as fp: 39 | lines = fp.readlines() 40 | lines = [line.strip() for line in lines] 41 | 42 | # Fix for ModelNet bug were 'OFF' and the number of vertices and faces 43 | # are all in the first line. 44 | if len(lines[0]) > 3: 45 | assert lines[0][:3] == 'OFF' or lines[0][:3] == 'off', \ 46 | 'invalid OFF file %s' % file 47 | 48 | parts = lines[0][3:].split(' ') 49 | assert len(parts) == 3 50 | 51 | num_vertices = int(parts[0]) 52 | assert num_vertices > 0 53 | 54 | num_faces = int(parts[1]) 55 | assert num_faces > 0 56 | 57 | start_index = 1 58 | # This is the regular case! 59 | else: 60 | assert lines[0] == 'OFF' or lines[0] == 'off', \ 61 | 'invalid OFF file %s' % file 62 | 63 | parts = lines[1].split(' ') 64 | assert len(parts) == 3 65 | 66 | num_vertices = int(parts[0]) 67 | assert num_vertices > 0 68 | 69 | num_faces = int(parts[1]) 70 | assert num_faces > 0 71 | 72 | start_index = 2 73 | 74 | vertices = [] 75 | for i in range(num_vertices): 76 | vertex = lines[start_index + i].split(' ') 77 | vertex = [float(point.strip()) for point in vertex if point != ''] 78 | assert len(vertex) == 3 79 | 80 | vertices.append(vertex) 81 | 82 | faces = [] 83 | for i in range(num_faces): 84 | face = lines[start_index + num_vertices + i].split(' ') 85 | face = [index.strip() for index in face if index != ''] 86 | 87 | # check to be sure 88 | for index in face: 89 | assert index != '', \ 90 | 'found empty vertex index: %s (%s)' \ 91 | % (lines[start_index + num_vertices + i], file) 92 | 93 | face = [int(index) for index in face] 94 | 95 | assert face[0] == len(face) - 1, \ 96 | 'face should have %d vertices but as %d (%s)' \ 97 | % (face[0], len(face) - 1, file) 98 | assert face[0] == 3, \ 99 | 'only triangular meshes supported (%s)' % file 100 | for index in face: 101 | assert index >= 0 and index < num_vertices, \ 102 | 'vertex %d (of %d vertices) does not exist (%s)' \ 103 | % (index, num_vertices, file) 104 | 105 | assert len(face) > 1 106 | 107 | faces.append(face) 108 | 109 | return vertices, faces 110 | 111 | assert False, 'could not open %s' % file -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Texture Fields 2 |
3 | 4 | 5 |
6 | 7 | This repository contains code for the paper 'Texture Fields: Learning Texture Representations in Function Space'. 8 | 9 | You can find detailed usage instructions for training your own models and using pretrained models below. 10 | 11 | If you find our code or paper useful, please consider citing 12 | 13 | @inproceedings{OechsleICCV2019, 14 | title = {Texture Fields: Learning Texture Representations in Function Space}, 15 | author = {Oechsle, Michael and Mescheder,Lars and Niemeyer, Michael and Strauss, Thilo and Geiger, Andreas}, 16 | booktitle = {Proceedings IEEE International Conf. on Computer Vision (ICCV)}, 17 | year = {2019} 18 | } 19 | 20 | ## Installation 21 | The simplest way to run our implementation is to use [anaconda](https://www.anaconda.com/). 22 | 23 | You can create an anaconda environment called `texturefields` with 24 | ``` 25 | conda env create -f environment.yaml 26 | conda activate texturefields 27 | ``` 28 | 29 | ## Demo 30 | 31 | If you just want to quickly test our method on the single view reconstruction task, you can run our demo with 32 | ``` 33 | python generate.py configs/singleview/texfields/car_demo.yaml 34 | ``` 35 | 36 | The script is using a pre-trained model for reconstructing the texture of the car object provided in `data/demo/`. You can find predicted images in `out/demo/fake/`. 37 | 38 | ## Dataset 39 | For downloading the preprocessed data, run the following script. 40 | ``` 41 | source ./scripts/download_data.sh 42 | ``` 43 | We just provide data for the car category with a file size of 33 GB. 44 | The dataset is copied to the `data/` folder. For each 3D object we have 17 input views, 10 random views with corresponding depth maps and camera information. The train, test and validation splits are located in the main sub folder of the categories. For visualization, we provide renderings from fixed views in th `visualize` subfolder. 45 | Data structure: 46 | ____ 47 | data/shapenet/data_cars/{ModelID}/\ 48 |      input_image/ \ 49 |      image/\ 50 |      depth/\ 51 |      visualize/\ 52 |          image/\ 53 |          depth/\ 54 |      pointcloud.npz 55 | ____ 56 | 57 | ## Usage 58 | You can use our implementation for training, generation and evaluation. For each mode there is a corresponding file that needs to be run. 59 | 60 | ### Generation 61 | #### Single View Texture Reconstruction 62 | For testing our method, you can generate novel views by running 63 | ``` 64 | python generate.py CONFIG.yaml 65 | ``` 66 | CONFIG.yaml stands for the path to a config file. 67 | For generation, you can choose whether the views are from random views or fixed views on a circle around the object. If you would like to use random views you can add to the config file 68 | ``` 69 | test: 70 | dataset_split: 'test_eval' 71 | ``` 72 | to the config file. For evaluating fixed views, you can replace the option with `'test_vis'`. 73 | For the evaluation we use the random views. 74 | Example of config files can be found in 75 | ``` 76 | configs/singleview/car.yaml 77 | configs/singleview/car_eval_rnd.yaml 78 | ``` 79 | 80 | #### Generative Model 81 | You can run our generative models by executing 82 | ``` 83 | python generate.py configs/VAE/000_eval_fix.yaml 84 | ``` 85 | for the GAN 86 | ``` 87 | python generate.py configs/GAN/000_eval_fix.yaml 88 | ``` 89 | for predicting novel texture in 90 | ``` 91 | out/VAE/car/eval_fix/ 92 | out/GAN/car/eval_fix/ 93 | ``` 94 | 95 | ### Evaluation 96 | You can evaluate the performance of our method on the single view reconstruction task by running 97 | ``` 98 | python generate.py configs/singleview/car_eval_rnd.yaml 99 | python evaluate.py configs/singleview/car_eval_rnd.yaml 100 | ``` 101 | The script writes the results into the respective output folder. 102 | ``` 103 | out/singleview/car/eval_rnd/ 104 | ``` 105 | 106 | ### Training 107 | For training a model from scratch run 108 | ``` 109 | python train.py CONFIG.yaml 110 | ``` 111 | in the conda environement. 112 | Please set the following option in the config file: 113 | ``` 114 | model: 115 | model_url: 116 | ``` 117 | So that pre-trained model are not loaded. 118 | The training process can be visualized with tensorboard. The logfiles are saved to the `logs` folder in the output directory. 119 | ``` 120 | tensorboard --logdir ./out --port 6006 121 | ``` -------------------------------------------------------------------------------- /scripts/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd data 3 | mkdir shapenet 4 | cd shapenet 5 | wget https://s3.eu-central-1.amazonaws.com/avg-projects/texture_fields/data/data_cars.zip 6 | unzip data_cars.zip -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """Base file for starting training 2 | """ 3 | 4 | import torch 5 | import argparse 6 | from mesh2tex import config 7 | import matplotlib 8 | 9 | matplotlib.use('Agg') 10 | 11 | 12 | parser = argparse.ArgumentParser( 13 | description='Train a Texture Field.' 14 | ) 15 | parser.add_argument('config', type=str, help='Path to config file.') 16 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.') 17 | parser.add_argument('--exit-after', type=int, default=-1, 18 | help='Checkpoint and exit after specified ' 19 | 'number of seconds with exit code 2.') 20 | args = parser.parse_args() 21 | cfg = config.load_config(args.config, 'configs/default.yaml') 22 | is_cuda = (torch.cuda.is_available() and not args.no_cuda) 23 | device = torch.device("cuda" if is_cuda else "cpu") 24 | exit_after = args.exit_after 25 | 26 | 27 | models = config.get_models(cfg, device=device) 28 | optimizers = config.get_optimizers(models, cfg) 29 | 30 | 31 | train_loader = config.get_dataloader('train', cfg) 32 | val_loader = config.get_dataloader('val_eval', cfg) 33 | 34 | if cfg['training']['vis_fixviews'] is True: 35 | vis_loader = config.get_dataloader('val_vis', cfg) 36 | else: 37 | vis_loader = None 38 | 39 | 40 | trainer = config.get_trainer(models, optimizers, cfg, device=device) 41 | 42 | trainer.train(train_loader, val_loader, vis_loader, 43 | exit_after=exit_after, n_epochs=None) 44 | --------------------------------------------------------------------------------