├── .gitignore ├── LICENSE.md ├── configs ├── GAN │ ├── 000.yaml │ └── 000_eval_fix.yaml ├── VAE │ ├── 000.yaml │ └── 000_eval_fix.yaml ├── default.yaml └── singleview │ ├── NVS │ └── car.yaml │ └── texfields │ ├── car.yaml │ ├── car_demo.yaml │ ├── car_eval_fix.yaml │ └── car_eval_rnd.yaml ├── demo ├── cc067578ad92517bbe25370c898e25a5 │ ├── input_image │ │ └── 000.jpg │ ├── pointcloud.npz │ └── visualize │ │ ├── depth │ │ ├── 000.exr │ │ ├── 001.exr │ │ ├── 002.exr │ │ ├── 003.exr │ │ ├── 004.exr │ │ └── cameras.npz │ │ └── image │ │ ├── 000.png │ │ ├── 001.png │ │ ├── 002.png │ │ ├── 003.png │ │ └── 004.png └── test.lst ├── environment.yaml ├── evaluate.py ├── generate.py ├── gfx └── clips │ ├── blackroof.gif │ └── header.png ├── mesh2tex ├── .gitignore ├── __init__.py ├── checkpoints.py ├── common.py ├── config.py ├── data │ ├── __init__.py │ ├── core.py │ ├── fields.py │ └── transforms.py ├── eval.py ├── geometry │ ├── __init__.py │ └── pointnet.py ├── layers.py ├── nvs │ ├── __init__.py │ ├── config.py │ ├── generation.py │ ├── models │ │ ├── __init__.py │ │ ├── decoder.py │ │ ├── discriminator.py │ │ └── encoder.py │ └── training.py ├── texnet │ ├── __init__.py │ ├── config.py │ ├── generation.py │ ├── models │ │ ├── __init__.py │ │ ├── decoder.py │ │ ├── discriminator.py │ │ ├── image_encoder.py │ │ └── vae_encoder.py │ └── training.py ├── training.py └── utils │ ├── FID │ ├── feature_l1.py │ ├── fid_score.py │ └── inception.py │ ├── SSIM_L1 │ └── ssim_l1_score.py │ ├── __init__.py │ └── io.py ├── readme.md ├── scripts └── download_data.sh └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/synthetic_combined 2 | data/shapenet -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Michael Oechsle, Lars Mescheder, Michael Niemeyer, Thilo Strauss, Andreas Geiger 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/GAN/000.yaml: -------------------------------------------------------------------------------- 1 | method: texnet 2 | data: 3 | path_shapes: data/shapenet/synthetic_cars_nospecular/ 4 | dataset_imgs_type: image_folder 5 | img_size: 128 6 | training: 7 | out_dir: 'out/GAN/car' 8 | batch_size: 32 9 | model_selection_metric: null 10 | model_selection_mode: maximize 11 | print_every: 10 12 | visualize_every: 100 13 | checkpoint_every: 1000 14 | validate_every: 1000 15 | backup_every: 10000 16 | moving_average_beta: 0.99 17 | gradient_penalties_reg: 10. 18 | lr_g: 0.0001 19 | lr_d: 0.0001 20 | multi_gpu: false 21 | vis_fixviews: True 22 | weight_pixelloss: 0. 23 | weight_ganloss: 1. 24 | weight_vaeloss: 0. 25 | experiment: 'generative' 26 | model: 27 | decoder: each_layer_c 28 | encoder: 29 | geometry_encoder: simple 30 | discriminator: resnet_conditional 31 | vae_encoder: 32 | encoder_kwargs: 33 | vae_encoder_kwargs: 34 | decoder_kwargs: 35 | leaky: True 36 | geometry_encoder_kwargs: 37 | leaky: True 38 | generator_bg_kwargs: 39 | leaky: True 40 | discriminator_kwargs: 41 | leaky: True 42 | z_dim: 64 43 | c_dim: 512 44 | white_bg: True 45 | model_url: 'https://s3.eu-central-1.amazonaws.com/avg-projects/texture_fields/models/gan_car-360b7ce7.pt' 46 | generation: 47 | batch_size: 10 48 | test: 49 | model_file: model.pt 50 | vis_dir: 'out/GAN/car/eval_fix' 51 | dataset_split: 'test_vis' 52 | with_occnet: False 53 | 54 | -------------------------------------------------------------------------------- /configs/GAN/000_eval_fix.yaml: -------------------------------------------------------------------------------- 1 | method: texnet 2 | inherit_from: configs/GAN/000.yaml 3 | data: 4 | with_shuffle: False 5 | training: 6 | vis_fixviews: True 7 | generation: 8 | batch_size: 1 9 | test: 10 | model_file: model.pt 11 | vis_dir: 'out/GAN/car/eval_fix' 12 | dataset_split: 'test_vis' 13 | with_occnet: False 14 | generation_mode: 'gan' -------------------------------------------------------------------------------- /configs/VAE/000.yaml: -------------------------------------------------------------------------------- 1 | method: texnet 2 | data: 3 | path_shapes: data/shapenet/synthetic_cars_nospecular/ 4 | dataset_imgs_type: image_folder 5 | img_size: 128 6 | training: 7 | out_dir: 'out/VAE/car' 8 | batch_size: 3 9 | model_selection_metric: null 10 | model_selection_mode: maximize 11 | print_every: 100 12 | visualize_every: 1000 13 | checkpoint_every: 1000 14 | validate_every: 1000 15 | backup_every: 1000000 16 | moving_average_beta: 0 17 | pc_subsampling: 2048 18 | vis_fixviews: True 19 | weight_pixelloss: 1. 20 | weight_ganloss: 0. 21 | weight_vaeloss: 10. 22 | experiment: 'generative' 23 | gradient_penalties_reg: 0. 24 | model: 25 | decoder: each_layer_c 26 | encoder: 27 | vae_encoder: resnet 28 | geometry_encoder: simple 29 | decoder_kwargs: 30 | leaky: True 31 | resnet_leaky: True 32 | encoder_kwargs: {} 33 | vae_encoder_kwargs: 34 | leaky: True 35 | geometry_encoder_kwargs: 36 | leaky: True 37 | z_dim: 512 38 | c_dim: 512 39 | white_bg: True 40 | model_url: 'https://s3.eu-central-1.amazonaws.com/avg-projects/texture_fields/models/vae_car-f141e128.pt' 41 | 42 | generation: 43 | batch_size: 3 44 | test: 45 | model_file: model.pt 46 | vis_dir: 'out/VAE/car/eval_vis' 47 | dataset_split: 'test_vis' 48 | with_occnet: False 49 | -------------------------------------------------------------------------------- /configs/VAE/000_eval_fix.yaml: -------------------------------------------------------------------------------- 1 | method: texnet 2 | inherit_from: configs/VAE/000.yaml 3 | data: 4 | img_size: 128 5 | training: 6 | vis_fixviews: True 7 | generation: 8 | batch_size: 3 9 | test: 10 | model_file: model.pt 11 | vis_dir: 'out/VAE/car/eval_fix' 12 | dataset_split: 'test_vis' 13 | with_occnet: False 14 | generation_mode: 'vae' -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | method: texnet 2 | data: 3 | path_imgs: null 4 | path_shapes: null 5 | dataset_imgs_type: image_folder 6 | shapes_multiclass: false 7 | classes_shapes: null 8 | classes_imgs: null 9 | img_size: 64 10 | pcl_knn: null 11 | with_shuffle: True 12 | training: 13 | out_dir: 'out' 14 | batch_size: 64 15 | model_selection_metric: none 16 | model_selection_mode: maximize 17 | print_every: 100 18 | visualize_every: 1000 19 | checkpoint_every: 1000 20 | validate_every: 10000 21 | backup_every: 10000 22 | moving_average_beta: null 23 | lr_g: 0.0001 24 | lr_d: 0.0001 25 | gradient_penalties_reg: 10. 26 | multi_gpu: false 27 | pc_subsampling: 2048 28 | vis_fixviews: 29 | weight_pixelloss: 0.0 30 | weight_ganloss: 0.0 31 | weight_vaeloss: 0.0 32 | model: 33 | decoder: simple 34 | geometry_encoder: simple 35 | generator_bg: resnet 36 | discriminator: resnet_conditional 37 | vae_encoder: 38 | decoder_kwargs: 39 | resnet_leaky: True 40 | geometry_encoder_kwargs: {} 41 | generator_bg_kwargs: {} 42 | discriminator_kwargs: {} 43 | vae_encoder_kwargs: {} 44 | z_dim: 128 45 | c_dim: 128 46 | white_bg: False 47 | gan_setting: conditional 48 | model_url: 49 | test: 50 | model_file: model_best.pt 51 | vis_dir: 52 | for_eval: False 53 | dataset_split: 'test' 54 | for_vis: False 55 | with_occnet: False 56 | interpol: False 57 | generate_grid: False 58 | generation_mode: 'HD' -------------------------------------------------------------------------------- /configs/singleview/NVS/car.yaml: -------------------------------------------------------------------------------- 1 | method: nvs 2 | data: 3 | path_shapes: data/shapenet/synthetic_cars_nospecular/ 4 | dataset_imgs_type: image_folder 5 | img_size: 256 6 | training: 7 | out_dir: 'out/nvs/car' 8 | batch_size: 32 9 | model_selection_metric: loss_val 10 | model_selection_mode: minimize 11 | print_every: 100 12 | visualize_every: 10000 13 | checkpoint_every: 10000 14 | validate_every: 10000 15 | backup_every: 100000 16 | moving_average_beta: 0 17 | pc_subsampling: 2048 18 | vis_fixviews: True 19 | gradient_penalties_reg: 0. 20 | weight_pixelloss: 1. 21 | weight_ganloss: 0. 22 | experiment: 'conditional' 23 | model: 24 | decoder: each_layer_c 25 | encoder: resnet18 26 | geometry_encoder: null 27 | generator_bg: null 28 | discriminator: resnet 29 | decoder_kwargs: {} 30 | encoder_kwargs: {} 31 | geometry_encoder_kwargs: {} 32 | generator_bg_kwargs: {} 33 | discriminator_kwargs: {} 34 | z_dim: 0 35 | c_dim: 512 36 | white_bg: True 37 | generation: 38 | batch_size: 32 39 | test: 40 | model_file: model_best.pt 41 | vis_dir: 'out/nvs/car/eval_fix' 42 | dataset_split: 'test_vis' 43 | with_occnet: False 44 | generation_mode: 'HD' -------------------------------------------------------------------------------- /configs/singleview/texfields/car.yaml: -------------------------------------------------------------------------------- 1 | method: texnet 2 | data: 3 | path_shapes: data/shapenet/synthetic_cars_nospecular/ 4 | dataset_imgs_type: image_folder 5 | img_size: 128 6 | training: 7 | out_dir: 'out/singleview/car' 8 | batch_size: 18 9 | model_selection_metric: loss_val 10 | model_selection_mode: minimize 11 | print_every: 100 12 | visualize_every: 1000 13 | checkpoint_every: 1000 14 | validate_every: 1000 15 | backup_every: 100000 16 | moving_average_beta: 0 17 | pc_subsampling: 2048 18 | vis_fixviews: True 19 | weight_pixelloss: 1. 20 | weight_ganloss: 0. 21 | experiment: 'conditional' 22 | gradient_penalties_reg: 0. 23 | model: 24 | decoder: each_layer_c_larger 25 | encoder: resnet18 26 | geometry_encoder: resnet 27 | decoder_kwargs: 28 | leaky: True 29 | resnet_leaky: False 30 | encoder_kwargs: {} 31 | geometry_encoder_kwargs: {} 32 | generator_bg_kwargs: {} 33 | discriminator_kwargs: {} 34 | z_dim: 512 35 | c_dim: 512 36 | white_bg: True 37 | model_url: 38 | generation: 39 | batch_size: 1 40 | test: 41 | model_file: model_best.pt 42 | vis_dir: 'out/singleview/car/eval_fix/' 43 | dataset_split: 'test_vis' 44 | with_occnet: False 45 | generation_mode: 'HD' -------------------------------------------------------------------------------- /configs/singleview/texfields/car_demo.yaml: -------------------------------------------------------------------------------- 1 | method: texnet 2 | inherit_from: configs/singleview/texfields/car.yaml 3 | data: 4 | path_shapes: data/demo 5 | dataset_imgs_type: image_folder 6 | img_size: 256 7 | model: 8 | model_url: 'https://s3.eu-central-1.amazonaws.com/avg-projects/texture_fields/models/car-b3b2a506.pt' 9 | generation: 10 | batch_size: 1 11 | test: 12 | model_file: model_best.pt 13 | vis_dir: 'out/demo' 14 | dataset_split: 'test_vis' 15 | with_occnet: False 16 | generation_mode: 'HD' -------------------------------------------------------------------------------- /configs/singleview/texfields/car_eval_fix.yaml: -------------------------------------------------------------------------------- 1 | method: texnet 2 | inherit_from: configs/singleview/texfields/car.yaml 3 | data: 4 | # path_shapes: data/synthetic_combined/02958343 5 | dataset_imgs_type: image_folder 6 | img_size: 256 7 | model: 8 | model_url: 'https://s3.eu-central-1.amazonaws.com/avg-projects/texture_fields/models/car-b3b2a506.pt' 9 | generation: 10 | batch_size: 1 11 | test: 12 | vis_dir: 'out/singleview/car/eval_fix' 13 | dataset_split: 'test_vis' 14 | with_occnet: False 15 | generation_mode: 'HD' -------------------------------------------------------------------------------- /configs/singleview/texfields/car_eval_rnd.yaml: -------------------------------------------------------------------------------- 1 | method: texnet 2 | inherit_from: configs/singleview/texfields/car.yaml 3 | data: 4 | # path_shapes: data//02958343/ 5 | dataset_imgs_type: image_folder 6 | img_size: 256 7 | model: 8 | model_url: 'https://s3.eu-central-1.amazonaws.com/avg-projects/texture_fields/models/car-b3b2a506.pt' 9 | generation: 10 | batch_size: 100 11 | test: 12 | model_file: model_best.pt 13 | vis_dir: 'out/singleview/car/eval_rnd' 14 | dataset_split: 'test_eval' 15 | with_occnet: False 16 | generation_mode: 'HD' -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/input_image/000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/input_image/000.jpg -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/pointcloud.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/pointcloud.npz -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/visualize/depth/000.exr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/visualize/depth/000.exr -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/visualize/depth/001.exr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/visualize/depth/001.exr -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/visualize/depth/002.exr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/visualize/depth/002.exr -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/visualize/depth/003.exr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/visualize/depth/003.exr -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/visualize/depth/004.exr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/visualize/depth/004.exr -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/visualize/depth/cameras.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/visualize/depth/cameras.npz -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/visualize/image/000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/visualize/image/000.png -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/visualize/image/001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/visualize/image/001.png -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/visualize/image/002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/visualize/image/002.png -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/visualize/image/003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/visualize/image/003.png -------------------------------------------------------------------------------- /demo/cc067578ad92517bbe25370c898e25a5/visualize/image/004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/demo/cc067578ad92517bbe25370c898e25a5/visualize/image/004.png -------------------------------------------------------------------------------- /demo/test.lst: -------------------------------------------------------------------------------- 1 | cc067578ad92517bbe25370c898e25a5 -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: texturefields 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - imageio=2.4.1 8 | - numpy=1.15.4 9 | - numpy-base=1.15.4 10 | - matplotlib=3.0.3 11 | - matplotlib-base=3.0.3 12 | - pandas=0.23.4 13 | - pillow=5.3.0 14 | - pyembree=0.1.4 15 | - pytest=4.0.2 16 | - python=3.6.7 17 | - pytorch=1.0.0 18 | - pyyaml=3.13 19 | - scikit-image=0.14.1 20 | - scipy=1.1.0 21 | - tensorboardx=1.4 22 | - torchvision=0.2.1 23 | - tqdm=4.28.1 24 | - trimesh=2.37.7 25 | - pip: 26 | - h5py==2.9.0 27 | - plyfile==0.7 28 | - lmdb -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | import os 4 | import glob 5 | 6 | from mesh2tex import config 7 | from mesh2tex.eval import evaluate_generated_images 8 | 9 | categories = {'02958343': 'cars', '03001627': 'chairs', 10 | '02691156': 'airplanes', '04379243': 'tables', 11 | '02828884': 'benches', '02933112': 'cabinets', 12 | '04256520': 'sofa', '03636649': 'lamps', 13 | '04530566': 'vessels'} 14 | 15 | parser = argparse.ArgumentParser( 16 | description='Generate Color for given mesh.' 17 | ) 18 | 19 | parser.add_argument('config', type=str, help='Path to config file.') 20 | 21 | args = parser.parse_args() 22 | cfg = config.load_config(args.config, 'configs/default.yaml') 23 | base_path = cfg['test']['vis_dir'] 24 | 25 | 26 | if cfg['data']['shapes_multiclass']: 27 | category_paths = glob.glob(os.path.join(base_path, '*')) 28 | else: 29 | category_paths = [base_path] 30 | 31 | for category_path in category_paths: 32 | cat_id = os.path.basename(category_path) 33 | category = categories.get(cat_id, cat_id) 34 | path1 = os.path.join(category_path, 'fake/') 35 | path2 = os.path.join(category_path, 'real/') 36 | print('Evaluating %s (%s)' % (category, category_path)) 37 | 38 | evaluation = evaluate_generated_images('all', path1, path2) 39 | name = base_path 40 | 41 | df = pd.DataFrame(evaluation, index=[category]) 42 | df.to_pickle(os.path.join(category_path, 'eval.pkl')) 43 | df.to_csv(os.path.join(category_path, 'eval.csv')) 44 | 45 | print('Evaluation finished') 46 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import argparse 4 | from tqdm import tqdm 5 | from mesh2tex import data 6 | from mesh2tex import config 7 | from mesh2tex.checkpoints import CheckpointIO 8 | 9 | # Get arguments and Config 10 | parser = argparse.ArgumentParser( 11 | description='Generate Color for given mesh.') 12 | parser.add_argument('config', type=str, help='Path to config file.') 13 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.') 14 | args = parser.parse_args() 15 | cfg = config.load_config(args.config, 'configs/default.yaml') 16 | 17 | # Define device 18 | is_cuda = (torch.cuda.is_available() and not args.no_cuda) 19 | device = torch.device("cuda" if is_cuda else "cpu") 20 | 21 | # Read config 22 | out_dir = cfg['training']['out_dir'] 23 | vis_dir = cfg['test']['vis_dir'] 24 | split = cfg['test']['dataset_split'] 25 | if split != 'test_vis' and split != 'test_eval': 26 | print('Are you sure not using test data?') 27 | batch_size = cfg['generation']['batch_size'] 28 | gen_mode = cfg['test']['generation_mode'] 29 | model_url = cfg['model']['model_url'] 30 | 31 | # Dataset 32 | dataset = config.get_dataset(split, cfg, input_sampling=False) 33 | if cfg['data']['shapes_multiclass']: 34 | datasets = dataset.datasets_classes 35 | else: 36 | datasets = [dataset] 37 | 38 | # Load Model 39 | models = config.get_models(cfg, device=device, dataset=dataset) 40 | model_g = models['generator'] 41 | checkpoint_io = CheckpointIO(out_dir, model_g=model_g) 42 | if model_url is None: 43 | checkpoint_io.load(cfg['test']['model_file']) 44 | else: 45 | checkpoint_io.load(cfg['model']['model_url']) 46 | 47 | # Assign Generator 48 | generator = config.get_generator(model_g, cfg, device) 49 | 50 | # data iteration loop 51 | for i_ds, ds in enumerate(datasets): 52 | ds_id = ds.metadata.get('id', str(i_ds)) 53 | ds_name = ds.metadata.get('name', 'n/a') 54 | 55 | if cfg['data']['shapes_multiclass']: 56 | out_dir = os.path.join(vis_dir, ds_id) 57 | else: 58 | out_dir = vis_dir 59 | 60 | test_loader = torch.utils.data.DataLoader( 61 | ds, batch_size=batch_size, num_workers=12, shuffle=False, 62 | collate_fn=data.collate_remove_none) 63 | 64 | batch_counter = 0 65 | 66 | def get_batch_size(batch): 67 | batch_size = next(iter(batch.values())).shape[0] 68 | return batch_size 69 | 70 | for batch in tqdm(test_loader): 71 | offset_batch = batch_size * batch_counter 72 | 73 | model_names = [ 74 | ds.get_model(i) for i in batch['idx'] 75 | ] 76 | 77 | if gen_mode == 'interpolate': 78 | out = generator.generate_images_4eval_vae_interpol(batch, 79 | out_dir, 80 | model_names) 81 | elif gen_mode == 'vae': 82 | out = generator.generate_images_4eval_vae(batch, 83 | out_dir, 84 | model_names) 85 | elif gen_mode == 'gan': 86 | out = generator.generate_images_4eval_gan(batch, 87 | out_dir, 88 | model_names) 89 | elif gen_mode == 'interpolate_rotation': 90 | out = generator.generate_images_4eval_vae_inter_rot(batch, 91 | out_dir, 92 | model_names) 93 | elif gen_mode == 'HD': 94 | generator.generate_images_4eval_condi_hd(batch, 95 | out_dir, 96 | model_names) 97 | 98 | elif gen_mode == 'SD': 99 | generator.generate_images_4eval_condi(batch, 100 | out_dir, 101 | model_names) 102 | 103 | elif gen_mode == 'grid': 104 | out = generator.generate_grid(batch, 105 | out_dir, 106 | model_names) 107 | 108 | elif gen_mode == 'test': 109 | out = generator.generate_images_occnet(batch, 110 | out_dir, 111 | model_names) 112 | else: 113 | print('Modes: HD, grid, interpolate, interpolate_rotation, test') 114 | 115 | batch_counter += 1 116 | -------------------------------------------------------------------------------- /gfx/clips/blackroof.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/gfx/clips/blackroof.gif -------------------------------------------------------------------------------- /gfx/clips/header.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/gfx/clips/header.png -------------------------------------------------------------------------------- /mesh2tex/.gitignore: -------------------------------------------------------------------------------- 1 | .pyc 2 | *.pyc -------------------------------------------------------------------------------- /mesh2tex/__init__.py: -------------------------------------------------------------------------------- 1 | from mesh2tex import geometry 2 | 3 | 4 | __all__ = [ 5 | geometry 6 | ] -------------------------------------------------------------------------------- /mesh2tex/checkpoints.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib 3 | import torch 4 | from torch.utils import model_zoo 5 | 6 | 7 | class CheckpointIO(object): 8 | def __init__(self, checkpoint_dir='./chkpts', **kwargs): 9 | self.module_dict = kwargs 10 | self.checkpoint_dir = checkpoint_dir 11 | 12 | if not os.path.exists(checkpoint_dir): 13 | os.makedirs(checkpoint_dir) 14 | 15 | def register_modules(self, **kwargs): 16 | self.module_dict.update(kwargs) 17 | 18 | def save(self, filename, **kwargs): 19 | filename = os.path.join(self.checkpoint_dir, filename) 20 | 21 | outdict = kwargs 22 | for k, v in self.module_dict.items(): 23 | outdict[k] = v.state_dict() 24 | torch.save(outdict, filename) 25 | 26 | def load(self, filename): 27 | '''Loads a module dictionary from local file or url. 28 | 29 | Args: 30 | filename (str): name of saved module dictionary 31 | ''' 32 | print(filename) 33 | if is_url(filename): 34 | return self.load_url(filename) 35 | else: 36 | return self.load_file(filename) 37 | 38 | def load_url(self, url): 39 | '''Load a module dictionary from url. 40 | 41 | Args: 42 | url (str): url to saved model 43 | ''' 44 | print(url) 45 | print('=> Loading checkpoint from url...') 46 | out_dict = model_zoo.load_url(url, progress=True) 47 | # scalars = self.parse_state_dict(state_dict) 48 | for k, v in self.module_dict.items(): 49 | print("Start loading: %s" % k) 50 | if k in out_dict: 51 | # print(out_dict[k]) 52 | v.load_state_dict(out_dict[k]) 53 | print("Finished: %s" % k) 54 | else: 55 | print('Warning: Could not find %s in checkpoint!' % k) 56 | scalars = {k: v for k, v in out_dict.items() 57 | if k not in self.module_dict} 58 | return scalars 59 | 60 | def load_file(self, filename): 61 | filename = os.path.join(self.checkpoint_dir, filename) 62 | 63 | if os.path.exists(filename): 64 | 65 | print('=> Loading checkpoint...') 66 | out_dict = torch.load(filename) 67 | for k, v in self.module_dict.items(): 68 | print("Start loading: %s" % k) 69 | if k in out_dict: 70 | # print(out_dict[k]) 71 | v.load_state_dict(out_dict[k]) 72 | print("Finished: %s" % k) 73 | else: 74 | print('Warning: Could not find %s in checkpoint!' % k) 75 | scalars = {k: v for k, v in out_dict.items() 76 | if k not in self.module_dict} 77 | return scalars 78 | else: 79 | raise FileExistsError 80 | 81 | def is_url(url): 82 | scheme = urllib.parse.urlparse(url).scheme 83 | return scheme in ('http', 'https') 84 | 85 | -------------------------------------------------------------------------------- /mesh2tex/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy.spatial import cKDTree as KDTree 3 | 4 | 5 | def get_nearest_neighbors_indices_batch(points_src, points_tgt, k=1): 6 | indices = [] 7 | distances = [] 8 | 9 | for (p1, p2) in zip(points_src, points_tgt): 10 | p1 = p1.detach().cpu().numpy().T 11 | p2 = p2.detach().cpu().numpy().T 12 | 13 | kdtree = KDTree(p2) 14 | dist, idx = kdtree.query(p1, k=k, n_jobs=-1) 15 | indices.append(idx) 16 | distances.append(dist) 17 | 18 | indices = torch.LongTensor(indices) 19 | distances = torch.FloatTensor(distances) 20 | 21 | return indices, distances 22 | 23 | 24 | def normalize_imagenet(x): 25 | x = x.clone() 26 | x[:, 0] = (x[:, 0] - 0.485) / 0.229 27 | x[:, 1] = (x[:, 1] - 0.456) / 0.224 28 | x[:, 2] = (x[:, 2] - 0.406) / 0.225 29 | return x 30 | -------------------------------------------------------------------------------- /mesh2tex/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from mesh2tex import texnet 3 | from mesh2tex import nvs 4 | method_dict = { 5 | 'texnet': texnet, 6 | 'nvs': nvs, 7 | } 8 | 9 | 10 | # General config 11 | def load_config(path, default_path=None): 12 | # Load configuration from file itself 13 | with open(path, 'r') as f: 14 | cfg_special = yaml.load(f) 15 | 16 | # Check if we should inherit from a config 17 | inherit_from = cfg_special.get('inherit_from') 18 | 19 | # If yes, load this config first as default 20 | # If no, use the default_path 21 | if inherit_from is not None: 22 | cfg = load_config(inherit_from, default_path) 23 | elif default_path is not None: 24 | with open(default_path, 'r') as f: 25 | cfg = yaml.load(f) 26 | else: 27 | cfg = dict() 28 | 29 | # Include main configuration 30 | update_recursive(cfg, cfg_special) 31 | 32 | return cfg 33 | 34 | 35 | def update_recursive(dict1, dict2): 36 | for k, v in dict2.items(): 37 | if k not in dict1: 38 | dict1[k] = dict() 39 | if isinstance(v, dict): 40 | update_recursive(dict1[k], v) 41 | else: 42 | dict1[k] = v 43 | 44 | 45 | # Individual configs 46 | def get_models(cfg, dataset=None, device=None): 47 | method = cfg['method'] 48 | models = method_dict[method].config.get_models(cfg, 49 | dataset=dataset, 50 | device=device) 51 | return models 52 | 53 | 54 | def get_optimizers(models, cfg): 55 | method = cfg['method'] 56 | optimizers = method_dict[method].config.get_optimizers(models, cfg) 57 | return optimizers 58 | 59 | 60 | def get_dataset(split, cfg, input_sampling=True): 61 | method = cfg['method'] 62 | dataset = method_dict[method].config.get_dataset(split, cfg, 63 | input_sampling) 64 | return dataset 65 | 66 | 67 | def get_dataloader(split, cfg): 68 | method = cfg['method'] 69 | dataloader = method_dict[method].config.get_dataloader(split, cfg) 70 | return dataloader 71 | 72 | 73 | def get_meshloader(split, cfg): 74 | method = cfg['method'] 75 | dataloader = method_dict[method].config.get_meshloader(split, cfg) 76 | return dataloader 77 | 78 | 79 | def get_generator(model, cfg, device): 80 | method = cfg['method'] 81 | generator = method_dict[method].config.get_generator(model, cfg, device) 82 | return generator 83 | 84 | 85 | def get_trainer(models, optimizers, cfg, device=None): 86 | method = cfg['method'] 87 | print("method: " + method) 88 | trainer = method_dict[method].config.get_trainer( 89 | models, optimizers, cfg, device=device) 90 | return trainer 91 | -------------------------------------------------------------------------------- /mesh2tex/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from mesh2tex.data.core import ( 3 | Shapes3dDataset, Shapes3dClassDataset, 4 | CombinedDataset, 5 | collate_remove_none, worker_init_fn 6 | ) 7 | from mesh2tex.data.fields import ( 8 | ImagesField, PointCloudField, 9 | DepthImageField, MeshField, 10 | DepthImageVisualizeField, IndexField, 11 | ) 12 | 13 | from mesh2tex.data.transforms import ( 14 | PointcloudNoise, SubsamplePointcloud, 15 | ComputeKNNPointcloud, 16 | ImageToGrayscale, ResizeImage, 17 | ImageToDepthValue 18 | ) 19 | 20 | 21 | __all__ = [ 22 | # Core 23 | Shapes3dDataset, 24 | Shapes3dClassDataset, 25 | CombinedDataset, 26 | collate_remove_none, 27 | worker_init_fn, 28 | # Fields 29 | ImagesField, 30 | PointCloudField, 31 | DepthImageField, 32 | MeshField, 33 | DepthImageVisualizeField, 34 | IndexField, 35 | # Transforms 36 | PointcloudNoise, 37 | SubsamplePointcloud, 38 | ComputeKNNPointcloud, 39 | ImageToGrayscale, 40 | ImageToDepthValue, 41 | ResizeImage, 42 | ] 43 | -------------------------------------------------------------------------------- /mesh2tex/data/core.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from torch.utils import data 4 | import numpy as np 5 | import yaml 6 | 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | # Fields 12 | class Field(object): 13 | def load(self, data_path, idx): 14 | raise NotImplementedError 15 | 16 | def check_complete(self, files): 17 | raise NotImplementedError 18 | 19 | 20 | class Shapes3dDataset(data.Dataset): 21 | def __init__(self, dataset_folder, fields, split=None, 22 | classes=None, no_except=True, transform=None): 23 | # Read metadata file 24 | metadata_file = os.path.join(dataset_folder, 'metadata.yaml') 25 | if os.path.exists(metadata_file): 26 | with open(metadata_file, 'r') as f: 27 | metadata = yaml.load(f) 28 | else: 29 | metadata = {} 30 | 31 | # If classes is None, use all subfolders 32 | if classes is None: 33 | classes = os.listdir(dataset_folder) 34 | classes = [c for c in classes 35 | if os.path.isdir(os.path.join(dataset_folder, c))] 36 | 37 | # Get all sub-datasets 38 | self.datasets_classes = [] 39 | for c in classes: 40 | subpath = os.path.join(dataset_folder, c) 41 | if not os.path.isdir(subpath): 42 | logger.warning('Class %s does not exist in dataset.' % c) 43 | 44 | metadata_c = metadata.get(c, {'id': c, 'name': 'n/a'}) 45 | dataset = Shapes3dClassDataset(subpath, fields, split, 46 | metadata_c, no_except, 47 | transform=transform) 48 | self.datasets_classes.append(dataset) 49 | 50 | self._concat_dataset = data.ConcatDataset(self.datasets_classes) 51 | 52 | def __len__(self): 53 | return len(self._concat_dataset) 54 | 55 | def __getitem__(self, idx): 56 | return self._concat_dataset[idx] 57 | 58 | 59 | class Shapes3dClassDataset(data.Dataset): 60 | def __init__(self, dataset_folder, fields, split=None, 61 | metadata=dict(), no_except=True, transform=None): 62 | self.dataset_folder = dataset_folder 63 | self.fields = fields 64 | self.metadata = metadata 65 | self.no_except = no_except 66 | self.transform = transform 67 | # Get (filtered) model list 68 | if split is None: 69 | models = [ 70 | f for f in os.listdir(dataset_folder) 71 | if os.path.isdir(os.path.join(dataset_folder, f)) 72 | ] 73 | else: 74 | split_file = os.path.join(dataset_folder, split + '.lst') 75 | with open(split_file, 'r') as f: 76 | models = f.read().split('\n') 77 | 78 | # self.models = list(filter(self.test_model_complete, models)) 79 | self.models = models 80 | 81 | def test_model_complete(self, model): 82 | model_path = os.path.join(self.dataset_folder, model) 83 | files = os.listdir(model_path) 84 | for field_name, field in self.fields.items(): 85 | if not field.check_complete(files): 86 | logger.warn('Field "%s" is incomplete: %s' 87 | % (field_name, model_path)) 88 | return False 89 | else: 90 | return True 91 | 92 | def __len__(self): 93 | return len(self.models) 94 | 95 | def __getitem__(self, idx): 96 | model = self.models[idx] 97 | model_path = os.path.join(self.dataset_folder, model) 98 | data = {} 99 | for field_name, field in self.fields.items(): 100 | try: 101 | field_data = field.load(model_path, idx) 102 | except Exception: 103 | if self.no_except: 104 | logger.warn( 105 | 'Error occured when loading field %s of model %s' 106 | % (field_name, model) 107 | ) 108 | return None 109 | else: 110 | raise 111 | if isinstance(field_data, dict): 112 | 113 | for k, v in field_data.items(): 114 | if k is None: 115 | data[field_name] = v 116 | else: 117 | data['%s.%s' % (field_name, k)] = v 118 | else: 119 | data[field_name] = field_data 120 | if self.transform is not None: 121 | data = self.transform(data) 122 | return data 123 | 124 | def get_model(self, idx): 125 | return self.models[idx] 126 | 127 | 128 | class CombinedDataset(data.Dataset): 129 | def __init__(self, datasets, idx_main=0): 130 | self.datasets = datasets 131 | self.idx_main = idx_main 132 | 133 | def __len__(self): 134 | return len(self.datasets[self.idx_main]) 135 | 136 | def __getitem__(self, idx): 137 | out = [] 138 | for it, ds in enumerate(self.datasets): 139 | if it != self.idx_main: 140 | x_idx = np.random.randint(0, len(ds)) 141 | else: 142 | x_idx = idx 143 | out.append(ds[x_idx]) 144 | return out 145 | 146 | 147 | # Collater 148 | def collate_remove_none(batch): 149 | "Puts each data field into a tensor with outer dimension batch size" 150 | batch = list(filter(check_element_valid, batch)) 151 | return data.dataloader.default_collate(batch) 152 | 153 | 154 | def check_element_valid(batch): 155 | if batch is None: 156 | return False 157 | elif isinstance(batch, list): 158 | for b in batch: 159 | if not check_element_valid(b): 160 | return False 161 | elif isinstance(batch, dict): 162 | for b in batch.values(): 163 | if not check_element_valid(b): 164 | return False 165 | return True 166 | 167 | 168 | # Worker initialization to ensure true randomeness 169 | def worker_init_fn(worker_id): 170 | random_data = os.urandom(4) 171 | base_seed = int.from_bytes(random_data, byteorder="big") 172 | np.random.seed(base_seed + worker_id) 173 | -------------------------------------------------------------------------------- /mesh2tex/data/fields.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import numpy as np 5 | import trimesh 6 | import imageio 7 | from mesh2tex.data.core import Field 8 | 9 | 10 | # Make sure loading xlr works 11 | imageio.plugins.freeimage.download() 12 | 13 | 14 | # Basic index field 15 | class IndexField(Field): 16 | def load(self, model_path, idx): 17 | return idx 18 | 19 | def check_complete(self, files): 20 | return True 21 | 22 | 23 | class MeshField(Field): 24 | def __init__(self, folder_name, transform=None): 25 | self.folder_name = folder_name 26 | self.transform = transform 27 | 28 | def load(self, model_path, idx): 29 | folder_path = os.path.join(model_path, self.folder_name) 30 | file_path = os.path.join(folder_path, 'model.off') 31 | mesh = trimesh.load(file_path, process=False) 32 | if self.transform is not None: 33 | mesh = self.transform(mesh) 34 | 35 | data = { 36 | 'vertices': np.array(mesh.vertices), 37 | 'faces': np.array(mesh.faces), 38 | } 39 | 40 | return data 41 | 42 | def check_complete(self, files): 43 | complete = (self.folder_name in files) 44 | return complete 45 | 46 | # Image field 47 | class ImagesField(Field): 48 | def __init__(self, folder_name, transform=None, 49 | extension='jpg', random_view=True, 50 | with_camera=False, 51 | imageio_kwargs=dict()): 52 | self.folder_name = folder_name 53 | self.transform = transform 54 | self.extension = extension 55 | self.random_view = random_view 56 | self.with_camera = with_camera 57 | self.imageio_kwargs = dict() 58 | 59 | def load(self, model_path, idx): 60 | folder = os.path.join(model_path, self.folder_name) 61 | files = glob.glob(os.path.join(folder, '*.%s' % self.extension)) 62 | files.sort() 63 | if self.random_view: 64 | idx_img = random.randint(0, len(files)-1) 65 | else: 66 | idx_img = 0 67 | filename = files[idx_img] 68 | 69 | image = imageio.imread(filename, **self.imageio_kwargs) 70 | image = np.asarray(image) 71 | 72 | if len(image.shape) == 2: 73 | image = image.reshape(image.shape[0], image.shape[1], 1) 74 | image = np.concatenate([image, image, image], axis=2) 75 | 76 | if image.shape[2] == 4: 77 | image = image[:, :, :3] 78 | 79 | if image.dtype == np.uint8: 80 | image = image.astype(np.float32) / 255 81 | else: 82 | image = image.astype(np.float32) 83 | 84 | if self.transform is not None: 85 | image = self.transform(image) 86 | image = image.transpose(2, 0, 1) 87 | data = { 88 | None: image 89 | } 90 | 91 | if self.with_camera: 92 | camera_file = os.path.join(folder, 'cameras.npz') 93 | camera_dict = np.load(camera_file) 94 | Rt = camera_dict['world_mat_%d' % idx_img].astype(np.float32) 95 | K = camera_dict['camera_mat_%d' % idx_img].astype(np.float32) 96 | data['world_mat'] = Rt 97 | data['camera_mat'] = K 98 | 99 | return data 100 | 101 | def check_complete(self, files): 102 | complete = (self.folder_name in files) 103 | # TODO: check camera 104 | return complete 105 | 106 | 107 | # 3D Fields 108 | class PointCloudField(Field): 109 | def __init__(self, file_name, transform=None, with_transforms=False): 110 | self.file_name = file_name 111 | self.transform = transform 112 | self.with_transforms = with_transforms 113 | 114 | def load(self, model_path, idx): 115 | file_path = os.path.join(model_path, self.file_name) 116 | 117 | pointcloud_dict = np.load(file_path) 118 | 119 | points = pointcloud_dict['points'].astype(np.float32) 120 | normals = pointcloud_dict['normals'].astype(np.float32) 121 | 122 | data = { 123 | None: points.T, 124 | 'normals': normals.T, 125 | } 126 | 127 | if self.with_transforms: 128 | data['loc'] = pointcloud_dict['loc'].astype(np.float32) 129 | data['scale'] = pointcloud_dict['scale'].astype(np.float32) 130 | 131 | if self.transform is not None: 132 | data = self.transform(data) 133 | 134 | return data 135 | 136 | def check_complete(self, files): 137 | complete = (self.file_name in files) 138 | return complete 139 | 140 | 141 | class DepthImageVisualizeField(Field): 142 | def __init__(self, folder_name_img, folder_name_depth, transform_img=None, transform_depth=None, 143 | extension_img='jpg', extension_depth='exr', random_view=True, 144 | with_camera=False, 145 | imageio_kwargs=dict()): 146 | self.folder_name_img = folder_name_img 147 | self.folder_name_depth = folder_name_depth 148 | self.transform_depth = transform_depth 149 | self.transform_img = transform_img 150 | self.extension_img = extension_img 151 | self.extension_depth = extension_depth 152 | self.random_view = random_view 153 | self.with_camera = with_camera 154 | self.imageio_kwargs = dict() 155 | 156 | def load(self, model_path, idx): 157 | folder_img = os.path.join(model_path, self.folder_name_img) 158 | files_img = glob.glob(os.path.join(folder_img, '*.%s' % self.extension_img)) 159 | files_img.sort() 160 | folder_depth = os.path.join(model_path, self.folder_name_depth) 161 | files_depth = glob.glob(os.path.join(folder_depth, '*.%s' % self.extension_depth)) 162 | files_depth.sort() 163 | if self.random_view: 164 | idx_img = random.randint(0, len(files_img)-1) 165 | else: 166 | idx_img = 0 167 | 168 | image_all = [] 169 | depth_all = [] 170 | Rt = [] 171 | K = [] 172 | camera_file = os.path.join(folder_depth, 'cameras.npz') 173 | camera_dict = np.load(camera_file) 174 | 175 | for i in range(len(files_img)): 176 | filename_img = files_img[i] 177 | filename_depth = files_depth[i] 178 | 179 | image = imageio.imread(filename_img, **self.imageio_kwargs) 180 | image = np.asarray(image) 181 | 182 | if image.shape[2] == 4: 183 | image = image[:,:,:3] 184 | 185 | depth = imageio.imread(filename_depth, **self.imageio_kwargs) 186 | 187 | depth = np.asarray(depth) 188 | 189 | if image.dtype == np.uint8: 190 | image = image.astype(np.float32) / 255 191 | else: 192 | image = image.astype(np.float32) 193 | 194 | if self.transform_img is not None: 195 | image = self.transform_img(image) 196 | 197 | if self.transform_depth is not None: 198 | depth = self.transform_depth(depth) 199 | 200 | image = image.transpose(2, 0, 1) 201 | depth = depth.transpose(2, 0, 1) 202 | image_all.append(image) 203 | depth_all.append(depth) 204 | 205 | camera_file = os.path.join(folder_depth, 'cameras.npz') 206 | camera_dict = np.load(camera_file) 207 | Rt.append(camera_dict['world_mat_%d' % i].astype(np.float32)) 208 | K.append(camera_dict['camera_mat_%d' % i].astype(np.float32)) 209 | 210 | data = { 211 | 'img': np.stack(image_all), 212 | 'depth': np.stack(depth_all) 213 | } 214 | 215 | data['world_mat'] = np.stack(Rt) 216 | data['camera_mat'] = np.stack(K) 217 | 218 | return data 219 | 220 | 221 | # Image field 222 | class DepthImageField(Field): 223 | def __init__(self, folder_name_img, folder_name_depth, transform_img=None, transform_depth=None, 224 | extension_img='jpg', extension_depth='exr', random_view=True, 225 | with_camera=False, 226 | imageio_kwargs=dict()): 227 | self.folder_name_img = folder_name_img 228 | self.folder_name_depth = folder_name_depth 229 | self.transform_depth = transform_depth 230 | self.transform_img = transform_img 231 | self.extension_img = extension_img 232 | self.extension_depth = extension_depth 233 | self.random_view = random_view 234 | self.with_camera = with_camera 235 | self.imageio_kwargs = dict() 236 | 237 | def load(self, model_path, idx): 238 | folder_img = os.path.join(model_path, self.folder_name_img) 239 | files_img = glob.glob(os.path.join(folder_img, '*.%s' % self.extension_img)) 240 | files_img.sort() 241 | folder_depth = os.path.join(model_path, self.folder_name_depth) 242 | files_depth = glob.glob(os.path.join(folder_depth, '*.%s' % self.extension_depth)) 243 | files_depth.sort() 244 | if self.random_view: 245 | idx_img = random.randint(0, len(files_img)-1) 246 | else: 247 | idx_img = 0 248 | 249 | filename_img = files_img[idx_img] 250 | filename_depth = files_depth[idx_img] 251 | 252 | image = imageio.imread(filename_img, **self.imageio_kwargs) 253 | image = np.asarray(image) 254 | if image.shape[2] == 4: 255 | image = image[:,:,:3] 256 | 257 | depth = imageio.imread(filename_depth, **self.imageio_kwargs) 258 | 259 | depth = np.asarray(depth) 260 | 261 | if image.dtype == np.uint8: 262 | image = image.astype(np.float32) / 255 263 | else: 264 | image = image.astype(np.float32) 265 | 266 | if self.transform_img is not None: 267 | image = self.transform_img(image) 268 | 269 | if self.transform_depth is not None: 270 | depth = self.transform_depth(depth) 271 | 272 | image = image.transpose(2, 0, 1) 273 | #TODO adapt depth transpose 274 | depth = depth.transpose(2, 0, 1) 275 | 276 | data = { 277 | 'img': image, 278 | 'depth': depth 279 | } 280 | 281 | if self.with_camera: 282 | camera_file = os.path.join(folder_depth, 'cameras.npz') 283 | camera_dict = np.load(camera_file) 284 | Rt = camera_dict['world_mat_%d' % idx_img].astype(np.float32) 285 | K = camera_dict['camera_mat_%d' % idx_img].astype(np.float32) 286 | data['world_mat'] = Rt 287 | data['camera_mat'] = K 288 | 289 | return data 290 | 291 | def check_complete(self, files): 292 | complete = (self.folder_name_img in files) 293 | # TODO: check camera 294 | return complete -------------------------------------------------------------------------------- /mesh2tex/data/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.transform import resize 3 | from scipy.spatial import cKDTree as KDTree 4 | 5 | 6 | class PointcloudNoise(object): 7 | def __init__(self, stddev): 8 | self.stddev = stddev 9 | 10 | def __call__(self, data): 11 | data_out = data.copy() 12 | points = data[None] 13 | noise = self.stddev * np.random.randn(*points.shape) 14 | noise = noise.astype(np.float32) 15 | data_out[None] = points + noise 16 | return data_out 17 | 18 | 19 | class SubsamplePointcloud(object): 20 | def __init__(self, N): 21 | self.N = N 22 | 23 | def __call__(self, data): 24 | data_out = data.copy() 25 | points = data[None] 26 | normals = data['normals'] 27 | 28 | indices = np.random.randint(points.shape[1], size=self.N) 29 | data_out[None] = points[:, indices] 30 | data_out['normals'] = normals[:, indices] 31 | 32 | return data_out 33 | 34 | 35 | class ComputeKNNPointcloud(object): 36 | def __init__(self, K): 37 | self.K = K 38 | 39 | def __call__(self, data): 40 | data_out = data.copy() 41 | points = data[None] 42 | kdtree = KDTree(points.T) 43 | knn_idx = kdtree.query(points.T, k=self.K)[1] 44 | knn_idx = knn_idx.T 45 | 46 | data_out['knn_idx'] = knn_idx 47 | 48 | return data_out 49 | 50 | 51 | class ImageToGrayscale(object): 52 | def __call__(self, img): 53 | r, g, b = img[..., 0:1], img[..., 1:2], img[..., 2:3] 54 | out = 0.2990 * r + 0.5870 * g + 0.1140 * b 55 | return out 56 | 57 | 58 | class ImageToDepthValue(object): 59 | def __call__(self, img): 60 | return img[..., :1] 61 | 62 | 63 | class ResizeImage(object): 64 | def __init__(self, size, order=1): 65 | self.size = size 66 | self.order = order 67 | 68 | def __call__(self, img): 69 | img_out = resize(img, self.size, order=self.order, 70 | clip=False, mode='constant', 71 | anti_aliasing=False) 72 | img_out = img_out.astype(img.dtype) 73 | return img_out 74 | -------------------------------------------------------------------------------- /mesh2tex/eval.py: -------------------------------------------------------------------------------- 1 | import mesh2tex.utils.FID.fid_score as FID 2 | import mesh2tex.utils.FID.feature_l1 as feature_l1 3 | import mesh2tex.utils.SSIM_L1.ssim_l1_score as SSIM 4 | 5 | 6 | def evaluate_generated_images(metric, path_fake, path_real, batch_size=64): 7 | """ 8 | Start requested evaluation functions 9 | 10 | args: 11 | metric 12 | path_fake (Path to fake images) 13 | path_real (Path to real images) 14 | batch_size 15 | 16 | return: 17 | val_dict: dict with all metrics 18 | """ 19 | if metric == 'FID': 20 | paths = (path_fake, path_real) 21 | value = FID.calculate_fid_given_paths(paths, 22 | batch_size, 23 | True, 24 | 2048) 25 | val_dict = {'FID': value} 26 | 27 | elif metric == 'SSIM_L1': 28 | paths = (path_fake, path_real) 29 | value = SSIM.calculate_ssim_l1_given_paths(paths) 30 | val_dict = {'SSIM': value[0], 31 | 'L1': value[1]} 32 | elif metric == 'FeatL1': 33 | paths = (path_fake, path_real) 34 | value = feature_l1.calculate_feature_l1_given_paths( 35 | paths, batch_size, True, 2048) 36 | val_dict = {'FeatL1': value} 37 | 38 | elif metric == 'all': 39 | paths = (path_fake, path_real) 40 | value = SSIM.calculate_ssim_l1_given_paths(paths) 41 | value_FID = FID.calculate_fid_given_paths( 42 | paths, batch_size, True, 2048) 43 | value_FeatL1 = feature_l1.calculate_feature_l1_given_paths( 44 | paths, batch_size, True, 2048) 45 | val_dict = {'FID': value_FID, 46 | 'SSIM': value[0], 47 | 'L1': value[1], 48 | 'FeatL1': value_FeatL1} 49 | 50 | return val_dict 51 | -------------------------------------------------------------------------------- /mesh2tex/geometry/__init__.py: -------------------------------------------------------------------------------- 1 | from mesh2tex.geometry import ( 2 | pointnet 3 | ) 4 | 5 | 6 | encoder_dict = { 7 | 'simple': pointnet.SimplePointnet, 8 | 'resnet': pointnet.ResnetPointnetConv, 9 | } 10 | 11 | 12 | def get_representation(batch, device=None): 13 | mesh_points = batch['pointcloud'].to(device) 14 | mesh_normals = batch['pointcloud.normals'].to(device) 15 | geom_repr = { 16 | 'points': mesh_points, 17 | 'normals': mesh_normals, 18 | } 19 | if 'pointcloud.knn_idx' in batch: 20 | knn_idx = batch['pointcloud.knn_idx'].to(device) 21 | geom_repr['knn_idx'] = knn_idx 22 | 23 | return geom_repr 24 | 25 | -------------------------------------------------------------------------------- /mesh2tex/geometry/pointnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from mesh2tex.layers import EqualizedLR 5 | from mesh2tex.layers import ResnetBlockFC, ResnetBlockConv1D 6 | 7 | 8 | class SimplePointnet(nn.Module): 9 | def __init__(self, c_dim=128, hidden_dim=128, 10 | leaky=False, eq_lr=False): 11 | super().__init__() 12 | # Attributes 13 | self.c_dim = c_dim 14 | self.eq_lr = eq_lr 15 | 16 | # Activation function 17 | if not leaky: 18 | self.actvn = F.relu 19 | self.pool = maxpool 20 | else: 21 | self.actvn = lambda x: F.leaky_relu(x, 0.2) 22 | self.pool = avgpool 23 | 24 | # Submodules 25 | self.conv_p = nn.Conv1d(6, 2*hidden_dim, 1) 26 | self.conv_0 = nn.Conv1d(2*hidden_dim, hidden_dim, 1) 27 | self.conv_1 = nn.Conv1d(2*hidden_dim, hidden_dim, 1) 28 | self.conv_2 = nn.Conv1d(2*hidden_dim, hidden_dim, 1) 29 | self.conv_3 = nn.Conv1d(2*hidden_dim, hidden_dim, 1) 30 | self.fc_c = nn.Linear(hidden_dim, c_dim) 31 | 32 | if self.eq_lr: 33 | self.conv_p = EqualizedLR(self.conv_p) 34 | self.conv_0 = EqualizedLR(self.conv_0) 35 | self.conv_1 = EqualizedLR(self.conv_1) 36 | self.conv_2 = EqualizedLR(self.conv_2) 37 | self.conv_3 = EqualizedLR(self.conv_3) 38 | self.fc_c = EqualizedLR(self.fc_c) 39 | 40 | def forward(self, geometry): 41 | p = geometry['points'] 42 | n = geometry['normals'] 43 | 44 | # Encode position into batch_size x F x T 45 | pn = torch.cat([p, n], dim=1) 46 | net = self.conv_p(pn) 47 | 48 | # Always pool to batch_size x F x 1, 49 | # expand to batch_size x F x T 50 | # and concatenate to batch_size x 2F x T 51 | net = self.conv_0(self.actvn(net)) 52 | pooled = self.pool(net, dim=2, keepdim=True) 53 | pooled = pooled.expand(net.size()) 54 | net = torch.cat([net, pooled], dim=1) 55 | 56 | net = self.conv_1(self.actvn(net)) 57 | pooled = self.pool(net, dim=2, keepdim=True) 58 | pooled = pooled.expand(net.size()) 59 | net = torch.cat([net, pooled], dim=1) 60 | 61 | net = self.conv_2(self.actvn(net)) 62 | pooled = self.pool(net, dim=2, keepdim=True) 63 | pooled = pooled.expand(net.size()) 64 | net = torch.cat([net, pooled], dim=1) 65 | 66 | net = self.conv_3(self.actvn(net)) 67 | 68 | # Recude to batch_size x F 69 | net = self.pool(net, dim=2) 70 | 71 | c = self.fc_c(self.actvn(net)) 72 | 73 | geom_descr = { 74 | 'global': c, 75 | } 76 | 77 | return geom_descr 78 | 79 | 80 | class ResnetPointnet(nn.Module): 81 | def __init__(self, c_dim=128, dim=6, hidden_dim=128): 82 | super().__init__() 83 | self.c_dim = c_dim 84 | 85 | self.fc_pos = nn.Linear(dim, 2*hidden_dim) 86 | self.block_0 = ResnetBlockFC(2*hidden_dim, hidden_dim) 87 | self.block_1 = ResnetBlockFC(2*hidden_dim, hidden_dim) 88 | self.block_2 = ResnetBlockFC(2*hidden_dim, hidden_dim) 89 | self.block_3 = ResnetBlockFC(2*hidden_dim, hidden_dim) 90 | self.block_4 = ResnetBlockFC(2*hidden_dim, hidden_dim) 91 | self.fc_c = nn.Linear(hidden_dim, c_dim) 92 | 93 | self.actvn = nn.ReLU() 94 | self.pool = maxpool 95 | 96 | def forward(self, geometry): 97 | p = geometry['points'] 98 | n = geometry['normals'] 99 | batch_size, T, D = p.size() 100 | 101 | pn = torch.cat([p, n], dim=1) 102 | # output size: B x T X F 103 | net = self.fc_pos(pn) 104 | net = self.block_0(net) 105 | pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 106 | net = torch.cat([net, pooled], dim=2) 107 | 108 | net = self.block_1(net) 109 | pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 110 | net = torch.cat([net, pooled], dim=2) 111 | 112 | net = self.block_2(net) 113 | pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 114 | net = torch.cat([net, pooled], dim=2) 115 | 116 | net = self.block_3(net) 117 | pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 118 | net = torch.cat([net, pooled], dim=2) 119 | 120 | net = self.block_4(net) 121 | 122 | # Recude to B x F 123 | net = self.pool(net, dim=1) 124 | 125 | c = self.fc_c(self.actvn(net)) 126 | 127 | return c 128 | 129 | 130 | class ResnetPointnetConv(nn.Module): 131 | def __init__(self, c_dim=128, dim=6, hidden_dim=128): 132 | super().__init__() 133 | self.c_dim = c_dim 134 | 135 | self.fc_pos = nn.Conv1d(dim, 2*hidden_dim, 1) 136 | self.block_0 = ResnetBlockConv1D(2*hidden_dim, hidden_dim) 137 | self.block_1 = ResnetBlockConv1D(2*hidden_dim, hidden_dim) 138 | self.block_2 = ResnetBlockConv1D(2*hidden_dim, hidden_dim) 139 | self.block_3 = ResnetBlockConv1D(2*hidden_dim, hidden_dim) 140 | self.block_4 = ResnetBlockConv1D(2*hidden_dim, hidden_dim) 141 | self.fc_c = nn.Linear(hidden_dim, c_dim) 142 | 143 | self.actvn = nn.ReLU() 144 | self.pool = maxpool 145 | 146 | def forward(self, geometry): 147 | p = geometry['points'] 148 | n = geometry['normals'] 149 | batch_size, T, D = p.size() 150 | 151 | pn = torch.cat([p, n], dim=1) 152 | # output size: B x T X F 153 | net = self.fc_pos(pn) 154 | net = self.block_0(net) 155 | pooled = self.pool(net, dim=2, keepdim=True).expand(net.size()) 156 | net = torch.cat([net, pooled], dim=1) 157 | 158 | net = self.block_1(net) 159 | pooled = self.pool(net, dim=2, keepdim=True).expand(net.size()) 160 | net = torch.cat([net, pooled], dim=1) 161 | 162 | net = self.block_2(net) 163 | pooled = self.pool(net, dim=2, keepdim=True).expand(net.size()) 164 | net = torch.cat([net, pooled], dim=1) 165 | 166 | net = self.block_3(net) 167 | pooled = self.pool(net, dim=2, keepdim=True).expand(net.size()) 168 | net = torch.cat([net, pooled], dim=1) 169 | 170 | net = self.block_4(net) 171 | 172 | # Recude to B x F 173 | net = self.pool(net, dim=2) 174 | 175 | c = self.fc_c(self.actvn(net)) 176 | 177 | geom_descr = { 178 | 'global': c, 179 | } 180 | 181 | return geom_descr 182 | 183 | 184 | def maxpool(x, dim=-1, keepdim=False): 185 | out, _ = x.max(dim=dim, keepdim=keepdim) 186 | return out 187 | 188 | 189 | def avgpool(x, dim=-1, keepdim=False): 190 | out = x.mean(dim=dim, keepdim=keepdim) 191 | return out 192 | -------------------------------------------------------------------------------- /mesh2tex/layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResnetBlockFC(nn.Module): 7 | def __init__(self, size_in, size_out=None, size_h=None): 8 | super().__init__() 9 | # Attributes 10 | if size_out is None: 11 | size_out = size_in 12 | 13 | if size_h is None: 14 | size_h = min(size_in, size_out) 15 | 16 | self.size_in = size_in 17 | self.size_h = size_h 18 | self.size_out = size_out 19 | # Submodules 20 | self.fc_0 = nn.Linear(size_in, size_h) 21 | self.fc_1 = nn.Linear(size_h, size_out) 22 | self.actvn = nn.ReLU() 23 | 24 | if size_in == size_out: 25 | self.shortcut = None 26 | else: 27 | self.shortcut = nn.Linear(size_in, size_out, bias=False) 28 | # Initialization 29 | nn.init.zeros_(self.fc_1.weight) 30 | 31 | def forward(self, x): 32 | net = self.fc_0(self.actvn(x)) 33 | dx = self.fc_1(self.actvn(net)) 34 | 35 | if self.shortcut is not None: 36 | x_s = self.shortcut(x) 37 | else: 38 | x_s = x 39 | 40 | return x_s + dx 41 | 42 | 43 | # Resnet Blocks 44 | class ResnetBlockConv1D(nn.Module): 45 | def __init__(self, size_in, size_out=None, size_h=None): 46 | super().__init__() 47 | # Attributes 48 | if size_out is None: 49 | size_out = size_in 50 | 51 | if size_h is None: 52 | size_h = min(size_in, size_out) 53 | 54 | self.size_in = size_in 55 | self.size_h = size_h 56 | self.size_out = size_out 57 | # Submodules 58 | self.fc_0 = nn.Conv1d(size_in, size_h, 1) 59 | self.fc_1 = nn.Conv1d(size_h, size_out, 1) 60 | self.actvn = nn.ReLU() 61 | 62 | if size_in == size_out: 63 | self.shortcut = None 64 | else: 65 | self.shortcut = nn.Conv1d(size_in, size_out, 1, bias=False) 66 | # Initialization 67 | nn.init.zeros_(self.fc_1.weight) 68 | 69 | def forward(self, x): 70 | net = self.fc_0(self.actvn(x)) 71 | dx = self.fc_1(self.actvn(net)) 72 | 73 | if self.shortcut is not None: 74 | x_s = self.shortcut(x) 75 | else: 76 | x_s = x 77 | 78 | return x_s + dx 79 | 80 | 81 | class ResnetBlockPointwise(nn.Module): 82 | def __init__(self, f_in, f_out=None, f_hidden=None, 83 | is_bias=True, actvn=F.relu, factor=1., eq_lr=False): 84 | super().__init__() 85 | # Filter dimensions 86 | if f_out is None: 87 | f_out = f_in 88 | 89 | if f_hidden is None: 90 | f_hidden = min(f_in, f_out) 91 | 92 | self.f_in = f_in 93 | self.f_hidden = f_hidden 94 | self.f_out = f_out 95 | 96 | self.factor = factor 97 | self.eq_lr = eq_lr 98 | 99 | # Activation function 100 | self.actvn = actvn 101 | 102 | # Submodules 103 | self.conv_0 = nn.Conv1d(f_in, f_hidden, 1) 104 | self.conv_1 = nn.Conv1d(f_hidden, f_out, 1, bias=is_bias) 105 | 106 | if self.eq_lr: 107 | self.conv_0 = EqualizedLR(self.conv_0) 108 | self.conv_1 = EqualizedLR(self.conv_1) 109 | 110 | if f_in == f_out: 111 | self.shortcut = nn.Sequential() 112 | else: 113 | self.shortcut = nn.Conv1d(f_in, f_out, 1, bias=False) 114 | if self.eq_lr: 115 | self.shortcut = EqualizedLR(self.shortcut) 116 | 117 | # Initialization 118 | nn.init.zeros_(self.conv_1.weight) 119 | 120 | def forward(self, x): 121 | net = self.conv_0(self.actvn(x)) 122 | dx = self.conv_1(self.actvn(net)) 123 | x_s = self.shortcut(x) 124 | return x_s + self.factor * dx 125 | 126 | 127 | class ResnetBlockConv2d(nn.Module): 128 | def __init__(self, f_in, f_out=None, f_hidden=None, 129 | is_bias=True, actvn=F.relu, factor=1., 130 | eq_lr=False, pixel_norm=False): 131 | super().__init__() 132 | # Filter dimensions 133 | if f_out is None: 134 | f_out = f_in 135 | 136 | if f_hidden is None: 137 | f_hidden = min(f_in, f_out) 138 | 139 | self.f_in = f_in 140 | self.f_hidden = f_hidden 141 | self.f_out = f_out 142 | self.factor = factor 143 | self.eq_lr = eq_lr 144 | self.use_pixel_norm = pixel_norm 145 | 146 | # Activation 147 | self.actvn = actvn 148 | 149 | # Submodules 150 | self.conv_0 = nn.Conv2d(self.f_in, self.f_hidden, 3, 151 | stride=1, padding=1) 152 | self.conv_1 = nn.Conv2d(self.f_hidden, self.f_out, 3, 153 | stride=1, padding=1, bias=is_bias) 154 | 155 | if self.eq_lr: 156 | self.conv_0 = EqualizedLR(self.conv_0) 157 | self.conv_1 = EqualizedLR(self.conv_1) 158 | 159 | if f_in == f_out: 160 | self.shortcut = nn.Sequential() 161 | else: 162 | self.shortcut = nn.Conv2d(f_in, f_out, 1, bias=False) 163 | if self.eq_lr: 164 | self.shortcut = EqualizedLR(self.shortcut) 165 | 166 | # Initialization 167 | nn.init.zeros_(self.conv_1.weight) 168 | 169 | def forward(self, x): 170 | x_s = self.shortcut(x) 171 | 172 | if self.use_pixel_norm: 173 | x = pixel_norm(x) 174 | dx = self.conv_0(self.actvn(x)) 175 | 176 | if self.use_pixel_norm: 177 | dx = pixel_norm(dx) 178 | dx = self.conv_1(self.actvn(dx)) 179 | 180 | out = x_s + self.factor * dx 181 | 182 | return out 183 | 184 | def _shortcut(self, x): 185 | if self.learned_shortcut: 186 | x_s = self.conv_s(x) 187 | else: 188 | x_s = x 189 | return x_s 190 | 191 | 192 | class EqualizedLR(nn.Module): 193 | def __init__(self, module): 194 | super().__init__() 195 | self.module = module 196 | self._make_params() 197 | 198 | def _make_params(self): 199 | weight = self.module.weight 200 | 201 | height = weight.data.shape[0] 202 | width = weight.view(height, -1).data.shape[1] 203 | 204 | # Delete parameters in child 205 | del self.module._parameters['weight'] 206 | self.module.weight = None 207 | 208 | # Add parameters to myself 209 | self.weight = nn.Parameter(weight.data) 210 | 211 | # Inherit parameters 212 | self.factor = np.sqrt(2 / width) 213 | 214 | # Initialize 215 | nn.init.normal_(self.weight) 216 | 217 | # Inherit bias if available 218 | self.bias = self.module.bias 219 | self.module.bias = None 220 | 221 | if self.bias is not None: 222 | del self.module._parameters['bias'] 223 | nn.init.zeros_(self.bias) 224 | 225 | def forward(self, *args, **kwargs): 226 | self.module.weight = self.factor * self.weight 227 | if self.bias is not None: 228 | self.module.bias = 1. * self.bias 229 | out = self.module.forward(*args, **kwargs) 230 | self.module.weight = None 231 | self.module.bias = None 232 | return out 233 | 234 | 235 | def pixel_norm(x): 236 | sigma = x.norm(dim=1, keepdim=True) 237 | out = x / (sigma + 1e-5) 238 | return out 239 | -------------------------------------------------------------------------------- /mesh2tex/nvs/__init__.py: -------------------------------------------------------------------------------- 1 | from mesh2tex.nvs import ( 2 | training, generation, config, models 3 | ) 4 | 5 | __all__ = [ 6 | training, 7 | generation, 8 | config, 9 | models, 10 | ] 11 | -------------------------------------------------------------------------------- /mesh2tex/nvs/config.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.distributions as dist 4 | import torch.optim as optim 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | from mesh2tex import data, geometry 8 | from mesh2tex.nvs import training, generation 9 | from mesh2tex.nvs import models 10 | 11 | 12 | def get_models(cfg, device=None, dataset=None): 13 | # Get configs 14 | encoder = cfg['model']['encoder'] 15 | decoder = cfg['model']['decoder'] 16 | discriminator = cfg['model']['discriminator'] 17 | 18 | encoder_kwargs = cfg['model']['encoder_kwargs'] 19 | decoder_kwargs = cfg['model']['decoder_kwargs'] 20 | discriminator_kwargs = cfg['model']['discriminator_kwargs'] 21 | 22 | img_size = cfg['data']['img_size'] 23 | z_dim = cfg['model']['z_dim'] 24 | c_dim = cfg['model']['c_dim'] 25 | white_bg = cfg['model']['white_bg'] 26 | # Create generator 27 | 28 | if encoder is not None: 29 | encoder = models.encoder_dict[encoder]( 30 | c_dim=c_dim, **encoder_kwargs 31 | ).to(device) 32 | 33 | decoder = models.decoder_dict[decoder]( 34 | c_dim=c_dim, white_bg=white_bg, **decoder_kwargs 35 | ).to(device) 36 | 37 | generator = models.NovelViewSynthesis( 38 | decoder, encoder 39 | ) 40 | 41 | p0_z = get_prior_z(cfg, device) 42 | 43 | # Create discriminator 44 | discriminator = models.discriminator_dict[discriminator]( 45 | img_size=img_size, **discriminator_kwargs 46 | ).to(device) 47 | 48 | # Output dict 49 | models_out = { 50 | 'generator': generator, 51 | 'discriminator': discriminator, 52 | } 53 | 54 | return models_out 55 | 56 | 57 | def get_optimizers(models, cfg): 58 | model_g = models['generator'] 59 | model_d = models['discriminator'] 60 | 61 | lr_g = cfg['training']['lr_g'] 62 | lr_d = cfg['training']['lr_d'] 63 | optimizer_g = optim.RMSprop(model_g.parameters(), lr=lr_g) 64 | optimizer_d = optim.RMSprop(model_d.parameters(), lr=lr_d) 65 | 66 | optimizers = { 67 | 'generator': optimizer_g, 68 | 'discriminator': optimizer_d, 69 | } 70 | return optimizers 71 | 72 | 73 | def get_dataset(mode, cfg, input_sampling=True): 74 | # Config 75 | path_shapes = cfg['data']['path_shapes'] 76 | img_size = cfg['data']['img_size'] 77 | 78 | # Transforms 79 | transform_img = transforms.Compose([ 80 | data.ResizeImage((img_size, img_size), order=0), 81 | ]) 82 | 83 | transform_img_input = transforms.Compose([ 84 | data.ResizeImage((224, 224), order=0), 85 | ]) 86 | 87 | transform_depth = torchvision.transforms.Compose([ 88 | data.ImageToDepthValue(), 89 | data.ResizeImage((img_size, img_size), order=0), 90 | ]) 91 | 92 | # Fields 93 | if mode == 'train': 94 | fields = { 95 | '2d': data.DepthImageField( 96 | 'image', 'depth', transform_img, transform_depth, 'png', 97 | 'exr', with_camera=True, random_view=True), 98 | 'condition': data.ImagesField('input_image', 99 | transform_img_input, 'jpg'), 100 | } 101 | mode_ = 'train' 102 | 103 | elif mode == 'val_eval': 104 | fields = { 105 | '2d': data.DepthImageField( 106 | 'image', 'depth', transform_img, transform_depth, 'png', 107 | 'exr', with_camera=True, random_view=True), 108 | 'condition': data.ImagesField( 109 | 'input_image', 110 | transform_img_input, 'jpg' 111 | ), 112 | } 113 | mode_ = 'val' 114 | 115 | elif mode == 'val_vis': 116 | # elif for_vis is True or cfg['training']['vis_fixviews'] is True: 117 | fields = { 118 | '2d': data.DepthImageVisualizeField( 119 | 'visualize/image', 'visualize/depth', 120 | transform_img, transform_depth, 'png', 121 | 'exr', with_camera=True, random_view=True 122 | ), 123 | 'condition': data.ImagesField( 124 | 'input_image', 125 | transform_img_input, 'jpg' 126 | ), 127 | } 128 | mode_ = 'val' 129 | 130 | elif mode == 'test_eval': 131 | # elif for_eval is True: 132 | fields = { 133 | '2d': data.DepthImageVisualizeField( 134 | 'image', 'depth', 135 | transform_img, transform_depth, 'png', 136 | 'exr', with_camera=True, random_view=True 137 | ), 138 | 'condition': data.ImagesField('input_image', 139 | transform_img_input, 'jpg', 140 | random_view=input_sampling 141 | ), 142 | 'idx': data.IndexField(), 143 | } 144 | mode_ = 'test' 145 | 146 | elif mode == 'test_vis': 147 | # elif for_vis is True or cfg['training']['vis_fixviews'] is True: 148 | fields = { 149 | '2d': data.DepthImageVisualizeField( 150 | 'visualize/image', 'visualize/depth', 151 | transform_img, transform_depth, 'png', 152 | 'exr', with_camera=True, random_view=True 153 | ), 154 | 'condition': data.ImagesField('input_image', 155 | transform_img_input, 'jpg', 156 | random_view=input_sampling 157 | ), 158 | 'idx': data.IndexField(), 159 | } 160 | mode_ = 'test' 161 | 162 | else: 163 | fields = { 164 | '2d': data.DepthImageField( 165 | 'image', 'depth', transform_img, transform_depth, 'png', 166 | 'exr', with_camera=True, random_view=True), 167 | 'condition': data.ImagesField('input_image', 168 | transform_img_input, 'jpg'), 169 | } 170 | 171 | if cfg['data']['shapes_multiclass']: 172 | ds_shapes = data.Shapes3dDataset( 173 | path_shapes, fields, split=mode_, no_except=True, 174 | ) 175 | else: 176 | ds_shapes = data.Shapes3dClassDataset( 177 | path_shapes, fields, split=mode_, no_except=True, 178 | ) 179 | 180 | if mode_ == 'val' or mode_ == 'test': 181 | ds = ds_shapes 182 | else: 183 | ds = data.CombinedDataset([ds_shapes, ds_shapes]) 184 | 185 | return ds 186 | return ds_shapes 187 | 188 | 189 | def get_dataloader(mode, cfg): 190 | # Config 191 | batch_size = cfg['training']['batch_size'] 192 | if mode != 'train': 193 | batch_size = cfg['generation']['batch_size'] 194 | ds_shapes = get_dataset(mode, cfg) 195 | 196 | data_loader = torch.utils.data.DataLoader( 197 | ds_shapes, batch_size=batch_size, num_workers=4, shuffle=True, 198 | collate_fn=data.collate_remove_none) 199 | 200 | return data_loader 201 | 202 | 203 | def get_meshloader(mode, cfg): 204 | # Config 205 | 206 | path_meshes = cfg['data']['path_meshes'] 207 | 208 | batch_size = cfg['training']['batch_size'] 209 | 210 | fields = { 211 | 'meshes': data.MeshField('mesh'), 212 | } 213 | 214 | ds_shapes = data.Shapes3dClassDataset( 215 | path_meshes, fields, split=None, no_except=True, 216 | ) 217 | 218 | data_loader = torch.utils.data.DataLoader( 219 | ds_shapes, batch_size=batch_size, num_workers=4, shuffle=True, 220 | collate_fn=data.collate_remove_none) 221 | 222 | return data_loader 223 | 224 | 225 | def get_trainer(models, optimizers, cfg, device=None): 226 | out_dir = cfg['training']['out_dir'] 227 | 228 | print_every = cfg['training']['print_every'] 229 | visualize_every = cfg['training']['visualize_every'] 230 | checkpoint_every = cfg['training']['checkpoint_every'] 231 | validate_every = cfg['training']['validate_every'] 232 | backup_every = cfg['training']['backup_every'] 233 | 234 | model_selection_metric = cfg['training']['model_selection_metric'] 235 | model_selection_mode = cfg['training']['model_selection_mode'] 236 | 237 | ma_beta = cfg['training']['moving_average_beta'] 238 | multi_gpu = cfg['training']['multi_gpu'] 239 | gp_reg = cfg['training']['gradient_penalties_reg'] 240 | w_pix = cfg['training']['weight_pixelloss'] 241 | w_gan = cfg['training']['weight_ganloss'] 242 | w_vae = cfg['training']['weight_vaeloss'] 243 | 244 | trainer = training.Trainer( 245 | models['generator'], models['discriminator'], 246 | optimizers['generator'], optimizers['discriminator'], 247 | ma_beta=ma_beta, 248 | multi_gpu=multi_gpu, 249 | gp_reg=gp_reg, 250 | w_pix=w_pix, w_gan=w_gan, w_vae=w_vae, 251 | out_dir=out_dir, 252 | model_selection_metric=model_selection_metric, 253 | model_selection_mode=model_selection_mode, 254 | print_every=print_every, 255 | visualize_every=visualize_every, 256 | checkpoint_every=checkpoint_every, 257 | backup_every=backup_every, 258 | validate_every=validate_every, 259 | device=device, 260 | ) 261 | 262 | return trainer 263 | 264 | 265 | def get_generator(model, cfg, device, **kwargs): 266 | 267 | generator = generation.Generator3D( 268 | model, 269 | device=device, 270 | ) 271 | return generator 272 | 273 | def get_prior_z(cfg, device, **kwargs): 274 | z_dim = cfg['model']['z_dim'] 275 | p0_z = dist.Normal( 276 | torch.zeros(z_dim, device=device), 277 | torch.ones(z_dim, device=device) 278 | ) 279 | 280 | return p0_z -------------------------------------------------------------------------------- /mesh2tex/nvs/generation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import trimesh 4 | import os 5 | from trimesh.util import array_to_string 6 | from mesh2tex import geometry 7 | from torchvision.utils import save_image 8 | 9 | 10 | class Generator3D(object): 11 | def __init__(self, model, device=None): 12 | 13 | self.model = model 14 | self.device = device 15 | 16 | def generate_images_4eval(self, batch, out_dir, model_names): 17 | depth = batch['2d.depth'].to(self.device) 18 | img_real = batch['2d.img'].to(self.device) 19 | condition = batch['condition'].to(self.device) 20 | batch_size = depth.size(0) 21 | num_views = depth.size(1) 22 | # assert(num_views == 5) 23 | # Save real images 24 | 25 | out_dir_real = out_dir + "/real/" 26 | out_dir_fake = out_dir + "/fake/" 27 | out_dir_condition = out_dir + "/condition/" 28 | if not os.path.exists(out_dir_real): 29 | os.makedirs(out_dir_real) 30 | if not os.path.exists(out_dir_fake): 31 | os.makedirs(out_dir_fake) 32 | if not os.path.exists(out_dir_condition): 33 | os.makedirs(out_dir_condition) 34 | 35 | for j in range(batch_size): 36 | save_image( 37 | condition[j].cpu(), 38 | os.path.join(out_dir_condition, '%s.png' % model_names[j])) 39 | 40 | for v in range(num_views): 41 | depth_ = depth[:, v] 42 | img_real_ = img_real[:, v] 43 | 44 | self.model.eval() 45 | 46 | with torch.no_grad(): 47 | img_fake_ = self.model(depth_, condition) 48 | 49 | for j in range(batch_size): 50 | save_image( 51 | img_real_[j].cpu(), 52 | os.path.join( 53 | out_dir_real, '%s%03d.png' % (model_names[j], v) 54 | )) 55 | save_image( 56 | img_fake_[j].cpu(), 57 | os.path.join( 58 | out_dir_fake, '%s%03d.png' % (model_names[j], v) 59 | )) 60 | 61 | def generate_images_4eval_condi_hd(self, batch, out_dir, model_names): 62 | depth = batch['2d.depth'] #.to(self.device) 63 | img_real = batch['2d.img'] #.to(self.device) 64 | condition = batch['condition'].to(self.device) 65 | batch_size = depth.size(0) 66 | num_views = depth.size(1) 67 | # if depth.size(1) >= 10: 68 | # num_views = 10 69 | # assert(num_views == 5) 70 | # Save real images 71 | out_dir_real = out_dir + "/real/" 72 | out_dir_fake = out_dir + "/fake/" 73 | out_dir_condition = out_dir + "/condition/" 74 | if not os.path.exists(out_dir_real): 75 | os.makedirs(out_dir_real) 76 | if not os.path.exists(out_dir_fake): 77 | os.makedirs(out_dir_fake) 78 | if not os.path.exists(out_dir_condition): 79 | os.makedirs(out_dir_condition) 80 | viewbatchsize = 2 81 | viewbatchnum = int(num_views / viewbatchsize) 82 | # points_batches = points.split(10, dim=0) 83 | for j in range(batch_size): 84 | for vidx in range(viewbatchnum): 85 | lower = vidx * viewbatchsize 86 | upper = (vidx + 1) * viewbatchsize 87 | 88 | depth_ = depth[j][lower:upper] 89 | img_real_ = img_real[j][lower:upper] 90 | condition_ = condition[j][:4].expand( 91 | viewbatchsize, condition.size(1), 92 | condition.size(2), condition.size(3)) 93 | 94 | self.model.eval() 95 | with torch.no_grad(): 96 | img_fake = self.model(depth_.to(self.device), condition_.to(self.device)) 97 | for v in range(viewbatchsize): 98 | save_image( 99 | img_real_[v], 100 | os.path.join(out_dir_real, 101 | '%s%03d.png' % (model_names[j], vidx * viewbatchsize + v))) 102 | save_image( 103 | img_fake[v].cpu(), 104 | os.path.join(out_dir_fake, 105 | '%s%03d.png' % (model_names[j], vidx * viewbatchsize + v))) 106 | save_image( 107 | condition[j].cpu(), 108 | os.path.join(out_dir_condition, 109 | '%s.png' % (model_names[j]))) 110 | -------------------------------------------------------------------------------- /mesh2tex/nvs/models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import distributions as dist 4 | import trimesh 5 | from mesh2tex.nvs.models import ( 6 | encoder, decoder, discriminator 7 | ) 8 | 9 | encoder_dict = { 10 | 'resnet18': encoder.Resnet18, 11 | } 12 | 13 | decoder_dict = { 14 | 'each_layer_c': decoder.UNetEachLayerC, 15 | } 16 | 17 | discriminator_dict = { 18 | 'resnet': discriminator.Resnet, 19 | } 20 | 21 | 22 | class NovelViewSynthesis(nn.Module): 23 | def __init__(self, decoder, encoder): 24 | super().__init__() 25 | 26 | self.decoder = decoder 27 | self.encoder = encoder 28 | 29 | def forward(self, depth, condition): 30 | """Generate an image . 31 | 32 | Args: 33 | depth (torch.FloatTensor): tensor of size B x 1 x N x M 34 | representing depth of at pixels 35 | Returns: 36 | img (torch.FloatTensor): tensor of size B x 3 x N x M representing 37 | output image 38 | """ 39 | batch_size, _, N, M = depth.size() 40 | 41 | assert(depth.size(1) == 1) 42 | 43 | c = self.encode(condition) 44 | img = self.decode(depth, c) 45 | 46 | return img 47 | 48 | def encode(self, cond): 49 | """Encode mesh using sampled 3D location on the mesh. 50 | 51 | Args: 52 | input_image (torch.FloatTensor): tensor of size B x 3 x N x M 53 | input image 54 | 55 | Returns: 56 | c (torch.FloatTensor): tensor of size B x C with encoding of 57 | the input image 58 | """ 59 | z = self.encoder(cond) 60 | return z 61 | 62 | def decode(self, depth, c): 63 | """Decode image from 3D locations, conditional encoding and latent 64 | encoding. 65 | 66 | Args: 67 | depth (torch.FloatTensor): tensor of size B x 1 x N x M 68 | representing depth of at pixels 69 | c (torch.FloatTensor): tensor of size B x C with the encoding of 70 | the 3D meshes 71 | 72 | Returns: 73 | rgb (torch.FloatTensor): tensor of size B x 3 x N representing 74 | color at given 3d locations 75 | """ 76 | rgb = self.decoder(depth, c) 77 | return rgb 78 | -------------------------------------------------------------------------------- /mesh2tex/nvs/models/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class UNetEachLayerC(nn.Module): 7 | def __init__(self, c_dim, white_bg=True, resnet_leaky=None): 8 | super().__init__() 9 | # Attributes 10 | self.c_dim = c_dim 11 | self.white_bg = white_bg 12 | 13 | # Submodules 14 | self.conv_0 = nn.Conv2d(1, 64, 3, padding=1, stride=2) 15 | self.conv_1 = nn.Conv2d(64, 128, 3, padding=1, stride=2) 16 | self.conv_2 = nn.Conv2d(128, 256, 3, padding=1, stride=2) 17 | self.conv_3 = nn.Conv2d(256, 512, 3, padding=1, stride=2) 18 | self.conv_4 = nn.Conv2d(512, 1024, 3, padding=1, stride=2) 19 | 20 | self.conv_trp_0 = nn.ConvTranspose2d(1024, 512, 3, padding=1, stride=2, output_padding=1) 21 | self.conv_trp_1 = nn.ConvTranspose2d(1024, 256, 3, padding=1, stride=2, output_padding=1) 22 | self.conv_trp_2 = nn.ConvTranspose2d(512, 128, 3, padding=1, stride=2, output_padding=1) 23 | self.conv_trp_3 = nn.ConvTranspose2d(256, 64, 3, padding=1, stride=2, output_padding=1) 24 | self.conv_trp_4 = nn.ConvTranspose2d(128, 3, 3, padding=1, stride=2, output_padding=1) 25 | 26 | self.fc_0 = nn.Linear(c_dim, 64) 27 | self.fc_1 = nn.Linear(c_dim, 128) 28 | self.fc_2 = nn.Linear(c_dim, 256) 29 | self.fc_3 = nn.Linear(c_dim, 512) 30 | self.fc_4 = nn.Linear(c_dim, 1024) 31 | 32 | def forward(self, depth, c): 33 | assert(c.size(0) == depth.size(0)) 34 | 35 | batch_size = depth.size(0) 36 | c_dim = self.c_dim 37 | 38 | mask = (depth != float('Inf')) 39 | depth = depth.clone() 40 | depth[~mask] = 0. 41 | 42 | net = depth 43 | 44 | # Downsample 45 | # 64 x 128 x 128 46 | net0 = self.conv_0(net) + self.fc_0(c).view(batch_size, 64, 1, 1) 47 | net0 = F.relu(net0) 48 | # 128 x 64 x 64 49 | net1 = self.conv_1(net0) + self.fc_1(c).view(batch_size, 128, 1, 1) 50 | net1 = F.relu(net1) 51 | # 256 x 32 x 32 52 | net2 = self.conv_2(net1) + self.fc_2(c).view(batch_size, 256, 1, 1) 53 | net2 = F.relu(net2) 54 | # 512 x 16 x 16 55 | net3 = self.conv_3(net2) + self.fc_3(c).view(batch_size, 512, 1, 1) 56 | net3 = F.relu(net3) 57 | # 1024 x 8 x 8 58 | net4 = self.conv_4(net3) + self.fc_4(c).view(batch_size, 1024, 1, 1) 59 | net4 = F.relu(net4) 60 | 61 | # Upsample 62 | # 512 x 16 x 16 63 | net = F.relu(self.conv_trp_0(net4)) 64 | # 256 x 32 x 32 65 | net = torch.cat([net, net3], dim=1) 66 | net = F.relu(self.conv_trp_1(net)) 67 | # 128 x 64 x 64 68 | net = torch.cat([net, net2], dim=1) 69 | net = F.relu(self.conv_trp_2(net)) 70 | # 64 x 128 x 128 71 | net = torch.cat([net, net1], dim=1) 72 | net = F.relu(self.conv_trp_3(net)) 73 | # 3 x 256 x 256 74 | net = torch.cat([net, net0], dim=1) 75 | net = self.conv_trp_4(net) 76 | net = torch.sigmoid(net) 77 | 78 | if self.white_bg: 79 | mask = mask.float() 80 | net = mask * net + (1 - mask) * torch.ones_like(net) 81 | else: 82 | mask = mask.float() 83 | net = mask * net + (1 - mask) * torch.zeros_like(net) 84 | 85 | return net -------------------------------------------------------------------------------- /mesh2tex/nvs/models/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from mesh2tex.layers import ResnetBlockConv2d 6 | 7 | 8 | class Resnet(nn.Module): 9 | def __init__(self, img_size, embed_size=256, 10 | nfilter=64, nfilter_max=1024, leaky=True): 11 | super().__init__() 12 | self.embed_size = embed_size 13 | s0 = self.s0 = 4 14 | nf = self.nf = nfilter 15 | nf_max = self.nf_max = nfilter_max 16 | 17 | # Activation function 18 | if not leaky: 19 | self.actvn = F.relu 20 | else: 21 | self.actvn = lambda x: F.leaky_relu(x, 0.2) 22 | 23 | # Submodules 24 | nlayers = int(np.log2(img_size / s0)) 25 | self.nf0 = min(nf_max, nf * 2**nlayers) 26 | 27 | blocks = [ 28 | ResnetBlockConv2d(nf, nf, actvn=self.actvn) 29 | ] 30 | 31 | for i in range(nlayers): 32 | nf0 = min(nf * 2**i, nf_max) 33 | nf1 = min(nf * 2**(i+1), nf_max) 34 | blocks += [ 35 | nn.AvgPool2d(3, stride=2, padding=1), 36 | ResnetBlockConv2d(nf0, nf1), 37 | ] 38 | 39 | self.conv_img = nn.Conv2d(3 +1 , 1*nf, 3, padding=1) 40 | self.resnet = nn.Sequential(*blocks) 41 | self.fc = nn.Linear(self.nf0*s0*s0, 1) 42 | 43 | # Initialization 44 | nn.init.zeros_(self.fc.weight) 45 | 46 | def forward(self, x, depth): 47 | batch_size = x.size(0) 48 | 49 | depth = depth.clone() 50 | depth[depth == float("Inf")] = 0 51 | depth[depth == -1*float("Inf")] = 0 52 | 53 | x_and_depth = torch.cat([x, depth], dim=1) 54 | 55 | out = self.conv_img(x_and_depth) 56 | out = self.resnet(out) 57 | out = out.view(batch_size, self.nf0*self.s0*self.s0) 58 | out = self.fc(self.actvn(out)) 59 | out = out.squeeze() 60 | 61 | return out 62 | -------------------------------------------------------------------------------- /mesh2tex/nvs/models/encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision import models 3 | 4 | 5 | class Resnet18(nn.Module): 6 | ''' ResNet-18 conditioning network. 7 | ''' 8 | def __init__(self, c_dim=128, normalize=True, use_linear=True): 9 | ''' Initialisation. 10 | 11 | Args: 12 | c_dim (int): output dimension of the latent embedding 13 | normalize (bool): whether the input images should be normalized 14 | use_linear (bool): whether a final linear layer should be used 15 | ''' 16 | super().__init__() 17 | self.normalize = normalize 18 | self.use_linear = use_linear 19 | self.features = models.resnet18(pretrained=True) 20 | self.features.fc = nn.Sequential() 21 | if use_linear: 22 | self.fc = nn.Linear(512, c_dim) 23 | elif c_dim == 512: 24 | self.fc = nn.Sequential() 25 | else: 26 | raise ValueError('c_dim must be 512 if use_linear is False') 27 | 28 | def forward(self, x): 29 | if self.normalize: 30 | x = normalize_imagenet(x) 31 | net = self.features(x) 32 | out = self.fc(net) 33 | return out 34 | 35 | 36 | def normalize_imagenet(x): 37 | x = x.clone() 38 | x[:, 0] = (x[:, 0] - 0.485) / 0.229 39 | x[:, 1] = (x[:, 1] - 0.456) / 0.224 40 | x[:, 2] = (x[:, 2] - 0.406) / 0.225 41 | return x 42 | -------------------------------------------------------------------------------- /mesh2tex/nvs/training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.autograd as autograd 7 | from torchvision.utils import save_image 8 | from mesh2tex import geometry 9 | from mesh2tex.training import BaseTrainer 10 | from mesh2tex.utils.io import export_pointcloud 11 | import mesh2tex.utils.FID.feature_l1 as feature_l1 12 | import mesh2tex.utils.SSIM_L1.ssim_l1_score as SSIM 13 | 14 | 15 | class Trainer(BaseTrainer): 16 | def __init__(self, model_g, model_d, 17 | optimizer_g, optimizer_d, 18 | ma_beta=0.99, 19 | loss_type='L1', 20 | gp_reg=10., 21 | w_pix=0., w_gan=0., w_vae=0., 22 | gan_type='standard', 23 | multi_gpu=False, 24 | **kwargs): 25 | # Initialize base trainer 26 | super().__init__(**kwargs) 27 | 28 | # Models and optimizers 29 | self.model_g = model_g 30 | self.model_d = model_d 31 | 32 | self.model_g_ma = copy.deepcopy(model_g) 33 | 34 | for p in self.model_g_ma.parameters(): 35 | p.requires_grad = False 36 | self.model_g_ma.eval() 37 | 38 | self.optimizer_g = optimizer_g 39 | self.optimizer_d = optimizer_d 40 | self.loss_type = loss_type 41 | 42 | # Attributes 43 | self.gp_reg = gp_reg 44 | self.ma_beta = ma_beta 45 | self.gan_type = gan_type 46 | self.multi_gpu = multi_gpu 47 | self.w_pix = w_pix 48 | self.w_vae = w_vae 49 | self.w_gan = w_gan 50 | self.pix_loss = w_pix != 0 51 | self.vae_loss = w_vae != 0 52 | self.gan_loss = w_gan != 0 53 | if self.vae_loss and self.pix_loss: 54 | print('Not possible to combine pix and vae loss') 55 | # Checkpointer 56 | if self.gan_loss is True: 57 | self.checkpoint_io.register_modules( 58 | model_g=self.model_g, model_d=self.model_d, 59 | model_g_ma=self.model_g_ma, 60 | optimizer_g=self.optimizer_g, 61 | optimizer_d=self.optimizer_d, 62 | ) 63 | else: 64 | self.checkpoint_io.register_modules( 65 | model_g=self.model_g, model_d=self.model_d, 66 | model_g_ma=self.model_g_ma, 67 | optimizer_g=self.optimizer_g, 68 | ) 69 | 70 | def train_step(self, batch, epoch_it, it): 71 | # Seperate batch into batch and model 72 | # batch_model, (img, _) = batch 73 | batch_model0, batch_model1 = batch 74 | loss_g = self.train_step_g(batch_model0) 75 | if self.gan_loss is True: 76 | loss_d = self.train_step_d(batch_model1) 77 | else: 78 | loss_d = 0 79 | 80 | losses = { 81 | 'loss_g': loss_g, 82 | 'loss_d': loss_d, 83 | } 84 | return losses 85 | 86 | def train_step_d(self, batch): 87 | ''' 88 | A single train step of the discriminator 89 | ''' 90 | model_d = self.model_d 91 | model_g = self.model_g 92 | 93 | model_d.train() 94 | model_g.train() 95 | 96 | if self.multi_gpu: 97 | model_d = nn.DataParallel(model_d) 98 | model_g = nn.DataParallel(model_g) 99 | 100 | self.optimizer_d.zero_grad() 101 | 102 | # Get data 103 | depth = batch['2d.depth'].to(self.device) 104 | img_real = batch['2d.img'].to(self.device) 105 | condition = batch['condition'].to(self.device) 106 | 107 | # Loss on real 108 | img_real.requires_grad_() 109 | d_real = model_d(img_real, depth) 110 | 111 | dloss_real = self.compute_gan_loss(d_real, 1) 112 | dloss_real.backward(retain_graph=True) 113 | 114 | # R1 Regularizer 115 | reg = self.gp_reg * compute_grad2(d_real, img_real).mean() 116 | reg.backward() 117 | 118 | # Loss on fake 119 | with torch.no_grad(): 120 | if self.gan_loss is True: 121 | img_fake = model_g(depth, condition) 122 | 123 | d_fake = model_d(img_fake, depth) 124 | 125 | dloss_fake = self.compute_gan_loss(d_fake, 0) 126 | dloss_fake.backward() 127 | 128 | # Gradient step 129 | self.optimizer_d.step() 130 | 131 | return self.w_gan * (dloss_fake.item() + dloss_real.item()) 132 | 133 | def train_step_g(self, batch): 134 | model_d = self.model_d 135 | model_g = self.model_g 136 | 137 | model_d.train() 138 | model_g.train() 139 | 140 | if self.multi_gpu: 141 | model_d = nn.DataParallel(model_d) 142 | model_g = nn.DataParallel(model_g) 143 | 144 | self.optimizer_g.zero_grad() 145 | 146 | # Get data 147 | depth = batch['2d.depth'].to(self.device) 148 | img_real = batch['2d.img'].to(self.device) 149 | condition = batch['condition'].to(self.device) 150 | 151 | # Loss on fake 152 | img_fake = model_g(depth, condition) 153 | loss_pix = 0 154 | loss_gan = 0 155 | 156 | if self.pix_loss is True: 157 | loss_pix = self.compute_loss(img_fake, img_real) 158 | if self.gan_loss is True: 159 | d_fake = model_d(img_fake, depth) 160 | loss_gan = self.compute_gan_loss(d_fake, 1) 161 | 162 | # weighting 163 | loss = self.w_pix * loss_pix + self.w_gan * loss_gan 164 | loss.backward() 165 | 166 | # Gradient step 167 | self.optimizer_g.step() 168 | 169 | # Update moving average 170 | #self.update_moving_average() 171 | 172 | return loss.item() 173 | 174 | def compute_loss(self, img_fake, img_real): 175 | 176 | if self.loss_type == 'L2': 177 | loss = F.mse_loss(img_fake, img_real) 178 | elif self.loss_type == 'L1': 179 | loss = F.l1_loss(img_fake, img_real) 180 | else: 181 | raise NotImplementedError 182 | 183 | return loss 184 | 185 | def compute_gan_loss(self, d_out, target): 186 | ''' 187 | Compute GAN loss (standart cross entropy or wasserstein distance) 188 | !!! Without Regularizer 189 | ''' 190 | targets = d_out.new_full(size=d_out.size(), fill_value=target) 191 | 192 | if self.gan_type == 'standard': 193 | loss = F.binary_cross_entropy_with_logits(d_out, targets) 194 | elif self.gan_type == 'wgan': 195 | loss = (2*target - 1) * d_out.mean() 196 | else: 197 | raise NotImplementedError 198 | 199 | return loss 200 | 201 | def eval_step(self, batch): 202 | depth = batch['2d.depth'].to(self.device) 203 | img_real = batch['2d.img'].to(self.device) 204 | condition = batch['condition'].to(self.device) 205 | 206 | # Get model 207 | model_g = self.model_g 208 | model_g.eval() 209 | 210 | if self.multi_gpu: 211 | model_g = nn.DataParallel(model_g) 212 | 213 | # Predict 214 | with torch.no_grad(): 215 | img_fake = model_g(depth, condition) 216 | loss_val = self.compute_loss(img_fake, img_real) 217 | ssim, l1 = SSIM.calculate_ssim_l1_given_tensor(img_fake, img_real) 218 | featl1 = feature_l1.calculate_feature_l1_given_tensors( 219 | img_fake, img_real, img_real.size(0), True, 2048) 220 | 221 | loss_val_dict = {'loss_val': loss_val.item(), 'SSIM': ssim, 222 | 'featl1': featl1} 223 | 224 | return loss_val_dict 225 | 226 | def eval_step_old(self, batch): 227 | depth = batch['2d.depth'].to(self.device) 228 | img_real = batch['2d.img'].to(self.device) 229 | condition = batch['condition'].to(self.device) 230 | 231 | self.model_g.eval() 232 | with torch.no_grad(): 233 | img_fake = self.model_g(depth, condition) 234 | loss_val = self.compute_loss(img_fake, img_real) 235 | 236 | loss_val_dict = {'loss_val': loss_val.item()} 237 | return loss_val_dict 238 | 239 | def visualize(self, batch): 240 | # b atch_model, (img_real, _) = batch 241 | 242 | depth = batch['2d.depth'].to(self.device) 243 | img_real = batch['2d.img'].to(self.device) 244 | condition = batch['condition'].to(self.device) 245 | batch_size = depth.size(0) 246 | num_views = depth.size(1) 247 | # Save real images 248 | 249 | for j in range(batch_size): 250 | depth_ = depth[j] 251 | img_real_ = img_real[j] 252 | condition_ = condition[j].expand(num_views, condition.size(1), 253 | condition.size(2), condition.size(3)) 254 | save_image(img_real_, os.path.join(self.vis_dir, 'real_%i.png' % j)) 255 | # Create fake images and save 256 | self.model_g.eval() 257 | with torch.no_grad(): 258 | img_fake = self.model_g(depth_, condition_) 259 | save_image(img_fake.cpu(), os.path.join(self.vis_dir, 'fake_%i.png' % j)) 260 | save_image(condition[j].cpu(), os.path.join(self.vis_dir, 'condition_%i.png' % j)) 261 | 262 | def update_moving_average(self): 263 | param_dict_src = dict(self.model_g.named_parameters()) 264 | beta = self.ma_beta 265 | for p_name, p_tgt in self.model_g_ma.named_parameters(): 266 | p_src = param_dict_src[p_name] 267 | assert(p_src is not p_tgt) 268 | with torch.no_grad(): 269 | p_ma = beta * p_tgt + (1. - beta) * p_src 270 | p_tgt.copy_(p_ma) 271 | 272 | 273 | def compute_grad2(d_out, x_in): 274 | ''' 275 | Derive L2-Gradient penalty for regularizing the GAN 276 | ''' 277 | batch_size = x_in.size(0) 278 | grad_dout = autograd.grad( 279 | outputs=d_out.sum(), inputs=x_in, 280 | create_graph=True, retain_graph=True, only_inputs=True 281 | )[0] 282 | grad_dout2 = grad_dout.pow(2) 283 | assert(grad_dout2.size() == x_in.size()) 284 | reg = grad_dout2.view(batch_size, -1).sum(1) 285 | return reg 286 | -------------------------------------------------------------------------------- /mesh2tex/texnet/__init__.py: -------------------------------------------------------------------------------- 1 | from mesh2tex.texnet import ( 2 | training, generation, config, models 3 | ) 4 | 5 | __all__ = [ 6 | training, 7 | generation, 8 | config, 9 | models, 10 | ] 11 | -------------------------------------------------------------------------------- /mesh2tex/texnet/config.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.distributions as dist 5 | import torch.optim as optim 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | from mesh2tex import data, geometry 9 | from mesh2tex.texnet import training, generation 10 | from mesh2tex.texnet import models 11 | 12 | 13 | def get_models(cfg, dataset=None, device=None): 14 | # Get configs 15 | encoder = cfg['model']['encoder'] 16 | decoder = cfg['model']['decoder'] 17 | geometry_encoder = cfg['model']['geometry_encoder'] 18 | vae_encoder = cfg['model']['vae_encoder'] 19 | discriminator = cfg['model']['discriminator'] 20 | 21 | encoder_kwargs = cfg['model']['encoder_kwargs'] 22 | decoder_kwargs = cfg['model']['decoder_kwargs'] 23 | geometry_encoder_kwargs = cfg['model']['geometry_encoder_kwargs'] 24 | discriminator_kwargs = cfg['model']['discriminator_kwargs'] 25 | vae_encoder_kwargs = cfg['model']['vae_encoder_kwargs'] 26 | img_size = cfg['data']['img_size'] 27 | z_dim = cfg['model']['z_dim'] 28 | c_dim = cfg['model']['c_dim'] 29 | white_bg = cfg['model']['white_bg'] 30 | # Create generator 31 | 32 | if encoder == "idx": 33 | encoder = nn.Embedding(len(dataset), c_dim) 34 | elif encoder is not None: 35 | encoder = models.encoder_dict[encoder]( 36 | c_dim=c_dim, **encoder_kwargs 37 | ).to(device) 38 | 39 | decoder = models.decoder_dict[decoder]( 40 | c_dim=c_dim, z_dim=z_dim, **decoder_kwargs 41 | ).to(device) 42 | 43 | geometry_encoder = geometry.encoder_dict[geometry_encoder]( 44 | c_dim=c_dim, **geometry_encoder_kwargs 45 | ).to(device) 46 | 47 | if vae_encoder is not None: 48 | vae_encoder = models.vae_encoder_dict[vae_encoder]( 49 | img_size=img_size, c_dim=c_dim, z_dim=z_dim, **vae_encoder_kwargs 50 | ).to(device) 51 | 52 | p0_z = get_prior_z(cfg, device) 53 | 54 | generator = models.TextureNetwork( 55 | decoder, geometry_encoder, encoder, vae_encoder, p0_z, white_bg 56 | ) 57 | 58 | # Create discriminator 59 | discriminator = models.discriminator_dict[discriminator]( 60 | geometry_encoder, 61 | img_size=img_size, **discriminator_kwargs 62 | ).to(device) 63 | 64 | # Output dict 65 | models_out = { 66 | 'generator': generator, 67 | 'discriminator': discriminator, 68 | } 69 | 70 | return models_out 71 | 72 | 73 | def get_optimizers(models, cfg): 74 | model_g = models['generator'] 75 | model_d = models['discriminator'] 76 | 77 | lr_g = cfg['training']['lr_g'] 78 | lr_d = cfg['training']['lr_d'] 79 | optimizer_g = optim.RMSprop(model_g.parameters(), lr=lr_g) 80 | optimizer_d = optim.RMSprop(model_d.parameters(), lr=lr_d) 81 | 82 | optimizers = { 83 | 'generator': optimizer_g, 84 | 'discriminator': optimizer_d, 85 | } 86 | return optimizers 87 | 88 | 89 | def get_dataset(mode, cfg, input_sampling=True): 90 | # Config 91 | path_shapes = cfg['data']['path_shapes'] 92 | img_size = cfg['data']['img_size'] 93 | pc_subsampling = cfg['training']['pc_subsampling'] 94 | pcl_knn = cfg['data']['pcl_knn'] 95 | 96 | # Fields 97 | transform_img = transforms.Compose([ 98 | data.ResizeImage((img_size, img_size), order=0), 99 | ]) 100 | 101 | transform_img_input = transforms.Compose([ 102 | data.ResizeImage((224, 224), order=0), 103 | ]) 104 | 105 | transform_depth = torchvision.transforms.Compose([ 106 | data.ImageToDepthValue(), 107 | data.ResizeImage((img_size, img_size), order=0), 108 | ]) 109 | 110 | pcl_transform = [data.SubsamplePointcloud(pc_subsampling)] 111 | if pcl_knn is not None: 112 | pcl_transform += [data.ComputeKNNPointcloud(pcl_knn)] 113 | 114 | pcl_transform = transforms.Compose(pcl_transform) 115 | 116 | if mode == 'train': 117 | fields = { 118 | '2d': data.DepthImageField( 119 | 'image', 'depth', transform_img, transform_depth, 'png', 120 | 'exr', with_camera=True, random_view=True), 121 | 'pointcloud': data.PointCloudField('pointcloud.npz', pcl_transform), 122 | 'condition': data.ImagesField('input_image', 123 | transform_img_input, 'jpg'), 124 | } 125 | mode_ = 'train' 126 | 127 | elif mode == 'val_eval': 128 | fields = { 129 | '2d': data.DepthImageField( 130 | 'image', 'depth', transform_img, transform_depth, 'png', 131 | 'exr', with_camera=True, random_view=True), 132 | 'pointcloud': data.PointCloudField('pointcloud.npz', pcl_transform), 133 | 'condition': data.ImagesField('input_image', 134 | transform_img_input, 'jpg'), 135 | } 136 | mode_ = 'val' 137 | 138 | elif mode == 'val_vis': 139 | fields = { 140 | '2d': data.DepthImageVisualizeField( 141 | 'visualize/image', 'visualize/depth', 142 | transform_img, transform_depth, 'png', 143 | 'exr', with_camera=True, random_view=True 144 | ), 145 | 'pointcloud': data.PointCloudField('pointcloud.npz', pcl_transform), 146 | 'condition': data.ImagesField('input_image', 147 | transform_img_input, 'jpg'), 148 | } 149 | mode_ = 'val' 150 | 151 | elif mode == 'test_eval': 152 | fields = { 153 | '2d': data.DepthImageVisualizeField( 154 | 'image', 'depth', 155 | transform_img, transform_depth, 'png', 156 | 'exr', with_camera=True, random_view=True 157 | ), 158 | 'pointcloud': data.PointCloudField('pointcloud.npz', pcl_transform), 159 | 'condition': data.ImagesField('input_image', 160 | transform_img_input, 'jpg', 161 | random_view=input_sampling), 162 | 'idx': data.IndexField(), 163 | } 164 | mode_ = 'test' 165 | 166 | elif mode == 'test_vis': 167 | fields = { 168 | '2d': data.DepthImageVisualizeField( 169 | 'visualize/image', 'visualize/depth', 170 | transform_img, transform_depth, 'png', 171 | 'exr', with_camera=True, random_view=True 172 | ), 173 | 'pointcloud': data.PointCloudField('pointcloud.npz', pcl_transform), 174 | 'condition': data.ImagesField('input_image', 175 | transform_img_input, 'jpg', 176 | random_view=input_sampling), 177 | 'idx': data.IndexField(), 178 | } 179 | mode_ = 'test' 180 | 181 | else: 182 | print('Invalid data loading mode') 183 | 184 | # Dataset 185 | if cfg['data']['shapes_multiclass']: 186 | ds_shapes = data.Shapes3dDataset( 187 | path_shapes, fields, split=mode_, no_except=True, 188 | ) 189 | else: 190 | ds_shapes = data.Shapes3dClassDataset( 191 | path_shapes, fields, split=mode_, no_except=False, 192 | ) 193 | 194 | if mode_ == 'val' or mode_ == 'test': 195 | ds = ds_shapes 196 | else: 197 | ds = data.CombinedDataset([ds_shapes, ds_shapes]) 198 | 199 | return ds 200 | 201 | 202 | def get_dataloader(mode, cfg): 203 | # Config 204 | batch_size = cfg['training']['batch_size'] 205 | with_shuffle = cfg['data']['with_shuffle'] 206 | 207 | ds_shapes = get_dataset(mode, cfg) 208 | data_loader = torch.utils.data.DataLoader( 209 | ds_shapes, batch_size=batch_size, num_workers=12, shuffle=with_shuffle) 210 | #gcollate_fn=data.collate_remove_none) 211 | 212 | return data_loader 213 | 214 | 215 | def get_meshloader(mode, cfg): 216 | # Config 217 | 218 | path_meshes = cfg['data']['path_meshes'] 219 | 220 | batch_size = cfg['training']['batch_size'] 221 | 222 | fields = { 223 | 'meshes': data.MeshField('mesh'), 224 | } 225 | 226 | ds_shapes = data.Shapes3dClassDataset( 227 | path_meshes, fields, split=None, no_except=True, 228 | ) 229 | 230 | data_loader = torch.utils.data.DataLoader( 231 | ds_shapes, batch_size=batch_size, num_workers=12, shuffle=True) 232 | # collate_fn=data.collate_remove_none) 233 | 234 | return data_loader 235 | 236 | 237 | def get_trainer(models, optimizers, cfg, device=None): 238 | out_dir = cfg['training']['out_dir'] 239 | 240 | print_every = cfg['training']['print_every'] 241 | visualize_every = cfg['training']['visualize_every'] 242 | checkpoint_every = cfg['training']['checkpoint_every'] 243 | validate_every = cfg['training']['validate_every'] 244 | backup_every = cfg['training']['backup_every'] 245 | 246 | model_selection_metric = cfg['training']['model_selection_metric'] 247 | model_selection_mode = cfg['training']['model_selection_mode'] 248 | 249 | ma_beta = cfg['training']['moving_average_beta'] 250 | multi_gpu = cfg['training']['multi_gpu'] 251 | gp_reg = cfg['training']['gradient_penalties_reg'] 252 | w_pix = cfg['training']['weight_pixelloss'] 253 | w_gan = cfg['training']['weight_ganloss'] 254 | w_vae = cfg['training']['weight_vaeloss'] 255 | experiment = cfg['training']['experiment'] 256 | model_url = cfg['model']['model_url'] 257 | trainer = training.Trainer( 258 | models['generator'], models['discriminator'], 259 | optimizers['generator'], optimizers['discriminator'], 260 | ma_beta=ma_beta, 261 | gp_reg=gp_reg, 262 | w_pix=w_pix, w_gan=w_gan, w_vae=w_vae, 263 | multi_gpu=multi_gpu, 264 | experiment=experiment, 265 | out_dir=out_dir, 266 | model_selection_metric=model_selection_metric, 267 | model_selection_mode=model_selection_mode, 268 | print_every=print_every, 269 | visualize_every=visualize_every, 270 | checkpoint_every=checkpoint_every, 271 | backup_every=backup_every, 272 | validate_every=validate_every, 273 | device=device, 274 | model_url=model_url 275 | ) 276 | 277 | return trainer 278 | 279 | 280 | def get_generator(model, cfg, device, **kwargs): 281 | 282 | generator = generation.Generator3D( 283 | model, 284 | device=device, 285 | ) 286 | return generator 287 | 288 | 289 | def get_prior_z(cfg, device, **kwargs): 290 | z_dim = cfg['model']['z_dim'] 291 | p0_z = dist.Normal( 292 | torch.zeros(z_dim, device=device), 293 | torch.ones(z_dim, device=device) 294 | ) 295 | 296 | return p0_z 297 | -------------------------------------------------------------------------------- /mesh2tex/texnet/generation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from trimesh.util import array_to_string 5 | from mesh2tex import geometry 6 | from torchvision.utils import save_image 7 | from torch.nn.functional import interpolate 8 | 9 | #TODO comment the generation functions 10 | 11 | 12 | class Generator3D(object): 13 | def __init__(self, model, device=None): 14 | 15 | self.model = model 16 | self.device = device 17 | 18 | def save_mesh(self, mesh, out_file, digits=10): 19 | ''' 20 | Saving meshes to OFF file 21 | 22 | ''' 23 | digits = int(digits) 24 | # prepend a 3 (face count) to each face 25 | if mesh.visual.face_colors is None: 26 | faces_stacked = np.column_stack(( 27 | np.ones(len(mesh.faces)) * 3, 28 | mesh.faces)).astype(np.int64) 29 | else: 30 | assert(mesh.visual.face_colors.shape[0] == mesh.faces.shape[0]) 31 | faces_stacked = np.column_stack(( 32 | np.ones(len(mesh.faces)) * 3, 33 | mesh.faces, mesh.visual.face_colors[:, :3])).astype(np.int64) 34 | export = 'OFF\n' 35 | # the header is vertex count, face count, edge number 36 | export += str(len(mesh.vertices)) + ' ' + str(len(mesh.faces)) + ' 0\n' 37 | export += array_to_string( 38 | mesh.vertices, col_delim=' ', row_delim='\n', digits=digits) + '\n' 39 | export += array_to_string(faces_stacked, col_delim=' ', row_delim='\n') 40 | 41 | with open(out_file, 'w') as f: 42 | f.write(export) 43 | 44 | return mesh 45 | 46 | def generate_images_4eval_condi(self, batch, out_dir, model_names): 47 | ''' 48 | Generate textures in the conditional setting (given image) 49 | 50 | ''' 51 | 52 | # Extract depth, gt, camera info, shape pc and condition 53 | depth = batch['2d.depth'].to(self.device) 54 | img_real = batch['2d.img'].to(self.device) 55 | cam_K = batch['2d.camera_mat'].to(self.device) 56 | cam_W = batch['2d.world_mat'].to(self.device) 57 | mesh_repr = geometry.get_representation(batch, self.device) 58 | mesh_points = mesh_repr['points'] 59 | mesh_normals = mesh_repr['normals'] 60 | condition = batch['condition'].to(self.device) 61 | 62 | # Determine constants and check 63 | batch_size = depth.size(0) 64 | num_views = depth.size(1) 65 | 66 | # Define Output folders 67 | out_dir_real = out_dir + "/real/" 68 | out_dir_fake = out_dir + "/fake/" 69 | out_dir_condition = out_dir + "/condition/" 70 | if not os.path.exists(out_dir_real): 71 | os.makedirs(out_dir_real) 72 | if not os.path.exists(out_dir_fake): 73 | os.makedirs(out_dir_fake) 74 | if not os.path.exists(out_dir_condition): 75 | os.makedirs(out_dir_condition) 76 | 77 | # Batch loop 78 | for j in range(batch_size): 79 | 80 | # Expand shape info to tensors 81 | # for all views of the same objects 82 | geom_repr = { 83 | 'points': mesh_points[j][:num_views].expand( 84 | num_views, mesh_points.size(1), 85 | mesh_points.size(2)), 86 | 'normals': mesh_normals[j][:num_views].expand( 87 | num_views, mesh_normals.size(1), 88 | mesh_normals.size(2)), 89 | } 90 | 91 | depth_ = depth[j][:num_views] 92 | img_real_ = img_real[j][:num_views] 93 | condition_ = condition[j][:num_views].expand( 94 | num_views, condition.size(1), 95 | condition.size(2), condition.size(3)) 96 | cam_K_ = cam_K[j][:num_views] 97 | cam_W_ = cam_W[j][:num_views] 98 | 99 | # Generate images and save 100 | self.model.eval() 101 | with torch.no_grad(): 102 | img_fake = self.model(depth_, cam_K_, cam_W_, 103 | geom_repr, condition_) 104 | 105 | save_image( 106 | condition[j].cpu(), 107 | os.path.join(out_dir_condition, 108 | '%s.png' % (model_names[j]))) 109 | 110 | for v in range(num_views): 111 | save_image( 112 | img_real_[v], 113 | os.path.join(out_dir_real, 114 | '%s%03d.png' % (model_names[j], v))) 115 | save_image( 116 | img_fake[v].cpu(), 117 | os.path.join(out_dir_fake, 118 | '%s%03d.png' % (model_names[j], v))) 119 | 120 | def generate_images_4eval_condi_hd(self, batch, out_dir, model_names): 121 | ''' 122 | Generate textures in hd images given condition 123 | 124 | ''' 125 | 126 | # Extract depth, gt, camera info, shape pc and condition 127 | depth = batch['2d.depth'] 128 | img_real = batch['2d.img'] 129 | cam_K = batch['2d.camera_mat'] 130 | cam_W = batch['2d.world_mat'] 131 | mesh_repr = geometry.get_representation(batch, self.device) 132 | mesh_points = mesh_repr['points'] 133 | mesh_normals = mesh_repr['normals'] 134 | condition = batch['condition'] 135 | 136 | # Determine constants and check 137 | batch_size = depth.size(0) 138 | num_views = depth.size(1) 139 | 140 | # Define Output folders 141 | out_dir_real = out_dir + "/real/" 142 | out_dir_fake = out_dir + "/fake/" 143 | out_dir_condition = out_dir + "/condition/" 144 | if not os.path.exists(out_dir_real): 145 | os.makedirs(out_dir_real) 146 | if not os.path.exists(out_dir_fake): 147 | os.makedirs(out_dir_fake) 148 | if not os.path.exists(out_dir_condition): 149 | os.makedirs(out_dir_condition) 150 | 151 | # Loop through batch and views, because of memory requirement 152 | viewbatchsize = 1 153 | viewbatchnum = int(num_views / viewbatchsize) 154 | for j in range(batch_size): 155 | for vidx in range(viewbatchnum): 156 | lower = vidx * viewbatchsize 157 | upper = (vidx + 1) * viewbatchsize 158 | 159 | # Expand shape info to tensors 160 | # for all views of the same objects 161 | geom_repr = { 162 | 'points': mesh_points[j][:4].expand( 163 | viewbatchsize, mesh_points.size(1), 164 | mesh_points.size(2)), 165 | 'normals': mesh_normals[j][:4].expand( 166 | viewbatchsize, mesh_normals.size(1), 167 | mesh_normals.size(2)), 168 | } 169 | 170 | depth_ = depth[j][lower:upper].to(self.device) 171 | img_real_ = img_real[j][lower:upper] 172 | if len(condition.size()) == 1: 173 | condition_ = condition[j:j+1].expand( 174 | viewbatchsize) 175 | else: 176 | condition_ = condition[j:j+1][:4].expand( 177 | viewbatchsize, condition.size(1), 178 | condition.size(2), condition.size(3)).to(self.device) 179 | cam_K_ = cam_K[j][lower:upper].to(self.device) 180 | cam_W_ = cam_W[j][lower:upper].to(self.device) 181 | 182 | # Generate images and save 183 | self.model.eval() 184 | with torch.no_grad(): 185 | img_fake = self.model(depth_, cam_K_, cam_W_, 186 | geom_repr, condition_) 187 | if len(condition.size()) != 1: 188 | save_image( 189 | condition[j].cpu(), 190 | os.path.join(out_dir_condition, 191 | '%s.png' % (model_names[j]))) 192 | 193 | for v in range(viewbatchsize): 194 | save_image( 195 | img_real_[v], 196 | os.path.join( 197 | out_dir_real, 198 | '%s%03d.png' % (model_names[j], 199 | vidx * viewbatchsize + v))) 200 | save_image( 201 | img_fake[v].cpu(), 202 | os.path.join( 203 | out_dir_fake, 204 | '%s%03d.png' % (model_names[j], 205 | vidx * viewbatchsize + v))) 206 | 207 | def generate_images_4eval_vae(self, batch, out_dir, model_names): 208 | ''' 209 | Generate texture using the VAE 210 | 211 | ''' 212 | # Extract depth, gt, camera info, shape pc and condition 213 | depth = batch['2d.depth'].to(self.device) 214 | img_real = batch['2d.img'].to(self.device) 215 | cam_K = batch['2d.camera_mat'].to(self.device) 216 | cam_W = batch['2d.world_mat'].to(self.device) 217 | mesh_repr = geometry.get_representation(batch, self.device) 218 | mesh_points = mesh_repr['points'] 219 | mesh_normals = mesh_repr['normals'] 220 | 221 | # Determine constants and check 222 | batch_size = depth.size(0) 223 | num_views = depth.size(1) 224 | if depth.size(1) >= 10: 225 | num_views = 10 226 | 227 | # Define Output folders 228 | out_dir_real = out_dir + "/real/" 229 | out_dir_fake = out_dir + "/fake/" 230 | if not os.path.exists(out_dir_real): 231 | os.makedirs(out_dir_real) 232 | if not os.path.exists(out_dir_fake): 233 | os.makedirs(out_dir_fake) 234 | 235 | # batch loop 236 | for j in range(batch_size): 237 | geom_repr = { 238 | 'points': mesh_points[j][:num_views].expand( 239 | num_views, mesh_points.size(1), mesh_points.size(2)), 240 | 'normals': mesh_normals[j][:num_views].expand( 241 | num_views, mesh_normals.size(1), mesh_normals.size(2)), 242 | } 243 | depth_ = depth[j][:num_views] 244 | img_real_ = img_real[j][:num_views] 245 | cam_K_ = cam_K[j][:num_views] 246 | cam_W_ = cam_W[j][:num_views] 247 | 248 | # Sample latent code 249 | z_ = np.random.normal(0, 1, 512) 250 | inter = torch.from_numpy(z_).float().to(self.device) 251 | z = inter.expand(num_views, 512) 252 | 253 | # Generate images and save 254 | self.model.eval() 255 | with torch.no_grad(): 256 | img_fake = self.model(depth_, cam_K_, cam_W_, 257 | geom_repr, z=z, sample=False) 258 | 259 | for v in range(num_views): 260 | save_image( 261 | img_real_[v], 262 | os.path.join(out_dir_real, '%s%03d.png' 263 | % (model_names[j], v))) 264 | save_image( 265 | img_fake[v].cpu(), 266 | os.path.join(out_dir_fake, '%s%03d.png' 267 | % (model_names[j], v))) 268 | 269 | def generate_images_4eval_vae_interpol(self, batch, out_dir, model_names): 270 | ''' 271 | Interpolates between latent encoding 272 | of first and second element of batch 273 | 274 | ''' 275 | # Extract depth, gt, camera info, shape pc and condition 276 | depth = batch['2d.depth'].to(self.device) 277 | img_real = batch['2d.img'].to(self.device) 278 | cam_K = batch['2d.camera_mat'].to(self.device) 279 | cam_W = batch['2d.world_mat'].to(self.device) 280 | mesh_repr = geometry.get_representation(batch, self.device) 281 | mesh_points = mesh_repr['points'] 282 | mesh_normals = mesh_repr['normals'] 283 | 284 | # Determine constants and check 285 | batch_size = depth.size(0) 286 | num_views = depth.size(1) 287 | if depth.size(1) >= 10: 288 | num_views = 10 289 | 290 | # Define Output folders 291 | out_dir_real = out_dir + "/real/" 292 | out_dir_fake = out_dir + "/fake/" 293 | if not os.path.exists(out_dir_real): 294 | os.makedirs(out_dir_real) 295 | if not os.path.exists(out_dir_fake): 296 | os.makedirs(out_dir_fake) 297 | 298 | # Derive latent texture code as starting point of interpolation 299 | geom_repr = { 300 | 'points': mesh_points[:1], 301 | 'normals': mesh_normals[:1], 302 | } 303 | self.model.eval() 304 | shape_encoding = self.model.encode_geometry(geom_repr) 305 | image_input = img_real[0][:1] 306 | img = interpolate(image_input, size=[128, 128]) 307 | latent_input = self.model.infer_z_transfer(img, shape_encoding) 308 | 309 | # Derive latent texture code as end point of interpolation 310 | geom_repr2 = { 311 | 'points': mesh_points[1:2], 312 | 'normals': mesh_normals[1:2], 313 | } 314 | shape_encoding2 = self.model.encode_geometry(geom_repr2) 315 | image_input2 = img_real[1][:1] 316 | img2 = interpolate(image_input2, size=[128, 128]) 317 | latent_input2 = self.model.infer_z_transfer(img2, shape_encoding2) 318 | 319 | # Derive stepsize 320 | steps = 20 321 | step = (latent_input2-latent_input)/steps 322 | 323 | # batch loop 324 | for j in range(1, batch_size): 325 | 326 | geom_repr = { 327 | 'points': mesh_points[j][:num_views].expand( 328 | num_views, mesh_points.size(1), mesh_points.size(2)), 329 | 'normals': mesh_normals[j][:num_views].expand( 330 | num_views, mesh_normals.size(1), mesh_normals.size(2)), 331 | } 332 | 333 | depth_ = depth[j][:num_views] 334 | img_real_ = img_real[j][:num_views] 335 | cam_K_ = cam_K[j][:num_views] 336 | cam_W_ = cam_W[j][:num_views] 337 | 338 | self.model.eval() 339 | # steps loop 340 | for num in range(steps): 341 | inter = latent_input + step*num 342 | z = inter.expand(num_views, 512) 343 | with torch.no_grad(): 344 | img_fake = self.model(depth_, cam_K_, cam_W_, 345 | geom_repr, z=z, sample=False) 346 | for v in range(1): 347 | save_image( 348 | img_real_[v], 349 | os.path.join( 350 | out_dir_real, '%s%03d_%03d.png' 351 | % (model_names[j], v, num))) 352 | save_image( 353 | img_fake[v].cpu(), 354 | os.path.join( 355 | out_dir_fake, '%s%03d_%03d.png' 356 | % (model_names[j], v, num))) 357 | 358 | def generate_images_4eval_gan(self, batch, out_dir, model_names): 359 | ''' 360 | Generate Texture using a GAN 361 | 362 | ''' 363 | # Extract depth, gt, camera info, shape pc and condition 364 | depth = batch['2d.depth'].to(self.device) 365 | img_real = batch['2d.img'].to(self.device) 366 | cam_K = batch['2d.camera_mat'].to(self.device) 367 | cam_W = batch['2d.world_mat'].to(self.device) 368 | mesh_repr = geometry.get_representation(batch, self.device) 369 | mesh_points = mesh_repr['points'] 370 | mesh_normals = mesh_repr['normals'] 371 | 372 | # Determine constants and check 373 | batch_size = depth.size(0) 374 | num_views = depth.size(1) 375 | if depth.size(1) >= 10: 376 | num_views = 10 377 | 378 | # Define Output folders 379 | out_dir_real = out_dir + "/real/" 380 | out_dir_fake = out_dir + "/fake/" 381 | out_dir_condition = out_dir + "/condition/" 382 | if not os.path.exists(out_dir_real): 383 | os.makedirs(out_dir_real) 384 | if not os.path.exists(out_dir_fake): 385 | os.makedirs(out_dir_fake) 386 | if not os.path.exists(out_dir_condition): 387 | os.makedirs(out_dir_condition) 388 | 389 | # batch loop 390 | for j in range(batch_size): 391 | 392 | geom_repr = { 393 | 'points': mesh_points[j][:num_views].expand( 394 | num_views, mesh_points.size(1), 395 | mesh_points.size(2)), 396 | 'normals': mesh_normals[j][:num_views].expand( 397 | num_views, mesh_normals.size(1), 398 | mesh_normals.size(2)), 399 | } 400 | 401 | depth_ = depth[j][:num_views] 402 | img_real_ = img_real[j][:num_views] 403 | cam_K_ = cam_K[j][:num_views] 404 | cam_W_ = cam_W[j][:num_views] 405 | 406 | self.model.eval() 407 | with torch.no_grad(): 408 | img_fake = self.model(depth_, cam_K_, cam_W_, 409 | geom_repr, sample=False) 410 | for v in range(num_views): 411 | save_image( 412 | img_real_[v], 413 | os.path.join( 414 | out_dir_real, '%s%03d.png' % (model_names[j], v))) 415 | save_image( 416 | img_fake[v].cpu(), 417 | os.path.join( 418 | out_dir_fake, '%s%03d.png' % (model_names[j], v))) 419 | 420 | 421 | def make_3d_grid(bb_min, bb_max, shape): 422 | ''' 423 | Outputs gird points of a 3d grid 424 | 425 | ''' 426 | size = shape[0] * shape[1] * shape[2] 427 | 428 | pxs = torch.linspace(bb_min[0], bb_max[0], shape[0]) 429 | pys = torch.linspace(bb_min[1], bb_max[1], shape[1]) 430 | pzs = torch.linspace(bb_min[2], bb_max[2], shape[2]) 431 | 432 | pxs = pxs.view(-1, 1, 1).expand(*shape).contiguous().view(size) 433 | pys = pys.view(1, -1, 1).expand(*shape).contiguous().view(size) 434 | pzs = pzs.view(1, 1, -1).expand(*shape).contiguous().view(size) 435 | p = torch.stack([pxs, pys, pzs], dim=1) 436 | 437 | return p 438 | -------------------------------------------------------------------------------- /mesh2tex/texnet/models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import distributions as dist 5 | import trimesh 6 | from mesh2tex.texnet.models import ( 7 | image_encoder, decoder, discriminator, vae_encoder 8 | ) 9 | 10 | encoder_dict = { 11 | 'resnet18': image_encoder.Resnet18, 12 | } 13 | 14 | decoder_dict = { 15 | 'each_layer_c': decoder.DecoderEachLayerC, 16 | 'each_layer_c_larger': decoder.DecoderEachLayerCLarger, 17 | } 18 | 19 | discriminator_dict = { 20 | 'resnet_conditional': discriminator.Resnet_Conditional, 21 | } 22 | 23 | vae_encoder_dict = { 24 | 'resnet': vae_encoder.Resnet, 25 | } 26 | 27 | 28 | class TextureNetwork(nn.Module): 29 | def __init__(self, decoder, geometry_encoder, encoder=None, 30 | vae_encoder=None, p0_z=None, white_bg=True): 31 | super().__init__() 32 | 33 | if p0_z is None: 34 | p0_z = dist.Normal(torch.tensor([]), torch.tensor([])) 35 | 36 | self.decoder = decoder 37 | self.encoder = encoder 38 | self.geometry_encoder = geometry_encoder 39 | self.vae_encoder = vae_encoder 40 | self.p0_z = p0_z 41 | self.white_bg = white_bg 42 | 43 | def forward(self, depth, cam_K, cam_W, geometry, 44 | condition=None, z=None, sample=True): 45 | """Generate an image . 46 | 47 | Args: 48 | depth (torch.FloatTensor): tensor of size B x 1 x N x M 49 | representing depth of at pixels 50 | cam_K (torch.FloatTensor): tensor of size B x 3 x 4 representing 51 | camera projectin matrix 52 | cam_W (torch.FloatTensor): tensor of size B x 3 x 4 representing 53 | camera world matrix 54 | geometry (dict): representation of geometry 55 | condition 56 | z 57 | sample (Boolean): wether to sample latent code or take MAP 58 | Returns: 59 | img (torch.FloatTensor): tensor of size B x 3 x N x M representing 60 | output image 61 | """ 62 | batch_size, _, N, M = depth.size() 63 | assert(depth.size(1) == 1) 64 | assert(cam_K.size() == (batch_size, 3, 4)) 65 | assert(cam_W.size() == (batch_size, 3, 4)) 66 | 67 | loc3d, mask = self.depth_map_to_3d(depth, cam_K, cam_W) 68 | geom_descr = self.encode_geometry(geometry) 69 | 70 | if self.encoder is not None: 71 | z = self.encode(condition) 72 | z = z.cuda() 73 | elif z is None: 74 | z = self.get_z_from_prior((batch_size,), sample=sample) 75 | 76 | loc3d = loc3d.view(batch_size, 3, N * M) 77 | x = self.decode(loc3d, geom_descr, z) 78 | x = x.view(batch_size, 3, N, M) 79 | 80 | if self.white_bg is False: 81 | x_bg = torch.zeros_like(x) 82 | else: 83 | x_bg = torch.ones_like(x) 84 | 85 | img = (mask * x).permute(0, 1, 3, 2) + (1 - mask.permute(0, 1, 3, 2)) * x_bg 86 | 87 | return img 88 | 89 | def load_mesh2facecenter(in_path): 90 | mesh = trimesh.load(in_path, process=False) 91 | faces_center = mesh.triangles_center 92 | return mesh, faces_center 93 | 94 | def elbo(self, image_real, depth, cam_K, cam_W, geometry): 95 | batch_size, _, N, M = depth.size() 96 | 97 | assert(depth.size(1) == 1) 98 | assert(cam_K.size() == (batch_size, 3, 4)) 99 | assert(cam_W.size() == (batch_size, 3, 4)) 100 | 101 | loc3d, mask = self.depth_map_to_3d(depth, cam_K, cam_W) 102 | geom_descr = self.encode_geometry(geometry) 103 | 104 | q_z = self.infer_z(image_real, geom_descr) 105 | z = q_z.rsample() 106 | 107 | loc3d = loc3d.view(batch_size, 3, N * M) 108 | x = self.decode(loc3d, geom_descr, z) 109 | x = x.view(batch_size, 3, N, M) 110 | 111 | if self.white_bg is False: 112 | x_bg = torch.zeros_like(x) 113 | else: 114 | x_bg = torch.ones_like(x) 115 | 116 | image_fake = (mask * x).permute(0, 1, 3, 2) + (1 - mask.permute(0, 1, 3, 2)) * x_bg 117 | 118 | recon_loss = F.mse_loss(image_fake, image_real).sum(dim=-1) 119 | kl = dist.kl_divergence(q_z, self.p0_z).sum(dim=-1) 120 | elbo = recon_loss.mean() + kl.mean()/float(N*M*3) 121 | return elbo, recon_loss.mean(), kl.mean()/float(N*M*3), image_fake 122 | 123 | def encode(self, cond): 124 | """Encode mesh using sampled 3D location on the mesh. 125 | 126 | Args: 127 | input_image (torch.FloatTensor): tensor of size B x 3 x N x M 128 | input image 129 | 130 | Returns: 131 | c (torch.FloatTensor): tensor of size B x C with encoding of 132 | the input image 133 | """ 134 | z = self.encoder(cond) 135 | return z 136 | 137 | def encode_geometry(self, geometry): 138 | """Encode mesh using sampled 3D location on the mesh. 139 | 140 | Args: 141 | geometry (dict): representation of teometry 142 | Returns: 143 | geom_descr (dict): geometry discriptor 144 | 145 | """ 146 | geom_descr = self.geometry_encoder(geometry) 147 | return geom_descr 148 | 149 | def decode(self, loc3d, c, z): 150 | """Decode image from 3D locations, conditional encoding and latent 151 | encoding. 152 | 153 | Args: 154 | loc3d (torch.FloatTensor): tensor of size B x 3 x K 155 | with 3D locations of the query 156 | c (torch.FloatTensor): tensor of size B x C with the encoding of 157 | the 3D meshes 158 | z (torch.FloatTensor): tensor of size B x Z with latent codes 159 | 160 | Returns: 161 | rgb (torch.FloatTensor): tensor of size B x 3 x N representing 162 | color at given 3d locations 163 | """ 164 | rgb = self.decoder(loc3d, c, z) 165 | return rgb 166 | 167 | def depth_map_to_3d(self, depth, cam_K, cam_W): 168 | """Derive 3D locations of each pixel of a depth map. 169 | 170 | Args: 171 | depth (torch.FloatTensor): tensor of size B x 1 x N x M 172 | with depth at every pixel 173 | cam_K (torch.FloatTensor): tensor of size B x 3 x 4 representing 174 | camera matrices 175 | cam_W (torch.FloatTensor): tensor of size B x 3 x 4 representing 176 | world matrices 177 | Returns: 178 | loc3d (torch.FloatTensor): tensor of size B x 3 x N x M 179 | representing color at given 3d locations 180 | mask (torch.FloatTensor): tensor of size B x 1 x N x M with 181 | a binary mask if the given pixel is present or not 182 | """ 183 | 184 | assert(depth.size(1) == 1) 185 | batch_size, _, N, M = depth.size() 186 | device = depth.device 187 | # Turn depth around. This also avoids problems with inplace operations 188 | depth = -depth .permute(0, 1, 3, 2) 189 | 190 | zero_one_row = torch.tensor([[0., 0., 0., 1.]]) 191 | zero_one_row = zero_one_row.expand(batch_size, 1, 4).to(device) 192 | 193 | # add row to world mat 194 | cam_W = torch.cat((cam_W, zero_one_row), dim=1) 195 | 196 | # clean depth image for mask 197 | mask = (depth.abs() != float("Inf")).float() 198 | depth[depth == float("Inf")] = 0 199 | depth[depth == -1*float("Inf")] = 0 200 | 201 | # 4d array to 2d array k=N*M 202 | d = depth.reshape(batch_size, 1, N * M) 203 | 204 | # create pixel location tensor 205 | px, py = torch.meshgrid([torch.arange(0, N), torch.arange(0, M)]) 206 | px, py = px.to(device), py.to(device) 207 | 208 | p = torch.cat(( 209 | px.expand(batch_size, 1, px.size(0), px.size(1)), 210 | (M - py).expand(batch_size, 1, py.size(0), py.size(1)) 211 | ), dim=1) 212 | p = p.reshape(batch_size, 2, py.size(0) * py.size(1)) 213 | p = (p.float() / M * 2) 214 | 215 | # create terms of mapping equation x = P^-1 * d*(qp - b) 216 | P = cam_K[:, :2, :2].float().to(device) 217 | q = cam_K[:, 2:3, 2:3].float().to(device) 218 | b = cam_K[:, :2, 2:3].expand(batch_size, 2, d.size(2)).to(device) 219 | Inv_P = torch.inverse(P).to(device) 220 | 221 | rightside = (p.float() * q.float() - b.float()) * d.float() 222 | x_xy = torch.bmm(Inv_P, rightside) 223 | 224 | # add depth and ones to location in world coord system 225 | x_world = torch.cat((x_xy, d, torch.ones_like(d)), dim=1) 226 | 227 | # derive loactoion in object coord via loc3d = W^-1 * x_world 228 | Inv_W = torch.inverse(cam_W) 229 | loc3d = torch.bmm( 230 | Inv_W.expand(batch_size, 4, 4), 231 | x_world 232 | ).reshape(batch_size, 4, N, M) 233 | 234 | loc3d = loc3d[:, :3].to(device) 235 | mask = mask.to(device) 236 | return loc3d, mask 237 | 238 | def get_z_from_prior(self, size=torch.Size([]), sample=True): 239 | """Draw latent code z from prior either using sampling or 240 | using the MAP. 241 | 242 | Args: 243 | size (torch.Size): size of sample to draw. 244 | sample (Boolean): wether to sample or to use the MAP 245 | 246 | Return: 247 | z (torch.FloatTensor): tensor of shape *size x Z representing 248 | the latent code 249 | """ 250 | if sample: 251 | z = self.p0_z.sample(size) 252 | else: 253 | z = self.p0_z.mean 254 | z = z.expand(*size, *z.size()) 255 | 256 | return z 257 | 258 | def infer_z(self, image, c, **kwargs): 259 | if self.vae_encoder is not None: 260 | mean_z, logstd_z = self.vae_encoder(image, c, **kwargs) 261 | else: 262 | batch_size = image.size(0) 263 | mean_z = torch.empty(batch_size, 0).to(self._device) 264 | logstd_z = torch.empty(batch_size, 0).to(self._device) 265 | 266 | q_z = dist.Normal(mean_z, torch.exp(logstd_z)) 267 | return q_z 268 | 269 | def infer_z_transfer(self, image, c, **kwargs): 270 | if self.vae_encoder is not None: 271 | mean_z, logstd_z = self.vae_encoder(image, c, **kwargs) 272 | else: 273 | batch_size = image.size(0) 274 | mean_z = torch.empty(batch_size, 0).to(self._device) 275 | return mean_z 276 | -------------------------------------------------------------------------------- /mesh2tex/texnet/models/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from mesh2tex import common 5 | from mesh2tex.layers import ( 6 | ResnetBlockPointwise, 7 | EqualizedLR 8 | ) 9 | 10 | 11 | class DecoderEachLayerC(nn.Module): 12 | def __init__(self, c_dim=128, z_dim=128, dim=3, 13 | hidden_size=128, leaky=True, 14 | resnet_leaky=True, eq_lr=False): 15 | super().__init__() 16 | self.c_dim = c_dim 17 | self.eq_lr = eq_lr 18 | 19 | # Submodules 20 | if not leaky: 21 | self.actvn = F.relu 22 | else: 23 | self.actvn = lambda x: F.leaky_relu(x, 0.2) 24 | 25 | if not resnet_leaky: 26 | self.resnet_actvn = F.relu 27 | else: 28 | self.resnet_actvn = lambda x: F.leaky_relu(x, 0.2) 29 | 30 | self.conv_p = nn.Conv1d(dim, hidden_size, 1) 31 | 32 | self.block0 = ResnetBlockPointwise( 33 | hidden_size, actvn=self.resnet_actvn, eq_lr=eq_lr) 34 | self.block1 = ResnetBlockPointwise( 35 | hidden_size, actvn=self.resnet_actvn, eq_lr=eq_lr) 36 | self.block2 = ResnetBlockPointwise( 37 | hidden_size, actvn=self.resnet_actvn, eq_lr=eq_lr) 38 | self.block3 = ResnetBlockPointwise( 39 | hidden_size, actvn=self.resnet_actvn, eq_lr=eq_lr) 40 | self.block4 = ResnetBlockPointwise( 41 | hidden_size, actvn=self.resnet_actvn, eq_lr=eq_lr) 42 | 43 | self.fc_cz_0 = nn.Linear(c_dim + z_dim, hidden_size) 44 | self.fc_cz_1 = nn.Linear(c_dim + z_dim, hidden_size) 45 | self.fc_cz_2 = nn.Linear(c_dim + z_dim, hidden_size) 46 | self.fc_cz_3 = nn.Linear(c_dim + z_dim, hidden_size) 47 | self.fc_cz_4 = nn.Linear(c_dim + z_dim, hidden_size) 48 | 49 | self.conv_out = nn.Conv1d(hidden_size, 3, 1) 50 | 51 | if self.eq_lr: 52 | self.conv_p = EqualizedLR(self.conv_p) 53 | self.conv_out = EqualizedLR(self.conv_out) 54 | self.fc_cz_0 = EqualizedLR(self.fc_cz_0) 55 | self.fc_cz_1 = EqualizedLR(self.fc_cz_1) 56 | self.fc_cz_2 = EqualizedLR(self.fc_cz_2) 57 | self.fc_cz_3 = EqualizedLR(self.fc_cz_3) 58 | self.fc_cz_4 = EqualizedLR(self.fc_cz_4) 59 | 60 | # Initialization 61 | nn.init.zeros_(self.conv_out.weight) 62 | 63 | def forward(self, p, geom_descr, z, **kwargs): 64 | c = geom_descr['global'] 65 | batch_size, D, T = p.size() 66 | 67 | cz = torch.cat([c, z], dim=1) 68 | net = self.conv_p(p) 69 | net = net + self.fc_cz_0(cz).unsqueeze(2) 70 | net = self.block0(net) 71 | net = net + self.fc_cz_1(cz).unsqueeze(2) 72 | net = self.block1(net) 73 | net = net + self.fc_cz_2(cz).unsqueeze(2) 74 | net = self.block2(net) 75 | net = net + self.fc_cz_3(cz).unsqueeze(2) 76 | net = self.block3(net) 77 | net = net + self.fc_cz_4(cz).unsqueeze(2) 78 | net = self.block4(net) 79 | 80 | out = self.conv_out(self.actvn(net)) 81 | out = torch.sigmoid(out) 82 | 83 | return out 84 | 85 | 86 | class DecoderEachLayerCLarger(nn.Module): 87 | def __init__(self, c_dim=128, z_dim=128, dim=3, 88 | hidden_size=128, leaky=True, 89 | resnet_leaky=True, eq_lr=False): 90 | super().__init__() 91 | self.c_dim = c_dim 92 | self.eq_lr = eq_lr 93 | if not leaky: 94 | self.actvn = F.relu 95 | else: 96 | self.actvn = lambda x: F.leaky_relu(x, 0.2) 97 | 98 | if not resnet_leaky: 99 | self.resnet_actvn = F.relu 100 | else: 101 | self.resnet_actvn = lambda x: F.leaky_relu(x, 0.2) 102 | 103 | # Submodules 104 | self.conv_p = nn.Conv1d(dim, hidden_size, 1) 105 | 106 | self.block0 = ResnetBlockPointwise( 107 | hidden_size, actvn=self.resnet_actvn, eq_lr=eq_lr) 108 | self.block1 = ResnetBlockPointwise( 109 | hidden_size, actvn=self.resnet_actvn, eq_lr=eq_lr) 110 | self.block2 = ResnetBlockPointwise( 111 | hidden_size, actvn=self.resnet_actvn, eq_lr=eq_lr) 112 | self.block3 = ResnetBlockPointwise( 113 | hidden_size, actvn=self.resnet_actvn, eq_lr=eq_lr) 114 | self.block4 = ResnetBlockPointwise( 115 | hidden_size, actvn=self.resnet_actvn, eq_lr=eq_lr) 116 | self.block5 = ResnetBlockPointwise( 117 | hidden_size, actvn=self.resnet_actvn, eq_lr=eq_lr) 118 | self.block6 = ResnetBlockPointwise( 119 | hidden_size, actvn=self.resnet_actvn, eq_lr=eq_lr) 120 | 121 | self.fc_cz_0 = nn.Linear(c_dim + z_dim, hidden_size) 122 | self.fc_cz_1 = nn.Linear(c_dim + z_dim, hidden_size) 123 | self.fc_cz_2 = nn.Linear(c_dim + z_dim, hidden_size) 124 | self.fc_cz_3 = nn.Linear(c_dim + z_dim, hidden_size) 125 | self.fc_cz_4 = nn.Linear(c_dim + z_dim, hidden_size) 126 | self.fc_cz_5 = nn.Linear(c_dim + z_dim, hidden_size) 127 | self.fc_cz_6 = nn.Linear(c_dim + z_dim, hidden_size) 128 | 129 | self.conv_out = nn.Conv1d(hidden_size, 3, 1) 130 | 131 | if self.eq_lr: 132 | self.conv_p = EqualizedLR(self.conv_p) 133 | self.conv_out = EqualizedLR(self.conv_out) 134 | self.fc_cz_0 = EqualizedLR(self.fc_cz_0) 135 | self.fc_cz_1 = EqualizedLR(self.fc_cz_1) 136 | self.fc_cz_2 = EqualizedLR(self.fc_cz_2) 137 | self.fc_cz_3 = EqualizedLR(self.fc_cz_3) 138 | self.fc_cz_4 = EqualizedLR(self.fc_cz_4) 139 | self.fc_cz_5 = EqualizedLR(self.fc_cz_5) 140 | self.fc_cz_6 = EqualizedLR(self.fc_cz_6) 141 | 142 | # Initialization 143 | nn.init.zeros_(self.conv_out.weight) 144 | 145 | def forward(self, p, geom_descr, z, **kwargs): 146 | c = geom_descr['global'] 147 | batch_size, D, T = p.size() 148 | 149 | cz = torch.cat([c, z], dim=1) 150 | 151 | net = self.conv_p(p) 152 | net = net + self.fc_cz_0(cz).unsqueeze(2) 153 | net = self.block0(net) 154 | net = net + self.fc_cz_1(cz).unsqueeze(2) 155 | net = self.block1(net) 156 | net = net + self.fc_cz_2(cz).unsqueeze(2) 157 | net = self.block2(net) 158 | net = net + self.fc_cz_3(cz).unsqueeze(2) 159 | net = self.block3(net) 160 | net = net + self.fc_cz_4(cz).unsqueeze(2) 161 | net = self.block4(net) 162 | net = net + self.fc_cz_5(cz).unsqueeze(2) 163 | net = self.block5(net) 164 | net = net + self.fc_cz_6(cz).unsqueeze(2) 165 | net = self.block6(net) 166 | 167 | out = self.conv_out(self.actvn(net)) 168 | out = torch.sigmoid(out) 169 | 170 | return out 171 | -------------------------------------------------------------------------------- /mesh2tex/texnet/models/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from mesh2tex.layers import ResnetBlockConv2d, EqualizedLR, pixel_norm 6 | 7 | 8 | class Resnet_Conditional(nn.Module): 9 | def __init__(self, geometry_encoder, img_size, c_dim=128, embed_size=256, 10 | nfilter=64, nfilter_max=1024, 11 | leaky=True, eq_lr=False, pixel_norm=False, 12 | factor=1.): 13 | super().__init__() 14 | self.embed_size = embed_size 15 | s0 = self.s0 = 4 16 | nf = self.nf = nfilter 17 | nf_max = self.nf_max = nfilter_max 18 | self.eq_lr = eq_lr 19 | self.use_pixel_norm = pixel_norm 20 | 21 | # Activation function 22 | if not leaky: 23 | self.actvn = F.relu 24 | else: 25 | self.actvn = lambda x: F.leaky_relu(x, 0.2) 26 | 27 | # Submodules 28 | nlayers = int(np.log2(img_size / s0)) 29 | self.nf0 = min(nf_max, nf * 2**nlayers) 30 | 31 | blocks = [ 32 | ResnetBlockConv2d( 33 | nf, nf, actvn=self.actvn, 34 | eq_lr=eq_lr, 35 | factor=factor, 36 | pixel_norm=pixel_norm) 37 | ] 38 | 39 | for i in range(nlayers): 40 | nf0 = min(nf * 2**i, nf_max) 41 | nf1 = min(nf * 2**(i+1), nf_max) 42 | blocks += [ 43 | nn.AvgPool2d(3, stride=2, padding=1), 44 | ResnetBlockConv2d( 45 | nf0, nf1, actvn=self.actvn, eq_lr=eq_lr, 46 | factor=factor, 47 | pixel_norm=pixel_norm), 48 | ] 49 | 50 | self.conv_img = nn.Conv2d(4, 1*nf, 3, padding=1) 51 | self.resnet = nn.Sequential(*blocks) 52 | self.fc = nn.Linear(self.nf0*s0*s0, 1) 53 | 54 | if self.eq_lr: 55 | self.conv_img = EqualizedLR(self.conv_img) 56 | self.fc = EqualizedLR(self.fc) 57 | 58 | # Initialization 59 | nn.init.zeros_(self.fc.weight) 60 | 61 | def forward(self, x, depth, geom_descr): 62 | batch_size = x.size(0) 63 | 64 | depth = depth.clone() 65 | depth[depth == float("Inf")] = 0 66 | depth[depth == -1*float("Inf")] = 0 67 | 68 | x_and_depth = torch.cat([x, depth], dim=1) 69 | 70 | out = self.conv_img(x_and_depth) 71 | out = self.resnet(out) 72 | 73 | if self.use_pixel_norm: 74 | out = pixel_norm(out) 75 | out = out.view(batch_size, self.nf0*self.s0*self.s0) 76 | out = self.fc(self.actvn(out)) 77 | out = out.squeeze() 78 | return out 79 | -------------------------------------------------------------------------------- /mesh2tex/texnet/models/image_encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision import models 3 | 4 | 5 | class Resnet18(nn.Module): 6 | ''' ResNet-18 conditioning network. 7 | ''' 8 | def __init__(self, c_dim=128, normalize=True, use_linear=True): 9 | ''' Initialization. 10 | 11 | Args: 12 | c_dim (int): output dimension of the latent embedding 13 | normalize (bool): whether the input images should be normalized 14 | use_linear (bool): whether a final linear layer should be used 15 | ''' 16 | super().__init__() 17 | self.normalize = normalize 18 | self.use_linear = use_linear 19 | self.features = models.resnet18(pretrained=True) 20 | self.features.fc = nn.Sequential() 21 | if use_linear: 22 | self.fc = nn.Linear(512, c_dim) 23 | elif c_dim == 512: 24 | self.fc = nn.Sequential() 25 | else: 26 | raise ValueError('c_dim must be 512 if use_linear is False') 27 | 28 | def forward(self, x): 29 | if self.normalize: 30 | x = normalize_imagenet(x) 31 | net = self.features(x) 32 | out = self.fc(net) 33 | return out 34 | 35 | 36 | def normalize_imagenet(x): 37 | x = x.clone() 38 | x[:, 0] = (x[:, 0] - 0.485) / 0.229 39 | x[:, 1] = (x[:, 1] - 0.456) / 0.224 40 | x[:, 2] = (x[:, 2] - 0.406) / 0.225 41 | return x 42 | -------------------------------------------------------------------------------- /mesh2tex/texnet/models/vae_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from mesh2tex.layers import ResnetBlockConv2d, EqualizedLR 6 | 7 | 8 | class Resnet(nn.Module): 9 | def __init__(self, img_size, z_dim=128, c_dim=128, embed_size=256, 10 | nfilter=32, nfilter_max=1024, leaky=True, eq_lr=False): 11 | super().__init__() 12 | self.embed_size = embed_size 13 | s0 = self.s0 = 4 14 | nf = self.nf = nfilter 15 | nf_max = self.nf_max = nfilter_max 16 | self.eq_lr = eq_lr 17 | self.c_dim = c_dim 18 | 19 | # Activation function 20 | if not leaky: 21 | self.actvn = F.relu 22 | else: 23 | self.actvn = lambda x: F.leaky_relu(x, 0.2) 24 | 25 | # Submodules 26 | nlayers = int(np.log2(img_size / s0)) 27 | self.nf0 = min(nf_max, nf * 2**nlayers) 28 | 29 | blocks = [ 30 | ResnetBlockConv2d( 31 | nf, nf, actvn=self.actvn, eq_lr=eq_lr) 32 | ] 33 | 34 | for i in range(nlayers): 35 | nf0 = min(nf * 2**i, nf_max) 36 | nf1 = min(nf * 2**(i+1), nf_max) 37 | blocks += [ 38 | nn.AvgPool2d(3, stride=2, padding=1), 39 | ResnetBlockConv2d( 40 | nf0, nf1, actvn=self.actvn, eq_lr=eq_lr), 41 | ] 42 | 43 | self.conv_img = nn.Conv2d(3, 1*nf, 3, padding=1) 44 | self.resnet = nn.Sequential(*blocks) 45 | self.fc_mean = nn.Linear(self.nf0*s0*s0, z_dim) 46 | self.fc_logstd = nn.Linear(self.nf0*s0*s0, z_dim) 47 | self.fc_inject_c = nn.Linear(self.c_dim, 1*nf) 48 | if self.eq_lr: 49 | self.conv_img = EqualizedLR(self.conv_img) 50 | self.fc = EqualizedLR(self.fc) 51 | 52 | def forward(self, x, geom_descr): 53 | c = geom_descr['global'] 54 | batch_size = x.size(0) 55 | 56 | out = self.conv_img(x) 57 | add = self.fc_inject_c(c).view(out.size(0), self.nf, 1, 1) 58 | out = out + add 59 | out = self.resnet(out) 60 | out = out.view(batch_size, self.nf0*self.s0*self.s0) 61 | 62 | mean = self.fc_mean(self.actvn(out)) 63 | logstd = self.fc_logstd(self.actvn(out)) 64 | return mean, logstd 65 | -------------------------------------------------------------------------------- /mesh2tex/texnet/training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | import numpy as np 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.autograd as autograd 8 | from torchvision.utils import save_image 9 | from mesh2tex import geometry 10 | from mesh2tex.training import BaseTrainer 11 | import mesh2tex.utils.FID.feature_l1 as feature_l1 12 | import mesh2tex.utils.SSIM_L1.ssim_l1_score as SSIM 13 | 14 | 15 | class Trainer(BaseTrainer): 16 | ''' 17 | Subclass of Basetrainer for defining train_step, eval_step and visualize 18 | ''' 19 | def __init__(self, model_g, model_d, 20 | optimizer_g, optimizer_d, 21 | ma_beta=0.99, gp_reg=10., 22 | w_pix=0., w_gan=0., w_vae=0., 23 | experiment='conditional', 24 | gan_type='standard', 25 | loss_type='L1', 26 | multi_gpu=False, 27 | **kwargs): 28 | 29 | # Initialize base trainer 30 | super().__init__(**kwargs) 31 | 32 | # Models and optimizers 33 | self.model_g = model_g 34 | self.model_d = model_d 35 | 36 | self.model_g_ma = copy.deepcopy(model_g) 37 | 38 | for p in self.model_g_ma.parameters(): 39 | p.requires_grad = False 40 | self.model_g_ma.eval() 41 | 42 | self.optimizer_g = optimizer_g 43 | self.optimizer_d = optimizer_d 44 | self.loss_type = loss_type 45 | self.experiment = experiment 46 | # Attributes 47 | self.gp_reg = gp_reg 48 | self.ma_beta = ma_beta 49 | self.gan_type = gan_type 50 | self.multi_gpu = multi_gpu 51 | self.w_pix = w_pix 52 | self.w_vae = w_vae 53 | self.w_gan = w_gan 54 | self.pix_loss = w_pix != 0 55 | self.vae_loss = w_vae != 0 56 | self.gan_loss = w_gan != 0 57 | if self.vae_loss and self.pix_loss: 58 | print('Not possible to combine pix and vae loss') 59 | # Checkpointer 60 | if self.gan_loss is True: 61 | self.checkpoint_io.register_modules( 62 | model_g=self.model_g, model_d=self.model_d, 63 | model_g_ma=self.model_g_ma, 64 | optimizer_g=self.optimizer_g, 65 | optimizer_d=self.optimizer_d, 66 | ) 67 | else: 68 | self.checkpoint_io.register_modules( 69 | model_g=self.model_g, model_d=self.model_d, 70 | model_g_ma=self.model_g_ma, 71 | optimizer_g=self.optimizer_g, 72 | ) 73 | 74 | print('w_pix: %f w_gan: %f w_vae: %f' 75 | % (self.w_pix, self.w_gan, self.w_vae)) 76 | 77 | def train_step(self, batch, epoch_it, it): 78 | ''' 79 | A single training step for the conditional or generative experiment 80 | Output: 81 | Losses 82 | ''' 83 | batch_model0, batch_model1 = batch 84 | if self.experiment == 'conditional': 85 | loss_con = self.train_step_cond(batch_model0) 86 | if self.gan_loss is True: 87 | loss_d = self.train_step_d(batch_model1) 88 | else: 89 | loss_d = 0 90 | 91 | losses = { 92 | 'loss_con': loss_con, 93 | 'loss_d': loss_d, 94 | } 95 | 96 | elif self.experiment == 'generative': 97 | loss_g = self.train_step_g(batch_model0) 98 | if self.gan_loss is True: 99 | loss_d = self.train_step_d(batch_model1) 100 | else: 101 | loss_d = 0 102 | losses = { 103 | 'loss_g': loss_g, 104 | 'loss_d': loss_d, 105 | } 106 | 107 | return losses 108 | 109 | def train_step_d(self, batch): 110 | ''' 111 | A single train step of the discriminator 112 | ''' 113 | model_d = self.model_d 114 | model_g = self.model_g 115 | 116 | model_d.train() 117 | model_g.train() 118 | 119 | if self.multi_gpu: 120 | model_d = nn.DataParallel(model_d) 121 | model_g = nn.DataParallel(model_g) 122 | 123 | self.optimizer_d.zero_grad() 124 | 125 | # Get data 126 | depth = batch['2d.depth'].to(self.device) 127 | img_real = batch['2d.img'].to(self.device) 128 | cam_K = batch['2d.camera_mat'].to(self.device) 129 | cam_W = batch['2d.world_mat'].to(self.device) 130 | mesh_repr = geometry.get_representation(batch, self.device) 131 | condition = batch['condition'].to(self.device) 132 | 133 | # Loss on real 134 | img_real.requires_grad_() 135 | d_real = model_d(img_real, depth, mesh_repr) 136 | 137 | dloss_real = self.compute_gan_loss(d_real, 1) 138 | dloss_real.backward(retain_graph=True) 139 | 140 | # R1 Regularizer 141 | reg = self.gp_reg * compute_grad2(d_real, img_real).mean() 142 | reg.backward() 143 | 144 | # Loss on fake 145 | with torch.no_grad(): 146 | if self.vae_loss is True: 147 | loss_vae, re, kl, img_fake = model_g.elbo(img_real, depth, 148 | cam_K, cam_W, 149 | mesh_repr) 150 | elif self.gan_loss is True: 151 | img_fake = model_g(depth, cam_K, cam_W, mesh_repr, condition) 152 | 153 | d_fake = model_d(img_fake, depth, mesh_repr) 154 | 155 | dloss_fake = self.compute_gan_loss(d_fake, 0) 156 | dloss_fake.backward() 157 | 158 | # Gradient step 159 | self.optimizer_d.step() 160 | 161 | return self.w_gan * (dloss_fake.item() + dloss_real.item()) 162 | 163 | def train_step_g(self, batch): 164 | ''' 165 | A single train step of the generator part of generative model 166 | (VAE: Encoder+Decoder and GAN: Generator) 167 | ''' 168 | model_d = self.model_d 169 | model_g = self.model_g 170 | 171 | model_d.train() 172 | model_g.train() 173 | 174 | if self.multi_gpu: 175 | model_d = nn.DataParallel(model_d) 176 | model_g = nn.DataParallel(model_g) 177 | 178 | self.optimizer_g.zero_grad() 179 | 180 | # Get data 181 | depth = batch['2d.depth'].to(self.device) 182 | img_real = batch['2d.img'].to(self.device) 183 | cam_K = batch['2d.camera_mat'].to(self.device) 184 | cam_W = batch['2d.world_mat'].to(self.device) 185 | mesh_repr = geometry.get_representation(batch, self.device) 186 | 187 | # Loss on fake 188 | loss_vae = 0 189 | loss_gan = 0 190 | 191 | # Forward part and loss derivation for given experiment 192 | if self.vae_loss is True: 193 | loss_vae, re, kl, img_fake = model_g.elbo(img_real, depth, 194 | cam_K, cam_W, 195 | mesh_repr) 196 | if self.gan_loss is True: 197 | d_fake = model_d(img_fake, depth, mesh_repr) 198 | loss_gan = self.compute_gan_loss(d_fake, 1) 199 | elif self.gan_loss is True: 200 | img_fake = model_g(depth, cam_K, cam_W, mesh_repr) 201 | d_fake = model_d(img_fake, depth, mesh_repr) 202 | loss_gan = self.compute_gan_loss(d_fake, 1) 203 | 204 | # weighting of losses (128*128*3=49152) 205 | loss = self.w_vae * loss_vae + self.w_gan * loss_gan 206 | loss.backward() 207 | 208 | # Gradient step 209 | self.optimizer_g.step() 210 | 211 | # Update moving average 212 | # self.update_moving_average() 213 | 214 | return loss.item() 215 | 216 | def train_step_cond(self, batch): 217 | ''' 218 | A single train step of the conditional model 219 | with or w/o generator part of adversarial loss 220 | ''' 221 | model_d = self.model_d 222 | model_g = self.model_g 223 | 224 | model_d.train() 225 | model_g.train() 226 | 227 | if self.multi_gpu: 228 | model_d = nn.DataParallel(model_d) 229 | model_g = nn.DataParallel(model_g) 230 | 231 | self.optimizer_g.zero_grad() 232 | 233 | # Get data 234 | depth = batch['2d.depth'].to(self.device) 235 | img_real = batch['2d.img'].to(self.device) 236 | cam_K = batch['2d.camera_mat'].to(self.device) 237 | cam_W = batch['2d.world_mat'].to(self.device) 238 | mesh_repr = geometry.get_representation(batch, self.device) 239 | condition = batch['condition'].to(self.device) 240 | 241 | # Loss on fake 242 | img_fake = model_g(depth, cam_K, cam_W, mesh_repr, condition) 243 | loss_pix = 0 244 | loss_gan = 0 245 | 246 | if self.pix_loss is True: 247 | loss_pix = self.compute_loss(img_fake, img_real) 248 | if self.gan_loss is True: 249 | d_fake = model_d(img_fake, depth, mesh_repr) 250 | loss_gan = self.compute_gan_loss(d_fake, 1) 251 | 252 | # weighting 253 | loss = self.w_pix * loss_pix + self.w_gan * loss_gan 254 | loss.backward() 255 | 256 | # Gradient step 257 | self.optimizer_g.step() 258 | 259 | # Update moving average 260 | # self.update_moving_average() 261 | 262 | return loss.item() 263 | 264 | def compute_loss(self, img_fake, img_real): 265 | ''' 266 | Compute Pixelloss 267 | ''' 268 | if self.loss_type == 'L2': 269 | loss = F.mse_loss(img_fake, img_real) 270 | elif self.loss_type == 'L1': 271 | loss = F.l1_loss(img_fake, img_real) 272 | else: 273 | raise NotImplementedError 274 | 275 | return loss 276 | 277 | def compute_gan_loss(self, d_out, target): 278 | ''' 279 | Compute GAN loss (standart cross entropy or wasserstein distance) 280 | !!! Without Regularizer 281 | ''' 282 | targets = d_out.new_full(size=d_out.size(), fill_value=target) 283 | 284 | if self.gan_type == 'standard': 285 | loss = F.binary_cross_entropy_with_logits(d_out, targets) 286 | elif self.gan_type == 'wgan': 287 | loss = (2*target - 1) * d_out.mean() 288 | else: 289 | raise NotImplementedError 290 | 291 | return loss 292 | 293 | def eval_step(self, batch): 294 | ''' 295 | Evaluation step with L1, SSIM, featl1 metrics 296 | ''' 297 | depth = batch['2d.depth'].to(self.device) 298 | img_real = batch['2d.img'].to(self.device) 299 | cam_K = batch['2d.camera_mat'].to(self.device) 300 | cam_W = batch['2d.world_mat'].to(self.device) 301 | mesh_repr = geometry.get_representation(batch, self.device) 302 | condition = batch['condition'].to(self.device) 303 | 304 | # Get model 305 | model_g = self.model_g 306 | model_g.eval() 307 | 308 | if self.multi_gpu: 309 | model_g = nn.DataParallel(model_g) 310 | 311 | # Predict 312 | with torch.no_grad(): 313 | img_fake = model_g(depth, cam_K, cam_W, mesh_repr, condition) 314 | 315 | # Derive metrics 316 | loss_val = self.compute_loss(img_fake, img_real) 317 | ssim, l1 = SSIM.calculate_ssim_l1_given_tensor(img_fake, img_real) 318 | featl1 = feature_l1.calculate_feature_l1_given_tensors( 319 | img_fake, img_real, img_real.size(0), True, 2048) 320 | 321 | loss_val_dict = {'loss_val': loss_val.item(), 'SSIM': ssim, 322 | 'featl1': featl1} 323 | return loss_val_dict 324 | 325 | def visualize(self, batch): 326 | ''' 327 | Visualization step 328 | ''' 329 | depth = batch['2d.depth'].to(self.device) 330 | img_real = batch['2d.img'].to(self.device) 331 | cam_K = batch['2d.camera_mat'].to(self.device) 332 | cam_W = batch['2d.world_mat'].to(self.device) 333 | mesh_repr = geometry.get_representation(batch, self.device) 334 | condition = batch['condition'].to(self.device) 335 | 336 | # determine constants 337 | batch_size = depth.size(0) 338 | num_views = depth.size(1) 339 | assert(num_views == 5) 340 | 341 | mesh_points = mesh_repr['points'] 342 | mesh_normals = mesh_repr['normals'] 343 | 344 | # batch loop 345 | for j in range(batch_size): 346 | 347 | # Gather input information on shape, depth, camera and condition 348 | geom_repr = { 349 | 'points': mesh_points[j].expand(num_views, 350 | mesh_points.size(1), 351 | mesh_points.size(2)), 352 | 'normals': mesh_normals[j].expand(num_views, 353 | mesh_normals.size(1), 354 | mesh_normals.size(2)), 355 | } 356 | 357 | depth_ = depth[j] 358 | img_real_ = img_real[j] 359 | if len(condition.size()) == 1: 360 | condition_ = condition[j].expand(num_views) 361 | else: 362 | condition_ = condition[j].expand(num_views, 363 | condition.size(1), 364 | condition.size(2), 365 | condition.size(3)) 366 | cam_K_ = cam_K[j] 367 | cam_W_ = cam_W[j] 368 | 369 | # save real images 370 | save_image(img_real_, 371 | os.path.join(self.vis_dir, 'real_%i.png' % j)) 372 | 373 | # predict fake images and save 374 | self.model_g.eval() 375 | with torch.no_grad(): 376 | img_fake = self.model_g(depth_, cam_K_, cam_W_, 377 | geom_repr, condition_) 378 | save_image(img_fake.cpu(), 379 | os.path.join(self.vis_dir, 'fake_%i.png' % j)) 380 | 381 | # save condition images 382 | if len(condition.size()) != 1: 383 | save_image(condition[j].cpu(), 384 | os.path.join(self.vis_dir, 'condition_%i.png' % j)) 385 | else: 386 | np.savetxt(os.path.join(self.vis_dir, 'condition_%i.txt' % j), 387 | condition_.cpu()) 388 | 389 | def update_moving_average(self): 390 | ''' 391 | Update moving average 392 | ''' 393 | param_dict_src = dict(self.model_g.named_parameters()) 394 | beta = self.ma_beta 395 | for p_name, p_tgt in self.model_g_ma.named_parameters(): 396 | p_src = param_dict_src[p_name] 397 | assert(p_src is not p_tgt) 398 | with torch.no_grad(): 399 | p_ma = beta * p_tgt + (1. - beta) * p_src 400 | p_tgt.copy_(p_ma) 401 | 402 | 403 | def compute_grad2(d_out, x_in): 404 | ''' 405 | Derive L2-Gradient penalty for regularizing the GAN 406 | ''' 407 | batch_size = x_in.size(0) 408 | grad_dout = autograd.grad( 409 | outputs=d_out.sum(), inputs=x_in, 410 | create_graph=True, retain_graph=True, only_inputs=True 411 | )[0] 412 | grad_dout2 = grad_dout.pow(2) 413 | assert(grad_dout2.size() == x_in.size()) 414 | reg = grad_dout2.view(batch_size, -1).sum(1) 415 | return reg 416 | -------------------------------------------------------------------------------- /mesh2tex/training.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | import time 4 | import logging 5 | import numpy as np 6 | from tensorboardX import SummaryWriter 7 | from tqdm import tqdm 8 | from mesh2tex.checkpoints import CheckpointIO 9 | 10 | LOGGER = logging.getLogger(__name__) 11 | 12 | 13 | class BaseTrainer(object): 14 | def __init__(self, 15 | out_dir, 16 | model_selection_metric, model_selection_mode, 17 | print_every, visualize_every, checkpoint_every, 18 | backup_every, validate_every, device=None, model_url=None): 19 | # Directories 20 | self.out_dir = out_dir 21 | self.vis_dir = os.path.join(out_dir, 'vis') 22 | self.log_dir = os.path.join(out_dir, 'log') 23 | self.model_url = model_url 24 | 25 | self.model_selection_metric = model_selection_metric 26 | if model_selection_mode == 'maximize': 27 | self.model_selection_sign = 1 28 | elif model_selection_mode == 'minimize': 29 | self.model_selection_sign = -1 30 | else: 31 | raise ValueError('model_selection_mode must be ' 32 | 'either maximize or minimize.') 33 | 34 | self.print_every = print_every 35 | self.visualize_every = visualize_every 36 | self.checkpoint_every = checkpoint_every 37 | self.backup_every = backup_every 38 | self.validate_every = validate_every 39 | self.device = device 40 | 41 | # Checkpointer 42 | self.checkpoint_io = CheckpointIO(out_dir) 43 | 44 | # Create directories 45 | all_dirs = [self.out_dir, self.vis_dir, self.log_dir] 46 | for directory in all_dirs: 47 | if not os.path.exists(directory): 48 | os.makedirs(directory) 49 | 50 | def train(self, train_loader, val_loader, vis_loader, 51 | exit_after=None, n_epochs=None): 52 | """ 53 | Main training method with epoch loop, validation and model selection 54 | 55 | args: 56 | train_loader 57 | val_loader (Validation) 58 | vis_loader (Visualsation during training) 59 | """ 60 | # Load if checkpoint exist 61 | epoch_it, it, metric_val_best = self.init_training() 62 | print('Current best validation metric (%s): %.8f' 63 | % (self.model_selection_metric, metric_val_best)) 64 | 65 | # for tensorboard 66 | summary_writer = SummaryWriter(os.path.join(self.log_dir)) 67 | 68 | if self.visualize_every > 0: 69 | if vis_loader is None: 70 | data_vis = next(iter(val_loader)) 71 | else: 72 | data_vis = next(iter(vis_loader)) 73 | 74 | # Main training loop 75 | t0 = time.time() 76 | while (n_epochs is None) or (epoch_it < n_epochs): 77 | epoch_it += 1 78 | 79 | for batch in train_loader: 80 | it += 1 81 | 82 | losses = self.train_step(batch, epoch_it=epoch_it, it=it) 83 | 84 | if isinstance(losses, dict): 85 | loss_str = [] 86 | for k, v in losses.items(): 87 | summary_writer.add_scalar('train/%s' % k, v, it) 88 | loss_str.append('%s=%.4f' % (k, v)) 89 | loss_str = ' '.join(loss_str) 90 | else: 91 | summary_writer.add_scalar('train/loss', losses, it) 92 | loss_str = ('loss=%.4f' % losses) 93 | 94 | # Print output 95 | if self.print_every > 0 and (it % self.print_every) == 0: 96 | print('[Epoch %02d] it=%03d, %s' 97 | % (epoch_it, it, loss_str)) 98 | 99 | # Visualize output 100 | if (self.visualize_every > 0 101 | and (it % self.visualize_every) == 0): 102 | print('Visualizing') 103 | try: 104 | self.visualize(data_vis) 105 | except NotImplementedError: 106 | LOGGER.warn('Visualizing method not implemented.') 107 | 108 | # Save checkpoint 109 | if (self.checkpoint_every > 0 110 | and (it % self.checkpoint_every) == 0): 111 | print('Saving checkpoint') 112 | self.checkpoint_io.save( 113 | 'model.pt', epoch_it=epoch_it, it=it, 114 | loss_val_best=metric_val_best) 115 | 116 | # Backup if necessary 117 | if (self.backup_every > 0 118 | and (it % self.backup_every) == 0): 119 | print('Backup checkpoint') 120 | self.checkpoint_io.save( 121 | 'model_%d.pt' % it, epoch_it=epoch_it, it=it, 122 | loss_val_best=metric_val_best) 123 | 124 | # Run validation and select if better 125 | if self.validate_every > 0 and (it % self.validate_every) == 0: 126 | try: 127 | eval_dict = self.evaluate(val_loader) 128 | print(eval_dict) 129 | except NotImplementedError: 130 | LOGGER.warn('Evaluation method not implemented.') 131 | eval_dict = {} 132 | 133 | for k, v in eval_dict.items(): 134 | summary_writer.add_scalar('val/%s' % k, v, it) 135 | 136 | if self.model_selection_metric is not None: 137 | metric_val = eval_dict[self.model_selection_metric] 138 | print( 139 | 'Validation metric (%s): %.4f' 140 | % (self.model_selection_metric, metric_val)) 141 | 142 | improvement = ( 143 | self.model_selection_sign 144 | * (metric_val - metric_val_best) 145 | ) 146 | if improvement > 0: 147 | metric_val_best = metric_val 148 | print('New best model (loss %.4f)' 149 | % metric_val_best) 150 | self.checkpoint_io.save( 151 | 'model_best.pt', epoch_it=epoch_it, it=it, 152 | loss_val_best=metric_val_best) 153 | 154 | # Exit if necessary 155 | if exit_after > 0 and (time.time() - t0) >= exit_after: 156 | print('Time limit reached. Exiting.') 157 | self.checkpoint_io.save( 158 | 'model.pt', epoch_it=epoch_it, it=it, 159 | loss_val_best=metric_val_best) 160 | exit(3) 161 | 162 | print('Maximum number of epochs reached. Exiting.') 163 | self.checkpoint_io.save( 164 | 'model.pt', epoch_it=epoch_it, it=it, 165 | loss_val_best=metric_val_best) 166 | 167 | def evaluate(self, val_loader): 168 | ''' 169 | Evaluate model with validation data using eval_step 170 | 171 | args: 172 | data loader 173 | ''' 174 | 175 | eval_list = defaultdict(list) 176 | 177 | for data in tqdm(val_loader): 178 | eval_step_dict = self.eval_step(data) 179 | 180 | for k, v in eval_step_dict.items(): 181 | eval_list[k].append(v) 182 | 183 | eval_dict = {k: np.mean(v) for k, v in eval_list.items()} 184 | return eval_dict 185 | 186 | def init_training(self): 187 | ''' 188 | Init training by loading the latest checkpoint 189 | ''' 190 | try: 191 | if self.model_url is not None: 192 | load_dict = self.checkpoint_io.load(self.model_url) 193 | else: 194 | load_dict = self.checkpoint_io.load('model.pt') 195 | except FileExistsError: 196 | load_dict = dict() 197 | epoch_it = load_dict.get('epoch_it', -1) 198 | it = load_dict.get('it', -1) 199 | metric_val_best = load_dict.get( 200 | 'loss_val_best', -self.model_selection_sign * np.inf) 201 | 202 | return epoch_it, it, metric_val_best 203 | 204 | def train_step(self, *args, **kwargs): 205 | raise NotImplementedError 206 | 207 | def eval_step(self, *args, **kwargs): 208 | raise NotImplementedError 209 | 210 | def visualize(self, *args, **kwargs): 211 | raise NotImplementedError 212 | -------------------------------------------------------------------------------- /mesh2tex/utils/FID/feature_l1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import torch 4 | import numpy as np 5 | from imageio import imread 6 | from torch.autograd import Variable 7 | from torch.nn.functional import adaptive_avg_pool2d 8 | 9 | from mesh2tex.utils.FID.inception import InceptionV3 10 | 11 | 12 | def get_activations(images, model, batch_size=64, dims=2048, 13 | cuda=False, verbose=False): 14 | """Calculates the activations of the pool_3 layer for all images. 15 | 16 | Params: 17 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values 18 | must lie between 0 and 1. 19 | -- model : Instance of inception model 20 | -- batch_size : the images numpy array is split into batches with 21 | batch size batch_size. A reasonable batch size depends 22 | on the hardware. 23 | -- dims : Dimensionality of features returned by Inception 24 | -- cuda : If set to True, use GPU 25 | -- verbose : If set to True and parameter out_step is given, the number 26 | of calculated batches is reported. 27 | Returns: 28 | -- A numpy array of dimension (num images, dims) that contains the 29 | activations of the given tensor when feeding inception with the 30 | query tensor. 31 | """ 32 | model.eval() 33 | 34 | d0 = images.shape[0] 35 | if batch_size > d0: 36 | print(('Warning: batch size is bigger than the data size. ' 37 | 'Setting batch size to data size')) 38 | batch_size = d0 39 | 40 | n_batches = d0 // batch_size 41 | n_used_imgs = n_batches * batch_size 42 | 43 | pred_arr = np.empty((n_used_imgs, dims)) 44 | for i in range(n_batches): 45 | if verbose: 46 | print('\rPropagating batch %d/%d' % (i + 1, n_batches), 47 | end='', flush=True) 48 | start = i * batch_size 49 | end = start + batch_size 50 | 51 | batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor) 52 | batch = Variable(batch, volatile=True) 53 | if cuda: 54 | batch = batch.cuda() 55 | 56 | pred = model(batch)[0] 57 | 58 | # If model output is not scalar, apply global spatial average pooling. 59 | # This happens if you choose a dimensionality not equal 2048. 60 | if pred.shape[2] != 1 or pred.shape[3] != 1: 61 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 62 | 63 | pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1) 64 | 65 | if verbose: 66 | print(' done') 67 | 68 | return pred_arr 69 | 70 | 71 | def _compute_statistics_of_path(path0, path1, model, batch_size, dims, cuda): 72 | path0 = pathlib.Path(path0) 73 | path1 = pathlib.Path(path1) 74 | 75 | files_list = os.listdir(path0) 76 | files0 = [os.path.join(path0, f) for f in files_list] 77 | files1 = [os.path.join(path1, f) for f in files_list] 78 | assert(len(files0) == len(files1)) 79 | 80 | # First set of images 81 | imgs = np.array([imread(str(fn)).astype(np.float32) for fn in files0]) 82 | # Bring images to shape (B, 3, H, W) 83 | imgs = imgs.transpose((0, 3, 1, 2))[:, :3] 84 | # Rescale images to be between 0 and 1 85 | imgs /= 255 86 | feat0 = get_activations(imgs, model, batch_size, dims, cuda, False) 87 | 88 | # Second set of images 89 | imgs = np.array([imread(str(fn)).astype(np.float32) for fn in files1]) 90 | # Bring images to shape (B, 3, H, W) 91 | imgs = imgs.transpose((0, 3, 1, 2))[:, :3] 92 | # Rescale images to be between 0 and 1 93 | imgs /= 255 94 | feat1 = get_activations(imgs, model, batch_size, dims, cuda, False) 95 | 96 | feature_l1 = np.mean(np.abs(feat0 - feat1)) 97 | return feature_l1 98 | 99 | 100 | def _compute_statistics_of_tensors(images_fake, images_real, model, batch_size, 101 | dims, cuda): 102 | 103 | # First set of images 104 | imgs = images_fake.cpu().numpy().astype(np.float32) 105 | feat0 = get_activations(imgs, model, batch_size, dims, cuda, False) 106 | 107 | # Second set of images 108 | imgs = images_real.cpu().numpy().astype(np.float32) 109 | feat1 = get_activations(imgs, model, batch_size, dims, cuda, False) 110 | 111 | feature_l1 = np.mean(np.abs(feat0 - feat1)) 112 | return feature_l1 113 | 114 | 115 | def calculate_feature_l1_given_paths(paths, batch_size, cuda, dims): 116 | """Calculates the FID of two paths""" 117 | for p in paths: 118 | if not os.path.exists(p): 119 | raise RuntimeError('Invalid path: %s' % p) 120 | 121 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 122 | 123 | model = InceptionV3([block_idx]) 124 | if cuda: 125 | model.cuda() 126 | 127 | feature_l1 = _compute_statistics_of_path( 128 | paths[0], paths[1], model, batch_size, dims, cuda) 129 | 130 | return feature_l1 131 | 132 | 133 | def calculate_feature_l1_given_tensors(tensor1, tensor2, 134 | batch_size, cuda, dims): 135 | """Calculates the FID of two image tensors""" 136 | 137 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 138 | 139 | model = InceptionV3([block_idx]) 140 | if cuda: 141 | model.cuda() 142 | 143 | feature_l1 = _compute_statistics_of_tensors( 144 | tensor1, tensor2, model, batch_size, dims, cuda) 145 | 146 | return feature_l1 147 | -------------------------------------------------------------------------------- /mesh2tex/utils/FID/fid_score.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 4 | 5 | import torch 6 | import numpy as np 7 | from scipy.misc import imread 8 | from scipy import linalg 9 | from torch.autograd import Variable 10 | from torch.nn.functional import adaptive_avg_pool2d 11 | 12 | from mesh2tex.utils.FID.inception import InceptionV3 13 | 14 | 15 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 16 | parser.add_argument('path', type=str, nargs=2, 17 | help=('Path to the generated images or ' 18 | 'to .npz statistic files')) 19 | parser.add_argument('--batch-size', type=int, default=64, 20 | help='Batch size to use') 21 | parser.add_argument('--dims', type=int, default=2048, 22 | choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), 23 | help=('Dimensionality of Inception features to use. ' 24 | 'By default, uses pool3 features')) 25 | parser.add_argument('-c', '--gpu', default='', type=str, 26 | help='GPU to use (leave blank for CPU only)') 27 | 28 | 29 | def get_activations(images, model, batch_size=64, dims=2048, 30 | cuda=False, verbose=False): 31 | """Calculates the activations of the pool_3 layer for all images. 32 | 33 | Params: 34 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values 35 | must lie between 0 and 1. 36 | -- model : Instance of inception model 37 | -- batch_size : the images numpy array is split into batches with 38 | batch size batch_size. A reasonable batch size depends 39 | on the hardware. 40 | -- dims : Dimensionality of features returned by Inception 41 | -- cuda : If set to True, use GPU 42 | -- verbose : If set to True and parameter out_step is given, the number 43 | of calculated batches is reported. 44 | Returns: 45 | -- A numpy array of dimension (num images, dims) that contains the 46 | activations of the given tensor when feeding inception with the 47 | query tensor. 48 | """ 49 | model.eval() 50 | 51 | d0 = images.shape[0] 52 | if batch_size > d0: 53 | print(('Warning: batch size is bigger than the data size. ' 54 | 'Setting batch size to data size')) 55 | batch_size = d0 56 | 57 | n_batches = d0 // batch_size 58 | n_used_imgs = n_batches * batch_size 59 | 60 | pred_arr = np.empty((n_used_imgs, dims)) 61 | for i in range(n_batches): 62 | if verbose: 63 | print('\rPropagating batch %d/%d' % (i + 1, n_batches), 64 | end='', flush=True) 65 | start = i * batch_size 66 | end = start + batch_size 67 | 68 | batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor) 69 | batch = Variable(batch, volatile=True) 70 | if cuda: 71 | batch = batch.cuda() 72 | 73 | pred = model(batch)[0] 74 | 75 | # If model output is not scalar, apply global spatial average pooling. 76 | # This happens if you choose a dimensionality not equal 2048. 77 | if pred.shape[2] != 1 or pred.shape[3] != 1: 78 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 79 | 80 | pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1) 81 | 82 | if verbose: 83 | print(' done') 84 | 85 | return pred_arr 86 | 87 | 88 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 89 | """Numpy implementation of the Frechet Distance. 90 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 91 | and X_2 ~ N(mu_2, C_2) is 92 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 93 | 94 | Stable version by Dougal J. Sutherland. 95 | 96 | Params: 97 | -- mu1 : Numpy array containing the activations of a layer of the 98 | inception net (like returned by the function 'get_predictions') 99 | for generated samples. 100 | -- mu2 : The sample mean over activations, precalculated on an 101 | representative data set. 102 | -- sigma1: The covariance matrix over activations for generated samples. 103 | -- sigma2: The covariance matrix over activations, precalculated on an 104 | representative data set. 105 | 106 | Returns: 107 | -- : The Frechet Distance. 108 | """ 109 | 110 | mu1 = np.atleast_1d(mu1) 111 | mu2 = np.atleast_1d(mu2) 112 | 113 | sigma1 = np.atleast_2d(sigma1) 114 | sigma2 = np.atleast_2d(sigma2) 115 | 116 | assert mu1.shape == mu2.shape, \ 117 | 'Training and test mean vectors have different lengths' 118 | assert sigma1.shape == sigma2.shape, \ 119 | 'Training and test covariances have different dimensions' 120 | 121 | diff = mu1 - mu2 122 | 123 | # Product might be almost singular 124 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 125 | if not np.isfinite(covmean).all(): 126 | msg = ('fid calculation produces singular product; ' 127 | 'adding %s to diagonal of cov estimates') % eps 128 | print(msg) 129 | offset = np.eye(sigma1.shape[0]) * eps 130 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 131 | 132 | # Numerical error might give slight imaginary component 133 | if np.iscomplexobj(covmean): 134 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 135 | m = np.max(np.abs(covmean.imag)) 136 | raise ValueError('Imaginary component {}'.format(m)) 137 | covmean = covmean.real 138 | 139 | tr_covmean = np.trace(covmean) 140 | 141 | return (diff.dot(diff) + np.trace(sigma1) + 142 | np.trace(sigma2) - 2 * tr_covmean) 143 | 144 | 145 | def calculate_activation_statistics(images, model, batch_size=64, 146 | dims=2048, cuda=False, verbose=False): 147 | """Calculation of the statistics used by the FID. 148 | Params: 149 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values 150 | must lie between 0 and 1. 151 | -- model : Instance of inception model 152 | -- batch_size : The images numpy array is split into batches with 153 | batch size batch_size. A reasonable batch size 154 | depends on the hardware. 155 | -- dims : Dimensionality of features returned by Inception 156 | -- cuda : If set to True, use GPU 157 | -- verbose : If set to True and parameter out_step is given, the 158 | number of calculated batches is reported. 159 | Returns: 160 | -- mu : The mean over samples of the activations of the pool_3 layer of 161 | the inception model. 162 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 163 | the inception model. 164 | """ 165 | act = get_activations(images, model, batch_size, dims, cuda, verbose) 166 | mu = np.mean(act, axis=0) 167 | sigma = np.cov(act, rowvar=False) 168 | return mu, sigma 169 | 170 | 171 | def _compute_statistics_of_path(path, model, batch_size, dims, cuda): 172 | if path.endswith('.npz'): 173 | f = np.load(path) 174 | m, s = f['mu'][:], f['sigma'][:] 175 | f.close() 176 | else: 177 | path = pathlib.Path(path) 178 | files = list(path.glob('*.jpg')) + list(path.glob('*.png')) 179 | 180 | imgs = np.array([imread(str(fn)).astype(np.float32) for fn in files]) 181 | 182 | # Bring images to shape (B, 3, H, W) 183 | imgs = imgs.transpose((0, 3, 1, 2))[:, :3] 184 | 185 | # Rescale images to be between 0 and 1 186 | imgs /= 255 187 | 188 | m, s = calculate_activation_statistics(imgs, model, batch_size, 189 | dims, cuda) 190 | 191 | return m, s 192 | 193 | 194 | def calculate_fid_given_paths(paths, batch_size, cuda, dims): 195 | """Calculates the FID of two paths""" 196 | for p in paths: 197 | if not os.path.exists(p): 198 | raise RuntimeError('Invalid path: %s' % p) 199 | 200 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 201 | 202 | model = InceptionV3([block_idx]) 203 | if cuda: 204 | model.cuda() 205 | 206 | m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, 207 | dims, cuda) 208 | m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, 209 | dims, cuda) 210 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 211 | 212 | return fid_value 213 | 214 | 215 | if __name__ == '__main__': 216 | args = parser.parse_args() 217 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 218 | 219 | fid_value = calculate_fid_given_paths(args.path, 220 | args.batch_size, 221 | args.gpu != '', 222 | args.dims) 223 | print('FID: ', fid_value) 224 | -------------------------------------------------------------------------------- /mesh2tex/utils/FID/inception.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torchvision import models 4 | 5 | 6 | class InceptionV3(nn.Module): 7 | """Pretrained InceptionV3 network returning feature maps""" 8 | 9 | # Index of default block of inception to return, 10 | # corresponds to output of final average pooling 11 | DEFAULT_BLOCK_INDEX = 3 12 | 13 | # Maps feature dimensionality to their output blocks indices 14 | BLOCK_INDEX_BY_DIM = { 15 | 64: 0, # First max pooling features 16 | 192: 1, # Second max pooling featurs 17 | 768: 2, # Pre-aux classifier features 18 | 2048: 3 # Final average pooling features 19 | } 20 | 21 | def __init__(self, 22 | output_blocks=[DEFAULT_BLOCK_INDEX], 23 | resize_input=True, 24 | normalize_input=True, 25 | requires_grad=False): 26 | """Build pretrained InceptionV3 27 | 28 | Parameters 29 | ---------- 30 | output_blocks : list of int 31 | Indices of blocks to return features of. Possible values are: 32 | - 0: corresponds to output of first max pooling 33 | - 1: corresponds to output of second max pooling 34 | - 2: corresponds to output which is fed to aux classifier 35 | - 3: corresponds to output of final average pooling 36 | resize_input : bool 37 | If true, bilinearly resizes input to width and height 299 before 38 | feeding input to model. As the network without fully connected 39 | layers is fully convolutional, it should be able to handle inputs 40 | of arbitrary size, so resizing might not be strictly needed 41 | normalize_input : bool 42 | If true, normalizes the input to the statistics the pretrained 43 | Inception network expects 44 | requires_grad : bool 45 | If true, parameters of the model require gradient. Possibly useful 46 | for finetuning the network 47 | """ 48 | super(InceptionV3, self).__init__() 49 | 50 | self.resize_input = resize_input 51 | self.normalize_input = normalize_input 52 | self.output_blocks = sorted(output_blocks) 53 | self.last_needed_block = max(output_blocks) 54 | 55 | assert self.last_needed_block <= 3, \ 56 | 'Last possible output block index is 3' 57 | 58 | self.blocks = nn.ModuleList() 59 | 60 | inception = models.inception_v3(pretrained=True) 61 | 62 | # Block 0: input to maxpool1 63 | block0 = [ 64 | inception.Conv2d_1a_3x3, 65 | inception.Conv2d_2a_3x3, 66 | inception.Conv2d_2b_3x3, 67 | nn.MaxPool2d(kernel_size=3, stride=2) 68 | ] 69 | self.blocks.append(nn.Sequential(*block0)) 70 | 71 | # Block 1: maxpool1 to maxpool2 72 | if self.last_needed_block >= 1: 73 | block1 = [ 74 | inception.Conv2d_3b_1x1, 75 | inception.Conv2d_4a_3x3, 76 | nn.MaxPool2d(kernel_size=3, stride=2) 77 | ] 78 | self.blocks.append(nn.Sequential(*block1)) 79 | 80 | # Block 2: maxpool2 to aux classifier 81 | if self.last_needed_block >= 2: 82 | block2 = [ 83 | inception.Mixed_5b, 84 | inception.Mixed_5c, 85 | inception.Mixed_5d, 86 | inception.Mixed_6a, 87 | inception.Mixed_6b, 88 | inception.Mixed_6c, 89 | inception.Mixed_6d, 90 | inception.Mixed_6e, 91 | ] 92 | self.blocks.append(nn.Sequential(*block2)) 93 | 94 | # Block 3: aux classifier to final avgpool 95 | if self.last_needed_block >= 3: 96 | block3 = [ 97 | inception.Mixed_7a, 98 | inception.Mixed_7b, 99 | inception.Mixed_7c, 100 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 101 | ] 102 | self.blocks.append(nn.Sequential(*block3)) 103 | 104 | for param in self.parameters(): 105 | param.requires_grad = requires_grad 106 | 107 | def forward(self, inp): 108 | """Get Inception feature maps 109 | 110 | Parameters 111 | ---------- 112 | inp : torch.autograd.Variable 113 | Input tensor of shape Bx3xHxW. Values are expected to be in 114 | range (0, 1) 115 | 116 | Returns 117 | ------- 118 | List of torch.autograd.Variable, corresponding to the selected output 119 | block, sorted ascending by index 120 | """ 121 | outp = [] 122 | x = inp 123 | 124 | if self.resize_input: 125 | x = F.upsample(x, size=(299, 299), mode='bilinear') 126 | 127 | if self.normalize_input: 128 | x = x.clone() 129 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 130 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 131 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 132 | 133 | for idx, block in enumerate(self.blocks): 134 | x = block(x) 135 | if idx in self.output_blocks: 136 | outp.append(x) 137 | 138 | if idx == self.last_needed_block: 139 | break 140 | 141 | return outp 142 | -------------------------------------------------------------------------------- /mesh2tex/utils/SSIM_L1/ssim_l1_score.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.measure import compare_ssim as ssim 3 | import imageio 4 | import os 5 | 6 | 7 | def calculate_ssim_l1_given_paths(paths): 8 | file_list = os.listdir(paths[0]) 9 | ssim_value = 0 10 | l1_value = 0 11 | for f in file_list: 12 | # assert(i[0] == i[1]) 13 | fake = load_img(paths[0] + f) 14 | real = load_img(paths[1] + f) 15 | ssim_value += np.mean( 16 | ssim(fake, real, multichannel=True)) 17 | l1_value += np.mean(abs(fake - real)) 18 | 19 | ssim_value = ssim_value/float(len(file_list)) 20 | l1_value = l1_value/float(len(file_list)) 21 | 22 | return ssim_value, l1_value 23 | 24 | 25 | def calculate_ssim_l1_given_tensor(images_fake, images_real): 26 | bs = images_fake.size(0) 27 | images_fake = images_fake.permute(0, 2, 3, 1).cpu().numpy() 28 | images_real = images_real.permute(0, 2, 3, 1).cpu().numpy() 29 | 30 | ssim_value = 0 31 | l1_value = 0 32 | for i in range(bs): 33 | # assert(i[0] == i[1]) 34 | fake = images_fake[i] 35 | real = images_real[i] 36 | ssim_value += np.mean( 37 | ssim(fake, real, multichannel=True)) 38 | l1_value += np.mean(abs(fake - real)) 39 | ssim_value = ssim_value/float(bs) 40 | l1_value = l1_value/float(bs) 41 | 42 | return ssim_value, l1_value 43 | 44 | 45 | def load_img(path): 46 | img = imageio.imread(path) 47 | img = img.astype(np.float64) / 255 48 | if img.ndim == 2: 49 | img = np.stack([img, img, img], axis=-1) 50 | elif img.shape[2] == 1: 51 | img = np.concatenate([img, img, img], axis=-1) 52 | elif img.shape[2] == 4: 53 | img = img[:, :, :3] 54 | 55 | return img 56 | -------------------------------------------------------------------------------- /mesh2tex/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/texture_fields/fe92e8dec3e6285259c4b61ec0167f52a7669ed0/mesh2tex/utils/__init__.py -------------------------------------------------------------------------------- /mesh2tex/utils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | from plyfile import PlyElement, PlyData 3 | import numpy as np 4 | 5 | 6 | def export_pointcloud(vertices, out_file, as_text=True): 7 | assert(vertices.shape[1] == 3) 8 | vertices = vertices.astype(np.float32) 9 | vector_dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4')] 10 | vertices = vertices.view(dtype=vector_dtype).flatten() 11 | plyel = PlyElement.describe(vertices, 'vertex') 12 | plydata = PlyData([plyel], text=as_text) 13 | plydata.write(out_file) 14 | 15 | 16 | def load_pointcloud(in_file): 17 | plydata = PlyData.read(in_file) 18 | vertices = np.stack([ 19 | plydata['vertex']['x'], 20 | plydata['vertex']['y'], 21 | plydata['vertex']['z'] 22 | ], axis=1) 23 | return vertices 24 | 25 | 26 | def read_off(file): 27 | """ 28 | Reads vertices and faces from an off file. 29 | 30 | :param file: path to file to read 31 | :type file: str 32 | :return: vertices and faces as lists of tuples 33 | :rtype: [(float)], [(int)] 34 | """ 35 | 36 | assert os.path.exists(file), 'file %s not found' % file 37 | 38 | with open(file, 'r') as fp: 39 | lines = fp.readlines() 40 | lines = [line.strip() for line in lines] 41 | 42 | # Fix for ModelNet bug were 'OFF' and the number of vertices and faces 43 | # are all in the first line. 44 | if len(lines[0]) > 3: 45 | assert lines[0][:3] == 'OFF' or lines[0][:3] == 'off', \ 46 | 'invalid OFF file %s' % file 47 | 48 | parts = lines[0][3:].split(' ') 49 | assert len(parts) == 3 50 | 51 | num_vertices = int(parts[0]) 52 | assert num_vertices > 0 53 | 54 | num_faces = int(parts[1]) 55 | assert num_faces > 0 56 | 57 | start_index = 1 58 | # This is the regular case! 59 | else: 60 | assert lines[0] == 'OFF' or lines[0] == 'off', \ 61 | 'invalid OFF file %s' % file 62 | 63 | parts = lines[1].split(' ') 64 | assert len(parts) == 3 65 | 66 | num_vertices = int(parts[0]) 67 | assert num_vertices > 0 68 | 69 | num_faces = int(parts[1]) 70 | assert num_faces > 0 71 | 72 | start_index = 2 73 | 74 | vertices = [] 75 | for i in range(num_vertices): 76 | vertex = lines[start_index + i].split(' ') 77 | vertex = [float(point.strip()) for point in vertex if point != ''] 78 | assert len(vertex) == 3 79 | 80 | vertices.append(vertex) 81 | 82 | faces = [] 83 | for i in range(num_faces): 84 | face = lines[start_index + num_vertices + i].split(' ') 85 | face = [index.strip() for index in face if index != ''] 86 | 87 | # check to be sure 88 | for index in face: 89 | assert index != '', \ 90 | 'found empty vertex index: %s (%s)' \ 91 | % (lines[start_index + num_vertices + i], file) 92 | 93 | face = [int(index) for index in face] 94 | 95 | assert face[0] == len(face) - 1, \ 96 | 'face should have %d vertices but as %d (%s)' \ 97 | % (face[0], len(face) - 1, file) 98 | assert face[0] == 3, \ 99 | 'only triangular meshes supported (%s)' % file 100 | for index in face: 101 | assert index >= 0 and index < num_vertices, \ 102 | 'vertex %d (of %d vertices) does not exist (%s)' \ 103 | % (index, num_vertices, file) 104 | 105 | assert len(face) > 1 106 | 107 | faces.append(face) 108 | 109 | return vertices, faces 110 | 111 | assert False, 'could not open %s' % file -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Texture Fields 2 |