├── soft_intro_vae_3d ├── datasets │ ├── __init__.py │ ├── transforms.py │ ├── shapenet.py │ └── modelnet40.py ├── metrics │ ├── __init__.py │ └── jsd.py ├── utils │ ├── __init__.py │ ├── util.py │ ├── data.py │ └── pcutil.py ├── requirements.txt ├── render │ ├── LICENSE │ ├── README.md │ └── render_mitsuba2_pc.py ├── losses │ └── chamfer_loss.py ├── config │ └── soft_intro_vae_hp.json ├── test_model.py ├── evaluation │ ├── generate_data_for_metrics.py │ └── find_best_epoch_on_validation_soft.py ├── generate_for_rendering.py └── README.md ├── style_soft_intro_vae ├── training_artifacts │ ├── ffhq │ │ └── last_checkpoint │ └── celeba-hq256 │ │ └── last_checkpoint ├── requirements.txt ├── registry.py ├── split_train_test_dirs.py ├── configs │ ├── celeba-hq256.yaml │ └── ffhq256.yaml ├── defaults.py ├── environment.yml ├── utils.py ├── scheduler.py ├── custom_adam.py ├── make_figures │ ├── generate_samples.py │ ├── make_generation_figure.py │ ├── make_recon_figure_ffhq.py │ ├── make_recon_figure_interpolation_2_images.py │ ├── make_recon_figure_paged.py │ └── make_recon_figure_interpolation.py ├── launcher.py ├── lod_driver.py ├── tracker.py ├── checkpointer.py ├── dataset_preparation │ └── split_tfrecords_ffhq.py └── README.md ├── soft_intro_vae_tutorial └── README.md ├── soft_intro_vae_2d ├── main.py └── README.md ├── soft_intro_vae ├── main.py ├── README.md └── dataset.py ├── soft_intro_vae_bootstrap ├── main.py ├── README.md └── dataset.py ├── environment.yml └── README.md /soft_intro_vae_3d/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /soft_intro_vae_3d/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /soft_intro_vae_3d/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /style_soft_intro_vae/training_artifacts/ffhq/last_checkpoint: -------------------------------------------------------------------------------- 1 | training_artifacts/ffhq/ffhq_fid_17.55_epoch_270.pth -------------------------------------------------------------------------------- /style_soft_intro_vae/training_artifacts/celeba-hq256/last_checkpoint: -------------------------------------------------------------------------------- 1 | training_artifacts/celeba-hq256/celebahq_fid_18.63_epoch_230.pth -------------------------------------------------------------------------------- /soft_intro_vae_3d/requirements.txt: -------------------------------------------------------------------------------- 1 | h5py 2 | matplotlib 3 | numpy 4 | pandas 5 | git+https://github.com/szagoruyko/pyinn.git@master 6 | torch==0.4.1 -------------------------------------------------------------------------------- /style_soft_intro_vae/requirements.txt: -------------------------------------------------------------------------------- 1 | packaging 2 | imageio 3 | numpy 4 | scipy 5 | tqdm 6 | dlutils 7 | bimpy 8 | torch >= 1.3 9 | torchvision 10 | sklearn 11 | yacs 12 | matplotlib 13 | -------------------------------------------------------------------------------- /style_soft_intro_vae/registry.py: -------------------------------------------------------------------------------- 1 | from utils import Registry 2 | 3 | MODELS = Registry() 4 | ENCODERS = Registry() 5 | GENERATORS = Registry() 6 | MAPPINGS = Registry() 7 | DISCRIMINATORS = Registry() 8 | -------------------------------------------------------------------------------- /style_soft_intro_vae/split_train_test_dirs.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import os 3 | import shutil 4 | 5 | if __name__ == '__main__': 6 | all_data_dir = str(Path.home()) + '/../../mnt/data/tal/celebhq_256' 7 | train_dir = str(Path.home()) + '/../../mnt/data/tal/celebhq_256_train' 8 | test_dir = str(Path.home()) + '/../../mnt/data/tal/celebhq_256_test' 9 | os.makedirs(train_dir, exist_ok=True) 10 | os.makedirs(test_dir, exist_ok=True) 11 | num_train = 29000 12 | num_test = 1000 13 | images = [] 14 | for i, filename in enumerate(os.listdir(all_data_dir)): 15 | if i < num_train: 16 | shutil.copyfile(os.path.join(all_data_dir, filename), os.path.join(train_dir, filename)) 17 | else: 18 | shutil.copyfile(os.path.join(all_data_dir, filename), os.path.join(test_dir, filename)) 19 | -------------------------------------------------------------------------------- /soft_intro_vae_3d/render/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Tolga Birdal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /soft_intro_vae_3d/losses/chamfer_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ChamferLoss(nn.Module): 6 | 7 | def __init__(self): 8 | super(ChamferLoss, self).__init__() 9 | self.use_cuda = torch.cuda.is_available() 10 | 11 | def forward(self, preds, gts): 12 | P = self.batch_pairwise_dist(gts, preds) 13 | mins, _ = torch.min(P, 1) 14 | loss_1 = torch.sum(mins, 1) 15 | mins, _ = torch.min(P, 2) 16 | loss_2 = torch.sum(mins, 1) 17 | return loss_1 + loss_2 18 | 19 | def batch_pairwise_dist(self, x, y): 20 | bs, num_points_x, points_dim = x.size() 21 | _, num_points_y, _ = y.size() 22 | xx = torch.bmm(x, x.transpose(2, 1)) 23 | yy = torch.bmm(y, y.transpose(2, 1)) 24 | zz = torch.bmm(x, y.transpose(2, 1)) 25 | if self.use_cuda: 26 | dtype = torch.cuda.LongTensor 27 | else: 28 | dtype = torch.LongTensor 29 | diag_ind_x = torch.arange(0, num_points_x).type(dtype) 30 | diag_ind_y = torch.arange(0, num_points_y).type(dtype) 31 | rx = xx[:, diag_ind_x, diag_ind_x].unsqueeze(1).expand_as( 32 | zz.transpose(2, 1)) 33 | ry = yy[:, diag_ind_y, diag_ind_y].unsqueeze(1).expand_as(zz) 34 | P = rx.transpose(2, 1) + ry - 2 * zz 35 | return P 36 | -------------------------------------------------------------------------------- /soft_intro_vae_3d/config/soft_intro_vae_hp.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_name": "soft_intro_vae", 3 | "results_root": "./results", 4 | "clean_results_dir": false, 5 | 6 | "cuda": true, 7 | "gpu": 0, 8 | 9 | "reconstruction_loss": "chamfer", 10 | 11 | "metrics": [ 12 | ], 13 | 14 | "dataset": "shapenet", 15 | "data_dir": "./datasets/shapenet_data", 16 | "classes": ["car", "airplane"], 17 | "shuffle": true, 18 | "transforms": [], 19 | "num_workers": 8, 20 | "n_points": 2048, 21 | 22 | "max_epochs": 2000, 23 | "batch_size": 32, 24 | "beta_rec": 20.0, 25 | "beta_kl": 1.0, 26 | "beta_neg": 256, 27 | "z_size": 128, 28 | "gamma_r": 1e-8, 29 | "num_vae": 0, 30 | "prior_std": 0.2, 31 | 32 | 33 | "seed": -1, 34 | "save_frequency": 50, 35 | "valid_frequency": 2, 36 | 37 | "arch": "vae", 38 | "model": { 39 | "D": { 40 | "use_bias": true, 41 | "relu_slope": 0.2 42 | }, 43 | "E": { 44 | "use_bias": true, 45 | "relu_slope": 0.2 46 | } 47 | }, 48 | "optimizer": { 49 | "D": { 50 | "type": "Adam", 51 | "hyperparams": { 52 | "lr": 0.0005, 53 | "weight_decay": 0, 54 | "betas": [0.9, 0.999], 55 | "amsgrad": false 56 | } 57 | }, 58 | "E": { 59 | "type": "Adam", 60 | "hyperparams": { 61 | "lr": 0.0005, 62 | "weight_decay": 0, 63 | "betas": [0.9, 0.999], 64 | "amsgrad": false 65 | } 66 | } 67 | } 68 | } -------------------------------------------------------------------------------- /style_soft_intro_vae/configs/celeba-hq256.yaml: -------------------------------------------------------------------------------- 1 | # Config for training SoftIntroVAE on CelebA-HQ at resolution 256x256 2 | 3 | NAME: celeba-hq256 4 | DATASET: 5 | PART_COUNT: 16 6 | SIZE: 29000 7 | SIZE_TEST: 1000 8 | PATH: /mnt/data/tal/celebhq_256_tfrecords/celeba-r%02d.tfrecords.%03d 9 | PATH_TEST: /mnt/data/tal/celebhq_256_test_tfrecords/celeba-r%02d.tfrecords.%03d 10 | MAX_RESOLUTION_LEVEL: 8 11 | SAMPLES_PATH: /mnt/data/tal/celebhq_256_test/ 12 | STYLE_MIX_PATH: /home/tal/tmp/SoftIntroVAE/style_mixing/test_images/set_ffhq 13 | MODEL: 14 | LATENT_SPACE_SIZE: 512 15 | LAYER_COUNT: 7 16 | MAX_CHANNEL_COUNT: 512 17 | START_CHANNEL_COUNT: 64 18 | DLATENT_AVG_BETA: 0.995 19 | MAPPING_LAYERS: 8 20 | BETA_KL: 0.2 21 | BETA_REC: 0.05 22 | BETA_NEG: [2048, 2048, 2048, 1024, 512, 512, 512, 512, 512] 23 | SCALE: 0.000005 24 | OUTPUT_DIR: /home/tal/tmp/SoftIntroVAE/training_artifacts/celeba-hq256 25 | TRAIN: 26 | BASE_LEARNING_RATE: 0.002 27 | EPOCHS_PER_LOD: 30 28 | NUM_VAE: 1 29 | LEARNING_DECAY_RATE: 0.1 30 | LEARNING_DECAY_STEPS: [] 31 | TRAIN_EPOCHS: 300 32 | # 4 8 16 32 64 128 256 512 1024 33 | LOD_2_BATCH_8GPU: [512, 256, 128, 64, 32, 32, 32, 32, 24] 34 | LOD_2_BATCH_4GPU: [512, 256, 128, 64, 32, 32, 16, 8, 4] 35 | LOD_2_BATCH_2GPU: [128, 128, 128, 64, 32, 16, 8] 36 | LOD_2_BATCH_1GPU: [128, 128, 128, 32, 16, 8, 4] 37 | 38 | LEARNING_RATES: [0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.003, 0.003] 39 | -------------------------------------------------------------------------------- /style_soft_intro_vae/configs/ffhq256.yaml: -------------------------------------------------------------------------------- 1 | # Config for training SoftIntroVAE on FFHQ at resolution 1256x256 2 | 3 | NAME: ffhq 4 | DATASET: 5 | PART_COUNT: 16 6 | SIZE: 60000 7 | FFHQ_SOURCE: /mnt/data/tal/ffhq_ds/ffhq-dataset/tfrecords_custom/ffhq-r%02d.tfrecords 8 | PATH: /mnt/data/tal/ffhq_ds/ffhq-dataset/tfrecords_custom/splitted/ffhq-r%02d.tfrecords.%03d 9 | 10 | PART_COUNT_TEST: 2 11 | PATH_TEST: /mnt/data/tal/ffhq_ds/ffhq-dataset/tfrecords_custom/splitted/ffhq-r%02d.tfrecords.%03d 12 | 13 | SAMPLES_PATH: /mnt/data/tal/ffhq_ds/ffhq-dataset/images1024x1024/69000 14 | STYLE_MIX_PATH: style_mixing/test_images/set_ffhq 15 | 16 | MAX_RESOLUTION_LEVEL: 8 17 | MODEL: 18 | LATENT_SPACE_SIZE: 512 19 | LAYER_COUNT: 7 20 | MAX_CHANNEL_COUNT: 512 21 | START_CHANNEL_COUNT: 64 22 | DLATENT_AVG_BETA: 0.995 23 | MAPPING_LAYERS: 8 24 | BETA_KL: 0.2 25 | BETA_REC: 0.1 26 | BETA_NEG: [2048, 2048, 2048, 1024, 512, 512, 512, 512, 512] 27 | SCALE: 0.000005 28 | OUTPUT_DIR: /home/tal/tmp/SoftIntroVAE/training_artifacts/ffhq 29 | TRAIN: 30 | BASE_LEARNING_RATE: 0.002 31 | EPOCHS_PER_LOD: 16 32 | NUM_VAE: 1 33 | LEARNING_DECAY_RATE: 0.1 34 | LEARNING_DECAY_STEPS: [] 35 | TRAIN_EPOCHS: 300 36 | # 4 8 16 32 64 128 256 37 | LOD_2_BATCH_8GPU: [512, 256, 128, 64, 32, 32, 32, 32, 32] # If GPU memory ~16GB reduce last number from 32 to 24 38 | LOD_2_BATCH_4GPU: [ 512, 256, 128, 64, 32, 32, 16, 8, 4 ] 39 | LOD_2_BATCH_2GPU: [ 128, 128, 128, 64, 32, 16, 8 ] 40 | LOD_2_BATCH_1GPU: [ 128, 128, 128, 32, 16, 8, 4 ] 41 | 42 | LEARNING_RATES: [0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.002, 0.003, 0.003] 43 | -------------------------------------------------------------------------------- /soft_intro_vae_3d/render/README.md: -------------------------------------------------------------------------------- 1 | # Rendering Beautiful Point Clouds with Mitsuba Renderer 2 | This code was adapted from the code by [tolgabirdal](https://github.com/tolgabirdal/Mitsuba2PointCloudRenderer) 3 | and modified to work with Soft-IntroVAE output. 4 | 5 | # Installing Mitsuba 2 (on a Linux machine) 6 | Follow the following steps: 7 | 1. Follow https://mitsuba2.readthedocs.io/en/latest/src/getting_started/compiling.html#linux 8 | 2. Clone mitsuba2: git clone --recursive https://github.com/mitsuba-renderer/mitsuba2 9 | 3. Install `openEXR`: `sudo apt-get install libopenexr-dev` 10 | 4. `pip install openexr` 11 | 5. Replace 'PATH_TO_MITSUBA2' in the 'render_mitsuba2_pc.py' with the path to your local 'mitsuba' file. 12 | 13 | 14 | # Multiple Point Cloud Renderer using Mitsuba 2 15 | 16 | Calling the script **render_mitsuba2_pc.py** automatically performs the following in order: 17 | 18 | 1. generates an XML file, which describes a 3D scene in the format used by Mitsuba. 19 | 2. calls Mitsuba2 to render the point cloud into an EXR 20 | 3. processes the EXR into a jpg file. 21 | 4. iterates for multiple point clouds present in the tensor (.npy) 22 | 23 | It could process both plys and npy. The script is heavily inspired by [PointFlow renderer](https://github.com/zekunhao1995/PointFlowRenderer) and here is how the outputs can look like: 24 | 25 | 26 | ## Dependencies 27 | * Python >= 3.6 28 | * [Mitsuba 2](http://www.mitsuba-renderer.org/) 29 | * Used python packages for 'render_mitsuba2_pc' : OpenEXR, Imath, PIL 30 | 31 | Ensure that Mitsuba 2 can be called as 'mitsuba' by following the [instructions here](https://mitsuba2.readthedocs.io/en/latest/src/getting_started/compiling.html#linux). 32 | Also make sure that the 'PATH_TO_MITSUBA2' in the code is replaced by the path to your local 'mitsuba' file. 33 | 34 | ## Instructions 35 | 36 | Replace 'PATH_TO_MITSUBA2' in the 'render_mitsuba2_pc.py' with the path to your local 'mitsuba' file. Then call: 37 | ```bash 38 | # Render a single or multiple JPG file(s) as: 39 | python render_mitsuba2_pc.py interpolations.npy 40 | 41 | # It could also render a ply file 42 | python render_mitsuba2_pc.py chair.ply 43 | ``` 44 | 45 | All the outputs including the resulting JPG files will be saved in the directory of the input point cloud. The intermediate EXR/XML files will remain in the folder and has to be removed by the user. 46 | 47 | * To animate a sequence of images in a `gif` format, it is recommended to use `imageio`. 48 | -------------------------------------------------------------------------------- /soft_intro_vae_tutorial/README.md: -------------------------------------------------------------------------------- 1 | # soft-intro-vae-pytorch-tutorials 2 | 3 | Step-by-step Jupyter Notebook tutorials for Soft-IntroVAE 4 | 5 | 6 |

7 | Open In Colab 8 |

9 | 10 |

11 |

12 | 13 |

14 | 15 |

16 | 17 | - [soft-intro-vae-pytorch-tutorials](#soft-intro-vae-pytorch-tutorials) 18 | * [Running Instructions](#running-instructions) 19 | * [Files and directories in the repository](#files-and-directories-in-the-repository) 20 | * [But Wait, There is More...](#but-wait--there-is-more) 21 | 22 | ## Running Instructions 23 | * This Jupyter Notebook can be opened locally with Anaconda, or online via Google Colab. 24 | * To run online, go to https://colab.research.google.com/ and drag-and-drop the `soft_intro_vae_code_tutorial.ipynb` file. 25 | * On Colab, note the "directory" icon on the left, figures and checkpoints are saved in this directory. 26 | * To run the training on the image dataset, it is better to have a GPU. In Google Colab select `Runtime->Change runtime type->GPU`. 27 | * You can also use NBViewer to render the notebooks: [Open with Jupyter NBviewer](https://nbviewer.jupyter.org/github/taldatech/soft-intro-vae-pytorch/tree/main/) 28 | 29 | ## Files and directories in the repository 30 | 31 | |File name | Purpose | 32 | |----------------------|------| 33 | |`soft_intro_vae_2d_code_tutorial.ipynb`| Soft-IntroVAE tutorial for 2D datasets| 34 | |`soft_intro_vae_image_code_tutorial.ipynb`| Soft-IntroVAE tutorial for image datasets| 35 | |`soft_intro_vae_bootstrap_code_tutorial.ipynb`| Bootstrap Soft-IntroVAE tutorial for image datasets (not used in the paper)| 36 | 37 | 38 | ## But Wait, There is More... 39 | * General Tutorials (Jupyter Notebooks with code) 40 | * [CS236756 - Intro to Machine Learning](https://github.com/taldatech/cs236756-intro-to-ml) 41 | * [EE046202 - Unsupervised Learning and Data Analysis](https://github.com/taldatech/ee046202-unsupervised-learning-data-analysis) 42 | * [EE046746 - Computer Vision](https://github.com/taldatech/ee046746-computer-vision) -------------------------------------------------------------------------------- /soft_intro_vae_3d/utils/util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | from os import listdir, makedirs 4 | from os.path import join, exists 5 | from shutil import rmtree 6 | from time import sleep 7 | 8 | import torch 9 | 10 | 11 | def setup_logging(log_dir): 12 | makedirs(log_dir, exist_ok=True) 13 | 14 | logpath = join(log_dir, 'log.txt') 15 | filemode = 'a' if exists(logpath) else 'w' 16 | 17 | # set up logging to file - see previous section for more details 18 | logging.basicConfig(level=logging.DEBUG, 19 | format='%(asctime)s %(message)s', 20 | datefmt='%m-%d %H:%M:%S', 21 | filename=logpath, 22 | filemode=filemode) 23 | # define a Handler which writes INFO messages or higher to the sys.stderr 24 | console = logging.StreamHandler() 25 | console.setLevel(logging.DEBUG) 26 | # set a format which is simpler for console use 27 | formatter = logging.Formatter('%(asctime)s: %(levelname)-8s %(message)s') 28 | # tell the handler to use this format 29 | console.setFormatter(formatter) 30 | # add the handler to the root logger 31 | logging.getLogger('').addHandler(console) 32 | 33 | 34 | def prepare_results_dir(config): 35 | output_dir = join(config['results_root'], config['arch'], 36 | config['experiment_name']) 37 | if config['clean_results_dir']: 38 | if exists(output_dir): 39 | print('Attention! Cleaning results directory in 10 seconds!') 40 | sleep(10) 41 | rmtree(output_dir, ignore_errors=True) 42 | makedirs(output_dir, exist_ok=True) 43 | makedirs(join(output_dir, 'weights'), exist_ok=True) 44 | makedirs(join(output_dir, 'samples'), exist_ok=True) 45 | makedirs(join(output_dir, 'results'), exist_ok=True) 46 | return output_dir 47 | 48 | 49 | def find_latest_epoch(dirpath): 50 | # Files with weights are in format ddddd_{D,E,G}.pth 51 | epoch_regex = re.compile(r'^(?P\d+)_[DEG]\.pth$') 52 | epochs_completed = [] 53 | if exists(join(dirpath, 'weights')): 54 | dirpath = join(dirpath, 'weights') 55 | for f in listdir(dirpath): 56 | m = epoch_regex.match(f) 57 | if m: 58 | epochs_completed.append(int(m.group('n_epoch'))) 59 | return max(epochs_completed) if epochs_completed else 0 60 | 61 | 62 | def cuda_setup(cuda=False, gpu_idx=0): 63 | if cuda and torch.cuda.is_available(): 64 | device = torch.device('cuda') 65 | torch.cuda.set_device(gpu_idx) 66 | else: 67 | device = torch.device('cpu') 68 | return device 69 | 70 | -------------------------------------------------------------------------------- /style_soft_intro_vae/defaults.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019-2020 Stanislav Pidhorskyi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from yacs.config import CfgNode as CN 17 | 18 | 19 | _C = CN() 20 | 21 | _C.NAME = "" 22 | _C.PPL_CELEBA_ADJUSTMENT = False 23 | _C.OUTPUT_DIR = "results" 24 | 25 | _C.DATASET = CN() 26 | _C.DATASET.PATH = 'celeba/data_fold_%d_lod_%d.pkl' 27 | _C.DATASET.PATH_TEST = '' 28 | _C.DATASET.FFHQ_SOURCE = '/data/datasets/ffhq-dataset/tfrecords/ffhq/ffhq-r%02d.tfrecords' 29 | _C.DATASET.PART_COUNT = 1 30 | _C.DATASET.PART_COUNT_TEST = 1 31 | _C.DATASET.SIZE = 70000 32 | _C.DATASET.SIZE_TEST = 10000 33 | _C.DATASET.FLIP_IMAGES = True 34 | _C.DATASET.SAMPLES_PATH = 'dataset_samples/faces/realign128x128' 35 | 36 | _C.DATASET.STYLE_MIX_PATH = 'style_mixing/test_images/set_celeba/' 37 | 38 | _C.DATASET.MAX_RESOLUTION_LEVEL = 10 39 | 40 | _C.MODEL = CN() 41 | 42 | _C.MODEL.LAYER_COUNT = 6 43 | _C.MODEL.START_CHANNEL_COUNT = 64 44 | _C.MODEL.MAX_CHANNEL_COUNT = 512 45 | _C.MODEL.LATENT_SPACE_SIZE = 256 46 | _C.MODEL.DLATENT_AVG_BETA = 0.995 47 | _C.MODEL.TRUNCATIOM_PSI = 0.7 48 | _C.MODEL.TRUNCATIOM_CUTOFF = 8 49 | _C.MODEL.STYLE_MIXING_PROB = 0.9 50 | _C.MODEL.MAPPING_LAYERS = 5 51 | _C.MODEL.CHANNELS = 3 52 | _C.MODEL.GENERATOR = "GeneratorDefault" 53 | _C.MODEL.ENCODER = "EncoderDefault" 54 | _C.MODEL.MAPPING_TO_LATENT = "MappingToLatent" 55 | _C.MODEL.MAPPING_FROM_LATENT = "MappingFromLatent" 56 | _C.MODEL.Z_REGRESSION = False 57 | _C.MODEL.BETA_KL = 1.0 58 | _C.MODEL.BETA_REC = 1.0 59 | _C.MODEL.BETA_NEG = [2048, 2048, 1024, 512, 512, 128, 128, 64, 64] 60 | _C.MODEL.SCALE = 1 / (3 * 256 ** 2) 61 | 62 | _C.TRAIN = CN() 63 | 64 | _C.TRAIN.EPOCHS_PER_LOD = 15 65 | 66 | _C.TRAIN.BASE_LEARNING_RATE = 0.0015 67 | _C.TRAIN.ADAM_BETA_0 = 0.0 68 | _C.TRAIN.ADAM_BETA_1 = 0.99 69 | _C.TRAIN.LEARNING_DECAY_RATE = 0.1 70 | _C.TRAIN.LEARNING_DECAY_STEPS = [] 71 | _C.TRAIN.TRAIN_EPOCHS = 110 72 | _C.TRAIN.NUM_VAE = 1 73 | 74 | _C.TRAIN.LOD_2_BATCH_8GPU = [512, 256, 128, 64, 32, 32] 75 | _C.TRAIN.LOD_2_BATCH_4GPU = [512, 256, 128, 64, 32, 16] 76 | _C.TRAIN.LOD_2_BATCH_2GPU = [256, 256, 128, 64, 32, 16] 77 | _C.TRAIN.LOD_2_BATCH_1GPU = [128, 128, 128, 64, 32, 16] 78 | 79 | 80 | _C.TRAIN.SNAPSHOT_FREQ = [300, 300, 300, 100, 50, 30, 20, 20, 10] 81 | 82 | _C.TRAIN.REPORT_FREQ = [100, 80, 60, 30, 20, 10, 10, 5, 5] 83 | 84 | _C.TRAIN.LEARNING_RATES = [0.002] 85 | 86 | 87 | def get_cfg_defaults(): 88 | return _C.clone() 89 | -------------------------------------------------------------------------------- /soft_intro_vae_2d/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main function for arguments parsing 3 | Author: Tal Daniel 4 | """ 5 | # imports 6 | import torch 7 | import argparse 8 | from train_soft_intro_vae_2d import train_soft_intro_vae_toy 9 | 10 | if __name__ == "__main__": 11 | """ 12 | Recommended hyper-parameters: 13 | - 8Gaussians: beta_kl: 0.3, beta_rec: 0.2, beta_neg: 0.9, z_dim: 2, batch_size: 512 14 | - 2spirals: beta_kl: 0.5, beta_rec: 0.2, beta_neg: 1.0, z_dim: 2, batch_size: 512 15 | - checkerboard: beta_kl: 0.1, beta_rec: 0.2, beta_neg: 0.2, z_dim: 2, batch_size: 512 16 | - rings: beta_kl: 0.2, beta_rec: 0.2, beta_neg: 1.0, z_dim: 2, batch_size: 512 17 | """ 18 | parser = argparse.ArgumentParser(description="train Soft-IntroVAE 2D") 19 | parser.add_argument("-d", "--dataset", type=str, 20 | help="dataset to train on: ['8Gaussians', '2spirals', 'checkerboard', rings']") 21 | parser.add_argument("-n", "--num_iter", type=int, help="total number of iterations to run", default=30_000) 22 | parser.add_argument("-z", "--z_dim", type=int, help="latent dimensions", default=2) 23 | parser.add_argument("-l", "--lr", type=float, help="learning rate", default=2e-4) 24 | parser.add_argument("-b", "--batch_size", type=int, help="batch size", default=512) 25 | parser.add_argument("-v", "--num_vae", type=int, help="number of iterations for vanilla vae training", default=2000) 26 | parser.add_argument("-r", "--beta_rec", type=float, help="beta coefficient for the reconstruction loss", 27 | default=0.2) 28 | parser.add_argument("-k", "--beta_kl", type=float, help="beta coefficient for the kl divergence", 29 | default=0.3) 30 | parser.add_argument("-e", "--beta_neg", type=float, 31 | help="beta coefficient for the kl divergence in the expELBO function", default=0.9) 32 | parser.add_argument("-g", "--gamma_r", type=float, 33 | help="coefficient for the reconstruction loss for fake data in the decoder", default=1e-8) 34 | parser.add_argument("-s", "--seed", type=int, help="seed", default=-1) 35 | parser.add_argument("-p", "--pretrained", type=str, help="path to pretrained model, to continue training", 36 | default="None") 37 | parser.add_argument("-c", "--device", type=int, help="device: -1 for cpu, 0 and up for specific cuda device", 38 | default=-1) 39 | args = parser.parse_args() 40 | 41 | device = torch.device("cpu") if args.device <= -1 else torch.device("cuda:" + str(args.device)) 42 | pretrained = None if args.pretrained == "None" else args.pretrained 43 | if args.dataset == '8Gaussians': 44 | scale = 1 45 | else: 46 | scale = 2 47 | # train 48 | model = train_soft_intro_vae_toy(z_dim=args.z_dim, lr_e=args.lr, lr_d=args.lr, batch_size=args.batch_size, 49 | n_iter=args.num_iter, num_vae=args.num_vae, save_interval=5000, 50 | recon_loss_type="mse", beta_kl=args.beta_kl, beta_rec=args.beta_rec, 51 | beta_neg=args.beta_neg, test_iter=5000, seed=args.seed, scale=scale, 52 | device=device, dataset=args.dataset) 53 | -------------------------------------------------------------------------------- /style_soft_intro_vae/environment.yml: -------------------------------------------------------------------------------- 1 | name: tf_torch 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - blas=1.0=mkl 9 | - ca-certificates=2020.6.24=0 10 | - certifi=2020.6.20=py36_0 11 | - cffi=1.14.0=py36he30daa8_1 12 | - cudatoolkit=10.1.243=h6bb024c_0 13 | - cycler=0.10.0=py_2 14 | - dbus=1.13.6=he372182_0 15 | - expat=2.2.9=he1b5a44_2 16 | - fontconfig=2.13.1=he4413a7_1000 17 | - freetype=2.10.2=he06d7ca_0 18 | - glib=2.65.0=h3eb4bd4_0 19 | - gst-plugins-base=1.14.0=hbbd80ab_1 20 | - gstreamer=1.14.0=hb31296c_0 21 | - icu=58.2=hf484d3e_1000 22 | - imageio=2.8.0=py_0 23 | - intel-openmp=2020.1=217 24 | - joblib=0.16.0=py_0 25 | - jpeg=9d=h516909a_0 26 | - kiwisolver=1.2.0=py36hdb11119_0 27 | - ld_impl_linux-64=2.33.1=h53a641e_7 28 | - libedit=3.1.20191231=h7b6447c_0 29 | - libffi=3.3=he6710b0_1 30 | - libgcc-ng=9.1.0=hdf63c60_0 31 | - libgfortran-ng=7.5.0=hdf63c60_6 32 | - libopenblas=0.3.7=h5ec1e0e_6 33 | - libpng=1.6.37=hed695b0_1 34 | - libstdcxx-ng=9.1.0=hdf63c60_0 35 | - libtiff=4.1.0=hc7e4089_6 36 | - libuuid=2.32.1=h14c3975_1000 37 | - libwebp-base=1.1.0=h516909a_3 38 | - libxcb=1.13=h14c3975_1002 39 | - libxml2=2.9.10=he19cac6_1 40 | - lz4-c=1.9.2=he1b5a44_1 41 | - matplotlib=3.2.2=0 42 | - matplotlib-base=3.2.2=py36hef1b27d_0 43 | - mkl=2020.1=217 44 | - mkl-service=2.3.0=py36he904b0f_0 45 | - mkl_fft=1.1.0=py36h23d657b_0 46 | - mkl_random=1.1.1=py36h0573a6f_0 47 | - ncurses=6.2=he6710b0_1 48 | - ninja=1.9.0=py36hfd86e86_0 49 | - olefile=0.46=py_0 50 | - openssl=1.1.1g=h516909a_0 51 | - packaging=20.4=pyh9f0ad1d_0 52 | - pcre=8.44=he1b5a44_0 53 | - pillow=7.2.0=py36h8328e55_0 54 | - pip=20.1.1=py36_1 55 | - pthread-stubs=0.4=h14c3975_1001 56 | - pycparser=2.20=py_0 57 | - pyparsing=2.4.7=pyh9f0ad1d_0 58 | - pyqt=5.9.2=py36hcca6a23_4 59 | - python=3.6.10=h7579374_2 60 | - python-dateutil=2.8.1=py_0 61 | - python_abi=3.6=1_cp36m 62 | - pytorch=1.5.1=py3.6_cuda10.1.243_cudnn7.6.3_0 63 | - pyyaml=5.3.1=py36h8c4c3a4_0 64 | - qt=5.9.7=h5867ecd_1 65 | - readline=8.0=h7b6447c_0 66 | - scikit-learn=0.23.1=py36h423224d_0 67 | - scipy=1.5.0=py36h0b6359f_0 68 | - sip=4.19.8=py36hf484d3e_0 69 | - six=1.15.0=pyh9f0ad1d_0 70 | - sqlite=3.32.3=h62c20be_0 71 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 72 | - tk=8.6.10=hbc83047_0 73 | - torchvision=0.6.1=py36_cu101 74 | - tornado=6.0.4=py36h8c4c3a4_1 75 | - tqdm=4.47.0=pyh9f0ad1d_0 76 | - wheel=0.34.2=py36_0 77 | - xorg-libxau=1.0.9=h14c3975_0 78 | - xorg-libxdmcp=1.1.3=h516909a_0 79 | - xz=5.2.5=h7b6447c_0 80 | - yacs=0.1.6=py_0 81 | - yaml=0.2.5=h516909a_0 82 | - zlib=1.2.11=h7b6447c_3 83 | - zstd=1.4.4=h6597ccf_3 84 | - pip: 85 | - absl-py==0.9.0 86 | - astor==0.8.1 87 | - bimpy==0.0.13 88 | - dareblopy==0.0.3 89 | - gast==0.3.3 90 | - grpcio==1.30.0 91 | - h5py==2.10.0 92 | - importlib-metadata==1.7.0 93 | - keras-applications==1.0.8 94 | - keras-preprocessing==1.1.2 95 | - markdown==3.2.2 96 | - mock==4.0.2 97 | - numpy==1.14.5 98 | - protobuf==3.12.2 99 | - setuptools==39.1.0 100 | - tensorboard==1.13.1 101 | - tensorflow-estimator==1.13.0 102 | - tensorflow-gpu==1.13.1 103 | - termcolor==1.1.0 104 | - werkzeug==1.0.1 105 | - zipp==3.1.0 106 | prefix: /home/tal/anaconda3/envs/tf_torch 107 | 108 | -------------------------------------------------------------------------------- /soft_intro_vae/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main function for arguments parsing 3 | Author: Tal Daniel 4 | """ 5 | # imports 6 | import torch 7 | import argparse 8 | from train_soft_intro_vae import train_soft_intro_vae 9 | 10 | if __name__ == "__main__": 11 | """ 12 | Recommended hyper-parameters: 13 | - CIFAR10: beta_kl: 1.0, beta_rec: 1.0, beta_neg: 256, z_dim: 128, batch_size: 32 14 | - SVHN: beta_kl: 1.0, beta_rec: 1.0, beta_neg: 256, z_dim: 128, batch_size: 32 15 | - MNIST: beta_kl: 1.0, beta_rec: 1.0, beta_neg: 256, z_dim: 32, batch_size: 128 16 | - FashionMNIST: beta_kl: 1.0, beta_rec: 1.0, beta_neg: 256, z_dim: 32, batch_size: 128 17 | - Monsters: beta_kl: 0.2, beta_rec: 0.2, beta_neg: 256, z_dim: 128, batch_size: 16 18 | - CelebA-HQ: beta_kl: 1.0, beta_rec: 0.5, beta_neg: 1024, z_dim: 256, batch_size: 8 19 | """ 20 | parser = argparse.ArgumentParser(description="train Soft-IntroVAE") 21 | parser.add_argument("-d", "--dataset", type=str, 22 | help="dataset to train on: ['cifar10', 'mnist', 'fmnist', 'svhn', 'monsters128', 'celeb128', " 23 | "'celeb256', 'celeb1024']") 24 | parser.add_argument("-n", "--num_epochs", type=int, help="total number of epochs to run", default=250) 25 | parser.add_argument("-z", "--z_dim", type=int, help="latent dimensions", default=128) 26 | parser.add_argument("-l", "--lr", type=float, help="learning rate", default=2e-4) 27 | parser.add_argument("-b", "--batch_size", type=int, help="batch size", default=32) 28 | parser.add_argument("-v", "--num_vae", type=int, help="number of epochs for vanilla vae training", default=0) 29 | parser.add_argument("-r", "--beta_rec", type=float, help="beta coefficient for the reconstruction loss", 30 | default=1.0) 31 | parser.add_argument("-k", "--beta_kl", type=float, help="beta coefficient for the kl divergence", 32 | default=1.0) 33 | parser.add_argument("-e", "--beta_neg", type=float, 34 | help="beta coefficient for the kl divergence in the expELBO function", default=1.0) 35 | parser.add_argument("-g", "--gamma_r", type=float, 36 | help="coefficient for the reconstruction loss for fake data in the decoder", default=1e-8) 37 | parser.add_argument("-s", "--seed", type=int, help="seed", default=-1) 38 | parser.add_argument("-p", "--pretrained", type=str, help="path to pretrained model, to continue training", 39 | default="None") 40 | parser.add_argument("-c", "--device", type=int, help="device: -1 for cpu, 0 and up for specific cuda device", 41 | default=-1) 42 | parser.add_argument('-f', "--fid", action='store_true', help="if specified, FID wil be calculated during training") 43 | args = parser.parse_args() 44 | 45 | device = torch.device("cpu") if args.device <= -1 else torch.device("cuda:" + str(args.device)) 46 | pretrained = None if args.pretrained == "None" else args.pretrained 47 | train_soft_intro_vae(dataset=args.dataset, z_dim=args.z_dim, batch_size=args.batch_size, num_workers=0, 48 | num_epochs=args.num_epochs, 49 | num_vae=args.num_vae, beta_kl=args.beta_kl, beta_neg=args.beta_neg, beta_rec=args.beta_rec, 50 | device=device, save_interval=50, start_epoch=0, lr_e=args.lr, lr_d=args.lr, 51 | pretrained=pretrained, seed=args.seed, 52 | test_iter=1000, with_fid=args.fid) 53 | -------------------------------------------------------------------------------- /soft_intro_vae_bootstrap/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main function for arguments parsing 3 | Author: Tal Daniel 4 | """ 5 | # imports 6 | import torch 7 | import argparse 8 | from train_soft_intro_vae_bootstrap import train_soft_intro_vae 9 | 10 | if __name__ == "__main__": 11 | """ 12 | Recommended hyper-parameters: 13 | - CIFAR10: beta_kl: 1.0, beta_rec: 1.0, beta_neg: 256, z_dim: 128, batch_size: 32 14 | - SVHN: beta_kl: 1.0, beta_rec: 1.0, beta_neg: 256, z_dim: 128, batch_size: 32 15 | - MNIST: beta_kl: 1.0, beta_rec: 1.0, beta_neg: 256, z_dim: 32, batch_size: 128 16 | - FashionMNIST: beta_kl: 1.0, beta_rec: 1.0, beta_neg: 256, z_dim: 32, batch_size: 128 17 | - Monsters: beta_kl: 0.2, beta_rec: 0.2, beta_neg: 256, z_dim: 128, batch_size: 16 18 | """ 19 | parser = argparse.ArgumentParser(description="train Soft-IntroVAE") 20 | parser.add_argument("-d", "--dataset", type=str, 21 | help="dataset to train on: ['cifar10', 'mnist', 'fmnist', 'svhn', 'monsters128', 'celeb128', " 22 | "'celeb256', 'celeb1024']") 23 | parser.add_argument("-n", "--num_epochs", type=int, help="total number of epochs to run", default=250) 24 | parser.add_argument("-z", "--z_dim", type=int, help="latent dimensions", default=128) 25 | parser.add_argument("-l", "--lr", type=float, help="learning rate", default=2e-4) 26 | parser.add_argument("-b", "--batch_size", type=int, help="batch size", default=32) 27 | parser.add_argument("-v", "--num_vae", type=int, help="number of epochs for vanilla vae training", default=0) 28 | parser.add_argument("-r", "--beta_rec", type=float, help="beta coefficient for the reconstruction loss", 29 | default=1.0) 30 | parser.add_argument("-k", "--beta_kl", type=float, help="beta coefficient for the kl divergence", 31 | default=1.0) 32 | parser.add_argument("-e", "--beta_neg", type=float, 33 | help="beta coefficient for the kl divergence in the expELBO function", default=1.0) 34 | parser.add_argument("-g", "--gamma_r", type=float, 35 | help="coefficient for the reconstruction loss for fake data in the decoder", default=1.0) 36 | parser.add_argument("-s", "--seed", type=int, help="seed", default=-1) 37 | parser.add_argument("-p", "--pretrained", type=str, help="path to pretrained model, to continue training", 38 | default="None") 39 | parser.add_argument("-c", "--device", type=int, help="device: -1 for cpu, 0 and up for specific cuda device", 40 | default=-1) 41 | parser.add_argument('-f', "--fid", action='store_true', help="if specified, FID wil be calculated during training") 42 | parser.add_argument("-o", "--freq", type=int, help="epochs between copying weights from decoder to target decoder", 43 | default=1) 44 | args = parser.parse_args() 45 | 46 | device = torch.device("cpu") if args.device <= -1 else torch.device("cuda:" + str(args.device)) 47 | pretrained = None if args.pretrained == "None" else args.pretrained 48 | train_soft_intro_vae(dataset=args.dataset, z_dim=args.z_dim, batch_size=args.batch_size, num_workers=0, 49 | num_epochs=args.num_epochs, copy_to_target_freq=args.freq, 50 | num_vae=args.num_vae, beta_kl=args.beta_kl, beta_neg=args.beta_neg, beta_rec=args.beta_rec, 51 | device=device, save_interval=50, start_epoch=0, lr_e=args.lr, lr_d=args.lr, 52 | pretrained=pretrained, seed=args.seed, 53 | test_iter=1000, with_fid=args.fid) 54 | -------------------------------------------------------------------------------- /style_soft_intro_vae/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019-2020 Stanislav Pidhorskyi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from torch import nn 17 | import torch 18 | import threading 19 | import hashlib 20 | import pickle 21 | import os 22 | 23 | 24 | class cache: 25 | def __init__(self, function): 26 | self.function = function 27 | self.pickle_name = self.function.__name__ 28 | 29 | def __call__(self, *args, **kwargs): 30 | m = hashlib.sha256() 31 | m.update(pickle.dumps((self.function.__name__, args, frozenset(kwargs.items())))) 32 | output_path = os.path.join('.cache', "%s_%s" % (m.hexdigest(), self.pickle_name)) 33 | try: 34 | with open(output_path, 'rb') as f: 35 | data = pickle.load(f) 36 | except (FileNotFoundError, pickle.PickleError): 37 | data = self.function(*args, **kwargs) 38 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 39 | with open(output_path, 'wb') as f: 40 | pickle.dump(data, f) 41 | return data 42 | 43 | 44 | def save_model(x, name): 45 | if isinstance(x, nn.DataParallel): 46 | torch.save(x.module.state_dict(), name) 47 | else: 48 | torch.save(x.state_dict(), name) 49 | 50 | 51 | class AsyncCall(object): 52 | def __init__(self, fnc, callback=None): 53 | self.Callable = fnc 54 | self.Callback = callback 55 | self.result = None 56 | 57 | def __call__(self, *args, **kwargs): 58 | self.Thread = threading.Thread(target=self.run, name=self.Callable.__name__, args=args, kwargs=kwargs) 59 | self.Thread.start() 60 | return self 61 | 62 | def wait(self, timeout=None): 63 | self.Thread.join(timeout) 64 | if self.Thread.isAlive(): 65 | raise TimeoutError 66 | else: 67 | return self.result 68 | 69 | def run(self, *args, **kwargs): 70 | self.result = self.Callable(*args, **kwargs) 71 | if self.Callback: 72 | self.Callback(self.result) 73 | 74 | 75 | class AsyncMethod(object): 76 | def __init__(self, fnc, callback=None): 77 | self.Callable = fnc 78 | self.Callback = callback 79 | 80 | def __call__(self, *args, **kwargs): 81 | return AsyncCall(self.Callable, self.Callback)(*args, **kwargs) 82 | 83 | 84 | def async_func(fnc=None, callback=None): 85 | if fnc is None: 86 | def add_async_callback(f): 87 | return AsyncMethod(f, callback) 88 | return add_async_callback 89 | else: 90 | return AsyncMethod(fnc, callback) 91 | 92 | 93 | class Registry(dict): 94 | def __init__(self, *args, **kwargs): 95 | super(Registry, self).__init__(*args, **kwargs) 96 | 97 | def register(self, module_name): 98 | def register_fn(module): 99 | assert module_name not in self 100 | self[module_name] = module 101 | return module 102 | return register_fn 103 | -------------------------------------------------------------------------------- /soft_intro_vae_3d/test_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test a trained model (on the test split of the data) 3 | """ 4 | 5 | import json 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn.parallel 10 | import torch.utils.data 11 | from torch.utils.data import DataLoader 12 | 13 | from utils.util import cuda_setup 14 | from metrics.jsd import jsd_between_point_cloud_sets 15 | 16 | from models.vae import SoftIntroVAE 17 | 18 | 19 | def prepare_model(config, path_to_weights, device=torch.device("cpu")): 20 | model = SoftIntroVAE(config).to(device) 21 | model.load_state_dict(torch.load(path_to_weights, map_location=device)) 22 | model.eval() 23 | return model 24 | 25 | 26 | def prepare_data(config, split='train', batch_size=32): 27 | dataset_name = config['dataset'].lower() 28 | if dataset_name == 'shapenet': 29 | from datasets.shapenet import ShapeNetDataset 30 | dataset = ShapeNetDataset(root_dir=config['data_dir'], 31 | classes=config['classes'], split=split) 32 | else: 33 | raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' 34 | f'`faust`. Got: `{dataset_name}`') 35 | data_loader = DataLoader(dataset, batch_size=batch_size, 36 | shuffle=False, num_workers=4, 37 | drop_last=False, pin_memory=True) 38 | return data_loader 39 | 40 | 41 | def calc_jsd_valid(model, config, prior_std=1.0, split='valid'): 42 | model.eval() 43 | device = cuda_setup(config['cuda'], config['gpu']) 44 | dataset_name = config['dataset'].lower() 45 | if dataset_name == 'shapenet': 46 | from datasets.shapenet import ShapeNetDataset 47 | dataset = ShapeNetDataset(root_dir=config['data_dir'], 48 | classes=config['classes'], split=split) 49 | else: 50 | raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' 51 | f'`faust`. Got: `{dataset_name}`') 52 | classes_selected = ('all' if not config['classes'] 53 | else ','.join(config['classes'])) 54 | num_samples = len(dataset.point_clouds_names_valid) 55 | data_loader = DataLoader(dataset, batch_size=num_samples, 56 | shuffle=False, num_workers=4, 57 | drop_last=False, pin_memory=True) 58 | # We take 3 times as many samples as there are in test data in order to 59 | # perform JSD calculation in the same manner as in the reference publication 60 | 61 | x, _ = next(iter(data_loader)) 62 | x = x.to(device) 63 | 64 | # We average JSD computation from 3 independent trials. 65 | js_results = [] 66 | for _ in range(3): 67 | noise = prior_std * torch.randn(3 * num_samples, model.zdim) 68 | noise = noise.to(device) 69 | 70 | with torch.no_grad(): 71 | x_g = model.decode(noise) 72 | if x_g.shape[-2:] == (3, 2048): 73 | x_g.transpose_(1, 2) 74 | 75 | jsd = jsd_between_point_cloud_sets(x, x_g, voxels=28) 76 | js_results.append(jsd) 77 | js_result = np.mean(js_results) 78 | return js_result 79 | 80 | 81 | if __name__ == "__main__": 82 | path_to_weights = './results/vae/soft_intro_vae_chair/weights/00350_jsd_0.106.pth' 83 | config_path = 'config/soft_intro_vae_hp.json' 84 | config = None 85 | if config_path is not None and config_path.endswith('.json'): 86 | with open(config_path) as f: 87 | config = json.load(f) 88 | assert config is not None 89 | device = cuda_setup(config['cuda'], config['gpu']) 90 | print("using device: ", device) 91 | model = prepare_model(config, path_to_weights, device=device) 92 | test_jsd = calc_jsd_valid(model, config, prior_std=config["prior_std"], split='test') 93 | print(f'test jsd: {test_jsd}') 94 | -------------------------------------------------------------------------------- /soft_intro_vae_3d/evaluation/generate_data_for_metrics.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import random 5 | from importlib import import_module 6 | from os.path import join 7 | 8 | import numpy as np 9 | import torch 10 | from torch.distributions import Beta 11 | from torch.utils.data import DataLoader 12 | 13 | from datasets.shapenet.shapenet import ShapeNetDataset 14 | from models.vae import SoftIntroVAE, reparameterize 15 | from utils.util import cuda_setup 16 | 17 | 18 | def prepare_model(config, path_to_weights, device=torch.device("cpu")): 19 | model = SoftIntroVAE(config).to(device) 20 | model.load_state_dict(torch.load(path_to_weights, map_location=device)) 21 | model.eval() 22 | return model 23 | 24 | 25 | def main(eval_config): 26 | # Load hyperparameters as they were during training 27 | train_results_path = join(eval_config['results_root'], eval_config['arch'], 28 | eval_config['experiment_name']) 29 | with open(join(train_results_path, 'config.json')) as f: 30 | train_config = json.load(f) 31 | 32 | random.seed(train_config['seed']) 33 | torch.manual_seed(train_config['seed']) 34 | torch.cuda.manual_seed_all(train_config['seed']) 35 | 36 | device = cuda_setup(config['cuda'], config['gpu']) 37 | print("using device: ", device) 38 | 39 | # 40 | # Dataset 41 | # 42 | dataset_name = train_config['dataset'].lower() 43 | if dataset_name == 'shapenet': 44 | dataset = ShapeNetDataset(root_dir=train_config['data_dir'], 45 | classes=train_config['classes'], split='test') 46 | else: 47 | raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' 48 | f'`faust`. Got: `{dataset_name}`') 49 | classes_selected = ('all' if not train_config['classes'] 50 | else ','.join(train_config['classes'])) 51 | 52 | # 53 | # Models 54 | # 55 | 56 | model = prepare_model(config, path_to_weights, device=device) 57 | model.eval() 58 | 59 | num_samples = len(dataset.point_clouds_names_test) 60 | data_loader = DataLoader(dataset, batch_size=num_samples, 61 | shuffle=False, num_workers=4, 62 | drop_last=False, pin_memory=True) 63 | 64 | # We take 3 times as many samples as there are in test data in order to 65 | # perform JSD calculation in the same manner as in the reference publication 66 | 67 | x, _ = next(iter(data_loader)) 68 | x = x.to(device) 69 | 70 | np.save(join(train_results_path, 'results', f'_X'), x) 71 | 72 | prior_std = config["prior_std"] 73 | 74 | for i in range(3): 75 | noise = prior_std * torch.randn(3 * num_samples, model.zdim) 76 | noise = noise.to(device) 77 | 78 | with torch.no_grad(): 79 | x_g = model.decode(noise) 80 | if x_g.shape[-2:] == (3, 2048): 81 | x_g.transpose_(1, 2) 82 | np.save(join(train_results_path, 'results', f'_Xg_{i}'), x_g) 83 | 84 | with torch.no_grad(): 85 | mu_z, logvar_z = model.encode(x) 86 | data_z = reparameterize(mu_z, logvar_z) 87 | # x_rec = model.decode(data_z) # stochastic 88 | x_rec = model.decode(mu_z) # deterministic 89 | if x_rec.shape[-2:] == (3, 2048): 90 | x_rec.transpose_(1, 2) 91 | 92 | np.save(join(train_results_path, 'results', f'_Xrec'), x_rec) 93 | 94 | 95 | if __name__ == '__main__': 96 | path_to_weights = './results/vae/soft_intro_vae_chair/weights/00350_jsd_0.106.pth' 97 | config_path = 'config/soft_intro_vae_hp.json' 98 | config = None 99 | if config_path is not None and config_path.endswith('.json'): 100 | with open(config_path) as f: 101 | config = json.load(f) 102 | assert config is not None 103 | 104 | main(config) 105 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: torch 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - blas=1.0=mkl 9 | - bzip2=1.0.8=h516909a_3 10 | - ca-certificates=2020.12.5=ha878542_0 11 | - cairo=1.16.0=h18b612c_1001 12 | - certifi=2020.12.5=py36h5fab9bb_0 13 | - cudatoolkit=10.1.243=h6bb024c_0 14 | - cycler=0.10.0=py_2 15 | - dataclasses=0.7=py36_0 16 | - dbus=1.13.6=he372182_0 17 | - expat=2.2.9=he1b5a44_2 18 | - ffmpeg=4.0=hcdf2ecd_0 19 | - fontconfig=2.13.1=he4413a7_1000 20 | - freeglut=3.2.1=h58526e2_0 21 | - freetype=2.10.4=h5ab3b9f_0 22 | - glib=2.66.1=h92f7085_0 23 | - graphite2=1.3.13=h58526e2_1001 24 | - gst-plugins-base=1.14.0=hbbd80ab_1 25 | - gstreamer=1.14.0=hb31296c_0 26 | - harfbuzz=1.8.8=hffaf4a1_0 27 | - hdf5=1.10.2=hc401514_3 28 | - icu=58.2=hf484d3e_1000 29 | - intel-openmp=2020.2=254 30 | - jasper=2.0.14=h07fcdf6_1 31 | - joblib=0.17.0=py_0 32 | - jpeg=9b=h024ee3a_2 33 | - kiwisolver=1.3.1=py36h51d7077_0 34 | - lcms2=2.11=h396b838_0 35 | - ld_impl_linux-64=2.33.1=h53a641e_7 36 | - libblas=3.9.0=1_h6e990d7_netlib 37 | - libcblas=3.9.0=3_h893e4fe_netlib 38 | - libedit=3.1.20191231=h14c3975_1 39 | - libffi=3.3=he6710b0_2 40 | - libgcc-ng=9.1.0=hdf63c60_0 41 | - libgfortran=3.0.0=1 42 | - libgfortran-ng=7.5.0=hae1eefd_17 43 | - libgfortran4=7.5.0=hae1eefd_17 44 | - libglu=9.0.0=he1b5a44_1001 45 | - liblapack=3.9.0=3_h893e4fe_netlib 46 | - libopencv=3.4.2=hb342d67_1 47 | - libopus=1.3.1=h7b6447c_0 48 | - libpng=1.6.37=hbc83047_0 49 | - libstdcxx-ng=9.1.0=hdf63c60_0 50 | - libtiff=4.1.0=h2733197_1 51 | - libuuid=2.32.1=h14c3975_1000 52 | - libuv=1.40.0=h7b6447c_0 53 | - libvpx=1.7.0=h439df22_0 54 | - libxcb=1.13=h14c3975_1002 55 | - libxml2=2.9.10=hb55368b_3 56 | - lz4-c=1.9.2=heb0550a_3 57 | - matplotlib=3.3.3=py36h5fab9bb_0 58 | - matplotlib-base=3.3.3=py36he12231b_0 59 | - mkl=2020.2=256 60 | - mkl-service=2.3.0=py36he8ac12f_0 61 | - mkl_fft=1.2.0=py36h23d657b_0 62 | - mkl_random=1.1.1=py36h0573a6f_0 63 | - ncurses=6.2=he6710b0_1 64 | - ninja=1.10.2=py36hff7bd54_0 65 | - numpy=1.19.2=py36h54aff64_0 66 | - numpy-base=1.19.2=py36hfa32c7d_0 67 | - olefile=0.46=py36_0 68 | - opencv=3.4.2=py36h6fd60c2_1 69 | - openssl=1.1.1h=h516909a_0 70 | - pcre=8.44=he1b5a44_0 71 | - pillow=8.0.1=py36he98fc37_0 72 | - pip=20.3=py36h06a4308_0 73 | - pixman=0.38.0=h516909a_1003 74 | - pthread-stubs=0.4=h36c2ea0_1001 75 | - py-opencv=3.4.2=py36hb342d67_1 76 | - pyparsing=2.4.7=pyh9f0ad1d_0 77 | - pyqt=5.9.2=py36hcca6a23_4 78 | - python=3.6.12=hcff3b4d_2 79 | - python-dateutil=2.8.1=py_0 80 | - python_abi=3.6=1_cp36m 81 | - pytorch=1.7.0=py3.6_cuda10.1.243_cudnn7.6.3_0 82 | - qt=5.9.7=h5867ecd_1 83 | - readline=8.0=h7b6447c_0 84 | - scikit-learn=0.23.2=py36hb6e6923_3 85 | - scipy=1.5.3=py36h976291a_0 86 | - setuptools=50.3.2=py36h06a4308_2 87 | - sip=4.19.8=py36hf484d3e_0 88 | - six=1.15.0=py36h06a4308_0 89 | - sqlite=3.33.0=h62c20be_0 90 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 91 | - tk=8.6.10=hbc83047_0 92 | - torchaudio=0.7.0=py36 93 | - torchvision=0.8.1=py36_cu101 94 | - tornado=6.1=py36h1d69622_0 95 | - tqdm=4.54.1=pyhd8ed1ab_0 96 | - typing_extensions=3.7.4.3=py_0 97 | - wheel=0.36.0=pyhd3eb1b0_0 98 | - xorg-fixesproto=5.0=h14c3975_1002 99 | - xorg-inputproto=2.3.2=h14c3975_1002 100 | - xorg-kbproto=1.0.7=h14c3975_1002 101 | - xorg-libice=1.0.10=h516909a_0 102 | - xorg-libsm=1.2.3=h84519dc_1000 103 | - xorg-libx11=1.6.12=h516909a_0 104 | - xorg-libxau=1.0.9=h14c3975_0 105 | - xorg-libxdmcp=1.1.3=h516909a_0 106 | - xorg-libxext=1.3.4=h516909a_0 107 | - xorg-libxfixes=5.0.3=h516909a_1004 108 | - xorg-libxi=1.7.10=h516909a_0 109 | - xorg-libxrender=0.9.10=h516909a_1002 110 | - xorg-renderproto=0.11.1=h14c3975_1002 111 | - xorg-xextproto=7.3.0=h14c3975_1002 112 | - xorg-xproto=7.0.31=h14c3975_1007 113 | - xz=5.2.5=h7b6447c_0 114 | - zlib=1.2.11=h7b6447c_3 115 | - zstd=1.4.5=h9ceee32_0 116 | - pip: 117 | - future==0.18.2 118 | - kornia==0.4.1 119 | - protobuf==3.14.0 120 | - tensorboardx==2.1 121 | 122 | -------------------------------------------------------------------------------- /style_soft_intro_vae/scheduler.py: -------------------------------------------------------------------------------- 1 | from bisect import bisect_right 2 | import torch 3 | import numpy as np 4 | 5 | 6 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 7 | def __init__( 8 | self, 9 | optimizer, 10 | milestones, 11 | gamma=0.1, 12 | warmup_factor=1.0 / 1.0, 13 | warmup_iters=1, 14 | last_epoch=-1, 15 | reference_batch_size=128, 16 | lr=[] 17 | ): 18 | if not list(milestones) == sorted(milestones): 19 | raise ValueError( 20 | "Milestones should be a list of" " increasing integers. Got {}", 21 | milestones, 22 | ) 23 | self.milestones = milestones 24 | self.gamma = gamma 25 | self.warmup_factor = warmup_factor 26 | self.warmup_iters = warmup_iters 27 | self.batch_size = 1 28 | self.lod = 0 29 | self.reference_batch_size = reference_batch_size 30 | 31 | self.optimizer = optimizer 32 | self.base_lrs = [] 33 | for _ in self.optimizer.param_groups: 34 | self.base_lrs.append(lr) 35 | 36 | self.last_epoch = last_epoch 37 | 38 | if not isinstance(optimizer, torch.optim.Optimizer): 39 | raise TypeError('{} is not an Optimizer'.format( 40 | type(optimizer).__name__)) 41 | self.optimizer = optimizer 42 | 43 | if last_epoch == -1: 44 | for group in optimizer.param_groups: 45 | group.setdefault('initial_lr', group['lr']) 46 | last_epoch = 0 47 | 48 | self.last_epoch = last_epoch 49 | 50 | self.optimizer._step_count = 0 51 | self._step_count = 0 52 | # self.step(last_epoch) 53 | self.step() 54 | 55 | def set_batch_size(self, batch_size, lod): 56 | self.batch_size = batch_size 57 | self.lod = lod 58 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 59 | param_group['lr'] = lr 60 | 61 | def get_lr(self): 62 | warmup_factor = 1 63 | if self.last_epoch < self.warmup_iters: 64 | alpha = float(self.last_epoch) / self.warmup_iters 65 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 66 | return [ 67 | base_lr[self.lod] 68 | * warmup_factor 69 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 70 | # * float(self.batch_size) 71 | # / float(self.reference_batch_size) 72 | for base_lr in self.base_lrs 73 | ] 74 | 75 | def state_dict(self): 76 | return { 77 | "last_epoch": self.last_epoch 78 | } 79 | 80 | def load_state_dict(self, state_dict): 81 | self.__dict__.update(dict(last_epoch=state_dict["last_epoch"])) 82 | 83 | 84 | class ComboMultiStepLR: 85 | def __init__( 86 | self, 87 | optimizers, base_lr, 88 | **kwargs 89 | ): 90 | self.schedulers = dict() 91 | for name, opt in optimizers.items(): 92 | self.schedulers[name] = WarmupMultiStepLR(opt, lr=base_lr, **kwargs) 93 | self.last_epoch = 0 94 | 95 | def set_batch_size(self, batch_size, lod): 96 | for x in self.schedulers.values(): 97 | x.set_batch_size(batch_size, lod) 98 | 99 | def step(self, epoch=None): 100 | for x in self.schedulers.values(): 101 | # x.step(epoch) 102 | x.step() 103 | if epoch is None: 104 | epoch = self.last_epoch + 1 105 | self.last_epoch = epoch 106 | 107 | def state_dict(self): 108 | return {key: value.state_dict() for key, value in self.schedulers.items()} 109 | 110 | def load_state_dict(self, state_dict): 111 | for k, x in self.schedulers.items(): 112 | x.load_state_dict(state_dict[k]) 113 | 114 | last_epochs = [x.last_epoch for k, x in self.schedulers.items()] 115 | assert np.all(np.asarray(last_epochs) == last_epochs[0]) 116 | self.last_epoch = last_epochs[0] 117 | 118 | def start_epoch(self): 119 | return self.last_epoch 120 | -------------------------------------------------------------------------------- /soft_intro_vae_2d/README.md: -------------------------------------------------------------------------------- 1 | # soft-intro-vae-pytorch-2d 2 | 3 | Implementation of Soft-IntroVAE for tabular (2D) data. 4 | 5 | A step-by-step tutorial can be found in [Soft-IntroVAE Jupyter Notebook Tutorials](https://github.com/taldatech/soft-intro-vae-pytorch/tree/main/soft_intro_vae_tutorial). 6 | 7 |

8 | 9 | 10 |

11 | 12 | - [soft-intro-vae-pytorch-2d](#soft-intro-vae-pytorch-2d) 13 | * [Training](#training) 14 | * [Recommended hyperparameters](#recommended-hyperparameters) 15 | * [What to expect](#what-to-expect) 16 | * [Files and directories in the repository](#files-and-directories-in-the-repository) 17 | * [Tutorial](#tutorial) 18 | 19 | ## Training 20 | 21 | `main.py --help` 22 | 23 | 24 | You should use the `main.py` file with the following arguments: 25 | 26 | |Argument | Description |Legal Values | 27 | |-------------------------|---------------------------------------------|-------------| 28 | |-h, --help | shows arguments description | | 29 | |-d, --dataset | dataset to train on |str: '8Gaussians', '2spirals', 'checkerboard', rings' | 30 | |-n, --num_iter | total number of iterations to run | int: default=30000| 31 | |-z, --z_dim| latent dimensions | int: default=2| 32 | |-s, --seed| random state to use. for random: -1 | int: -1 , 0, 1, 2 ,....| 33 | |-v, --num_vae| number of iterations for vanilla vae training | int: default=2000| 34 | |-l, --lr| learning rate | float: defalut=2e-4 | 35 | |-r, --beta_rec | beta coefficient for the reconstruction loss |float: default=0.2| 36 | |-k, --beta_kl| beta coefficient for the kl divergence | float: default=0.3| 37 | |-e, --beta_neg| beta coefficient for the kl divergence in the expELBO function | float: default=0.9| 38 | |-g, --gamma_r| coefficient for the reconstruction loss for fake data in the decoder | float: default=1e-8| 39 | |-b, --batch_size| batch size | int: default=512 | 40 | |-p, --pretrained | path to pretrained model, to continue training |str: default="None" | 41 | |-c, --device| device: -1 for cpu, 0 and up for specific cuda device |int: default=-1| 42 | 43 | 44 | Examples: 45 | 46 | `python main.py --dataset 8Gaussians --device 0 --seed 92 --lr 2e-4 --num_vae 2000 --num_iter 30000 --beta_kl 0.3 --beta_rec 0.2 --beta_neg 0.9` 47 | 48 | `python main.py --dataset rings --device -1 --seed -1 --lr 2e-4 --num_vae 2000 --num_iter 30000 --beta_kl 0.2 --beta_rec 0.2 --beta_neg 1.0` 49 | 50 | ## Recommended hyperparameters 51 | 52 | |Dataset | `beta_kl` | `beta_rec`| `beta_neg`| 53 | |------------|------|----|---| 54 | |`8Gaussians`|0.3|0.2| 0.9| 55 | |`2spirals`|0.5|0.2|1.0| 56 | |`checkerboard`|0.1|0.2|0.2| 57 | |`rings`|0.2|0.2|1.0| 58 | 59 | 60 | ## What to expect 61 | 62 | * During the training, figures of samples and density plots are saved locally. 63 | * During training, statistics are printed (reconstruction error, KLD, expELBO). 64 | * At the end of the training, the following quantities are calculated, printed and saved to a `.txt` file: grid-normalized ELBO (gnELBO), KL, JSD 65 | * Tips: 66 | * KL of fake/rec samples should be >= KL of real data 67 | * You will see that the deterministic reconstruction error is printed in parenthesis, it should be lower than the stochastic reconstruction error. 68 | * We found that for the 2D datasets, it better to initialize the networks with vanilla vae training (about 2000 iterations is good). 69 | 70 | 71 | ## Files and directories in the repository 72 | 73 | |File name | Purpose | 74 | |----------------------|------| 75 | |`main.py`| general purpose main application for training Soft-IntroVAE for 2D data| 76 | |`train_soft_intro_vae_2d.py`| main training function, datasets and architectures| 77 | 78 | 79 | ## Tutorial 80 | * [Jupyter Notebook tutorial for 2D datasets](https://github.com/taldatech/soft-intro-vae-pytorch/blob/main/soft_intro_vae_tutorial/soft_intro_vae_2d_code_tutorial.ipynb) 81 | * [Open in Colab](https://colab.research.google.com/github/taldatech/soft-intro-vae-pytorch/blob/main/soft_intro_vae_tutorial/soft_intro_vae_2d_code_tutorial.ipynb) -------------------------------------------------------------------------------- /style_soft_intro_vae/custom_adam.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019-2020 Stanislav Pidhorskyi 2 | # lr_equalization_coef was added for LREQ 3 | 4 | # Copyright (c) 2016- Facebook, Inc (Adam Paszke) 5 | # Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 6 | # Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 7 | # Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 8 | # Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 9 | # Copyright (c) 2011-2013 NYU (Clement Farabet) 10 | # Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 11 | # Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 12 | # Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 13 | 14 | # https://github.com/pytorch/pytorch/blob/master/LICENSE 15 | 16 | 17 | import math 18 | import torch 19 | from torch.optim.optimizer import Optimizer 20 | 21 | 22 | class LREQAdam(Optimizer): 23 | def __init__(self, params, lr=1e-3, betas=(0.0, 0.99), eps=1e-8, 24 | weight_decay=0): 25 | beta_2 = betas[1] 26 | if not 0.0 <= lr: 27 | raise ValueError("Invalid learning rate: {}".format(lr)) 28 | if not 0.0 <= eps: 29 | raise ValueError("Invalid epsilon value: {}".format(eps)) 30 | if not 0.0 == betas[0]: 31 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 32 | if not 0.0 <= beta_2 < 1.0: 33 | raise ValueError("Invalid beta parameter at index 1: {}".format(beta_2)) 34 | defaults = dict(lr=lr, beta_2=beta_2, eps=eps, 35 | weight_decay=weight_decay) 36 | super(LREQAdam, self).__init__(params, defaults) 37 | 38 | def __setstate__(self, state): 39 | super(LREQAdam, self).__setstate__(state) 40 | 41 | def step(self, closure=None): 42 | """Performs a single optimization step. 43 | 44 | Arguments: 45 | closure (callable, optional): A closure that reevaluates the model 46 | and returns the loss. 47 | """ 48 | loss = None 49 | if closure is not None: 50 | loss = closure() 51 | 52 | for group in self.param_groups: 53 | for p in group['params']: 54 | if p.grad is None: 55 | continue 56 | grad = p.grad.data 57 | if grad.is_sparse: 58 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 59 | 60 | state = self.state[p] 61 | 62 | # State initialization 63 | if len(state) == 0: 64 | state['step'] = 0 65 | # Exponential moving average of gradient values 66 | # state['exp_avg'] = torch.zeros_like(p.data) 67 | # Exponential moving average of squared gradient values 68 | state['exp_avg_sq'] = torch.zeros_like(p.data) 69 | 70 | # exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 71 | exp_avg_sq = state['exp_avg_sq'] 72 | beta_2 = group['beta_2'] 73 | 74 | state['step'] += 1 75 | 76 | if group['weight_decay'] != 0: 77 | grad.add_(group['weight_decay'], p.data / p.coef) 78 | 79 | # Decay the first and second moment running average coefficient 80 | # exp_avg.mul_(beta1).add_(1 - beta1, grad) 81 | # exp_avg_sq.mul_(beta_2).addcmul_(1 - beta_2, grad, grad) 82 | exp_avg_sq.mul_(beta_2).addcmul_(grad, grad, value=1 - beta_2) 83 | denom = exp_avg_sq.sqrt().add_(group['eps']) 84 | 85 | # bias_correction1 = 1 - beta1 ** state['step'] # 1 86 | bias_correction2 = 1 - beta_2 ** state['step'] 87 | # step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 88 | step_size = group['lr'] * math.sqrt(bias_correction2) 89 | 90 | # p.data.addcdiv_(-step_size, exp_avg, denom) 91 | if hasattr(p, 'lr_equalization_coef'): 92 | step_size *= p.lr_equalization_coef 93 | 94 | # p.data.addcdiv_(-step_size, grad, denom) 95 | p.data.addcdiv_(grad, denom, value=-step_size) 96 | 97 | return loss 98 | -------------------------------------------------------------------------------- /soft_intro_vae_3d/utils/data.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import os 4 | import pandas as pd 5 | import pickle 6 | 7 | from decimal import Decimal 8 | from itertools import accumulate, tee, chain 9 | from typing import List, Tuple, Dict, Optional, Any, Set 10 | 11 | from utils.plyfile import load_ply 12 | 13 | READERS = { 14 | '.ply': load_ply, 15 | '.np': lambda file_path: pickle.load(open(file_path, 'rb')), 16 | } 17 | 18 | 19 | def load_file(file_path): 20 | _, ext = os.path.splitext(file_path) 21 | return READERS[ext](file_path) 22 | 23 | 24 | def add_float(a, b): 25 | return float(Decimal(str(a)) + Decimal(str(b))) 26 | 27 | 28 | def ranges(values: List[float]) -> List[Tuple[float]]: 29 | lower, upper = tee(accumulate(values, add_float)) 30 | lower = chain([0], lower) 31 | 32 | return zip(lower, upper) 33 | 34 | 35 | def make_slices(values: List[float], N: int): 36 | slices = [slice(int(N * s), int(N * e)) for s, e in ranges(values)] 37 | return slices 38 | 39 | 40 | def make_splits( 41 | data: pd.DataFrame, 42 | splits: Dict[str, float], 43 | seed: Optional[int] = None): 44 | 45 | # assert correctness 46 | if not math.isclose(sum(splits.values()), 1.0): 47 | values = " ".join([f"{k} : {v}" for k, v in splits.items()]) 48 | raise ValueError(f"{values} should sum up to 1") 49 | 50 | # shuffle with random seed 51 | data = data.iloc[np.random.permutation(len(data))] 52 | slices = make_slices(list(splits.values()), len(data)) 53 | 54 | return { 55 | name: data[idxs].reset_index(drop=True) for name, idxs in zip(splits.keys(), slices) 56 | } 57 | 58 | 59 | def sample_other_than(black_list: Set[int], x: np.ndarray) -> int: 60 | res = np.random.randint(0, len(x)) 61 | while res in black_list: 62 | res = np.random.randint(0, len(x)) 63 | 64 | return res 65 | 66 | 67 | def clip_cloud(p: np.ndarray) -> np.ndarray: 68 | # create list of extreme points 69 | black_list = set(np.hstack([ 70 | np.argmax(p, axis=0), np.argmin(p, axis=0) 71 | ])) 72 | 73 | # swap any other point 74 | for idx in black_list: 75 | p[idx] = p[sample_other_than(black_list, p)] 76 | 77 | return p 78 | 79 | 80 | def find_extrema(xs, n_cols: int=3, clip: bool=True) -> Dict[Any, List[float]]: 81 | from collections import defaultdict 82 | 83 | mins = defaultdict(lambda: [np.inf for _ in range(n_cols)]) 84 | maxs = defaultdict(lambda: [-np.inf for _ in range(n_cols)]) 85 | 86 | for x, c in xs: 87 | x = clip_cloud(x) if clip else x 88 | mins[c] = [min(old, new) for old, new in zip(mins[c], np.min(x, axis=0))] 89 | maxs[c] = [max(old, new) for old, new in zip(maxs[c], np.max(x, axis=0))] 90 | 91 | return mins, maxs 92 | 93 | 94 | def merge_dicts( 95 | dict_old: Dict[Any, List[float]], 96 | dict_new: Dict[Any, List[float]], op=min) -> Dict[Any, List[float]]: 97 | ''' 98 | Simply takes values on List of floats for given key 99 | ''' 100 | d_out = {** dict_old} 101 | for k, v in dict_new.items(): 102 | if k in dict_old: 103 | d_out[k] = [op(new, old) for old, new in zip(dict_new[k], dict_old[k])] 104 | else: 105 | d_out[k] = dict_new[k] 106 | 107 | return d_out 108 | 109 | 110 | def save_extrema(clazz, root_dir, splits=('train', 'test', 'valid')): 111 | ''' 112 | Maybe this should be class dependent normalization? 113 | ''' 114 | min_dict, max_dict = {}, {} 115 | for split in splits: 116 | data = clazz(root_dir=root_dir, split=split, remap=False) 117 | mins, maxs = find_extrema(data) 118 | min_dict = merge_dicts(min_dict, mins, min) 119 | max_dict = merge_dicts(max_dict, maxs, max) 120 | 121 | # vectorzie values 122 | for d in (min_dict, max_dict): 123 | for k in d: 124 | d[k] = np.array(d[k]) 125 | 126 | with open(os.path.join(root_dir, 'extrema.np'), 'wb') as f: 127 | pickle.dump((min_dict, max_dict), f) 128 | 129 | 130 | def remap(old_value: np.ndarray, 131 | old_min: np.ndarray, old_max: np.ndarray, 132 | new_min: float = -0.5, new_max: float = 0.5) -> np.ndarray: 133 | ''' 134 | Remap reange 135 | ''' 136 | old_range = (old_max - old_min) 137 | new_range = (new_max - new_min) 138 | new_value = (((old_value - old_min) * new_range) / old_range) + new_min 139 | 140 | return new_value 141 | -------------------------------------------------------------------------------- /style_soft_intro_vae/make_figures/generate_samples.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2021 Tal Daniel 2 | # Copyright 2019-2020 Stanislav Pidhorskyi 3 | # 4 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # This work is licensed under the Creative Commons Attribution-NonCommercial 7 | # 4.0 International License. To view a copy of this license, visit 8 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 9 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 10 | 11 | from net import * 12 | from model import SoftIntroVAEModelTL 13 | from launcher import run 14 | from dataloader import * 15 | from checkpointer import Checkpointer 16 | # from dlutils.pytorch import count_parameters 17 | from defaults import get_cfg_defaults 18 | from PIL import Image 19 | import PIL 20 | from pathlib import Path 21 | from tqdm import tqdm 22 | 23 | 24 | def millify(n): 25 | millnames = ['', 'k', 'M', 'G', 'T', 'P'] 26 | n = float(n) 27 | millidx = max(0, min(len(millnames) - 1, int(math.floor(0 if n == 0 else math.log10(abs(n)) / 3)))) 28 | 29 | return '{:.1f}{}'.format(n / 10 ** (3 * millidx), millnames[millidx]) 30 | 31 | 32 | def count_parameters(model, print_func=print, verbose=False): 33 | for n, p in model.named_parameters(): 34 | if p.requires_grad and verbose: 35 | print_func(n, millify(p.numel())) 36 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 37 | 38 | 39 | def generate_samples(cfg, model, save_dir, num_samples, device=torch.device("cpu")): 40 | for i in tqdm(range(num_samples)): 41 | samplez = torch.randn(size=(1, cfg.MODEL.LATENT_SPACE_SIZE)).float().to(device) 42 | image = model.generate(cfg.DATASET.MAX_RESOLUTION_LEVEL - 2, 1, samplez, 1, mixing=True)[0] 43 | im = image.data.cpu().numpy() 44 | im = im.transpose(1, 2, 0) 45 | im = im * 0.5 + 0.5 46 | image = PIL.Image.fromarray(np.clip(im * 255, 0, 255).astype(np.uint8), 'RGB') 47 | image.save(os.path.join(save_dir, 'image_{}.jpg'.format(i))) 48 | 49 | 50 | def sample(cfg, logger): 51 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 52 | torch.cuda.set_device(2) 53 | model = SoftIntroVAEModelTL( 54 | startf=cfg.MODEL.START_CHANNEL_COUNT, 55 | layer_count=cfg.MODEL.LAYER_COUNT, 56 | maxf=cfg.MODEL.MAX_CHANNEL_COUNT, 57 | latent_size=cfg.MODEL.LATENT_SPACE_SIZE, 58 | dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA, 59 | style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB, 60 | mapping_layers=cfg.MODEL.MAPPING_LAYERS, 61 | channels=cfg.MODEL.CHANNELS, 62 | generator=cfg.MODEL.GENERATOR, 63 | encoder=cfg.MODEL.ENCODER, 64 | beta_kl=cfg.MODEL.BETA_KL, 65 | beta_rec=cfg.MODEL.BETA_REC, 66 | beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1], 67 | scale=cfg.MODEL.SCALE 68 | ) 69 | 70 | model.to(device) 71 | model.eval() 72 | model.requires_grad_(False) 73 | 74 | decoder = model.decoder 75 | encoder = model.encoder 76 | mapping_tl = model.mapping_tl 77 | mapping_fl = model.mapping_fl 78 | 79 | dlatent_avg = model.dlatent_avg 80 | 81 | logger.info("Trainable parameters decoder:") 82 | print(count_parameters(decoder)) 83 | 84 | logger.info("Trainable parameters encoder:") 85 | print(count_parameters(encoder)) 86 | 87 | arguments = dict() 88 | arguments["iteration"] = 0 89 | 90 | model_dict = { 91 | 'discriminator_s': encoder, 92 | 'generator_s': decoder, 93 | 'mapping_tl_s': mapping_tl, 94 | 'mapping_fl_s': mapping_fl, 95 | 'dlatent_avg': dlatent_avg 96 | } 97 | 98 | checkpointer = Checkpointer(cfg, 99 | model_dict, 100 | {}, 101 | logger=logger, 102 | save=False) 103 | 104 | checkpointer.load() 105 | 106 | model.eval() 107 | 108 | path = './make_figures/output' 109 | os.makedirs(path, exist_ok=True) 110 | os.makedirs(os.path.join(path, cfg.NAME), exist_ok=True) 111 | with torch.no_grad(): 112 | generate_samples(cfg, model, path, 5, device=device) 113 | 114 | 115 | if __name__ == "__main__": 116 | gpu_count = 1 117 | run(sample, get_cfg_defaults(), description='SoftIntroVAEVAE-generations', 118 | default_config='./configs/ffhq256.yaml', 119 | world_size=gpu_count, write_log=False) 120 | -------------------------------------------------------------------------------- /style_soft_intro_vae/launcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Stanislav Pidhorskyi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import os 17 | import sys 18 | import argparse 19 | import logging 20 | import torch 21 | import torch.multiprocessing as mp 22 | from torch import distributed 23 | import inspect 24 | 25 | 26 | def setup(rank, world_size): 27 | os.environ['MASTER_ADDR'] = 'localhost' 28 | os.environ['MASTER_PORT'] = '12355' 29 | distributed.init_process_group("nccl", rank=rank, world_size=world_size) 30 | 31 | 32 | def cleanup(): 33 | distributed.destroy_process_group() 34 | 35 | 36 | def _run(rank, world_size, fn, defaults, write_log, no_cuda, args): 37 | if world_size > 1: 38 | setup(rank, world_size) 39 | if not no_cuda: 40 | torch.cuda.set_device(rank) 41 | 42 | cfg = defaults 43 | config_file = args.config_file 44 | if len(os.path.splitext(config_file)[1]) == 0: 45 | config_file += '.yaml' 46 | if not os.path.exists(config_file) and os.path.exists(os.path.join('configs', config_file)): 47 | config_file = os.path.join('configs', config_file) 48 | cfg.merge_from_file(config_file) 49 | cfg.merge_from_list(args.opts) 50 | cfg.freeze() 51 | 52 | logger = logging.getLogger("logger") 53 | logger.setLevel(logging.DEBUG) 54 | 55 | output_dir = cfg.OUTPUT_DIR 56 | os.makedirs(output_dir, exist_ok=True) 57 | 58 | if rank == 0: 59 | ch = logging.StreamHandler(stream=sys.stdout) 60 | ch.setLevel(logging.DEBUG) 61 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 62 | ch.setFormatter(formatter) 63 | logger.addHandler(ch) 64 | 65 | if write_log: 66 | filepath = os.path.join(output_dir, 'log.txt') 67 | if isinstance(write_log, str): 68 | filepath = write_log 69 | fh = logging.FileHandler(filepath) 70 | fh.setLevel(logging.DEBUG) 71 | fh.setFormatter(formatter) 72 | logger.addHandler(fh) 73 | 74 | logger.info(args) 75 | 76 | logger.info("World size: {}".format(world_size)) 77 | 78 | logger.info("Loaded configuration file {}".format(config_file)) 79 | with open(config_file, "r") as cf: 80 | config_str = "\n" + cf.read() 81 | logger.info(config_str) 82 | logger.info("Running with config:\n{}".format(cfg)) 83 | 84 | if not no_cuda: 85 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 86 | device = torch.cuda.current_device() 87 | print("Running on ", torch.cuda.get_device_name(device)) 88 | 89 | args.distributed = world_size > 1 90 | args_to_pass = dict(cfg=cfg, logger=logger, local_rank=rank, world_size=world_size, distributed=args.distributed) 91 | signature = inspect.signature(fn) 92 | matching_args = {} 93 | for key in args_to_pass.keys(): 94 | if key in signature.parameters.keys(): 95 | matching_args[key] = args_to_pass[key] 96 | fn(**matching_args) 97 | 98 | if world_size > 1: 99 | cleanup() 100 | 101 | 102 | def run(fn, defaults, description='', default_config='configs/experiment.yaml', world_size=1, write_log=True, no_cuda=False): 103 | parser = argparse.ArgumentParser(description=description) 104 | parser.add_argument( 105 | "-c", "--config-file", 106 | default=default_config, 107 | metavar="FILE", 108 | help="path to config file", 109 | type=str, 110 | ) 111 | parser.add_argument( 112 | "opts", 113 | help="Modify config options using the command-line", 114 | default=None, 115 | nargs=argparse.REMAINDER, 116 | ) 117 | 118 | import multiprocessing 119 | cpu_count = multiprocessing.cpu_count() 120 | os.environ["OMP_NUM_THREADS"] = str(max(1, int(cpu_count / world_size))) 121 | del multiprocessing 122 | 123 | args = parser.parse_args() 124 | 125 | if world_size > 1: 126 | mp.spawn(_run, 127 | args=(world_size, fn, defaults, write_log, no_cuda, args), 128 | nprocs=world_size, 129 | join=True) 130 | else: 131 | _run(0, world_size, fn, defaults, write_log, no_cuda, args) 132 | -------------------------------------------------------------------------------- /style_soft_intro_vae/lod_driver.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019-2020 Stanislav Pidhorskyi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import torch 17 | import math 18 | import time 19 | from collections import defaultdict 20 | 21 | 22 | class LODDriver: 23 | def __init__(self, cfg, logger, world_size, dataset_size): 24 | if world_size == 8: 25 | self.lod_2_batch = cfg.TRAIN.LOD_2_BATCH_8GPU 26 | if world_size == 4: 27 | self.lod_2_batch = cfg.TRAIN.LOD_2_BATCH_4GPU 28 | if world_size == 2: 29 | self.lod_2_batch = cfg.TRAIN.LOD_2_BATCH_2GPU 30 | if world_size == 1: 31 | self.lod_2_batch = cfg.TRAIN.LOD_2_BATCH_1GPU 32 | 33 | self.world_size = world_size 34 | self.minibatch_base = 16 35 | self.cfg = cfg 36 | self.dataset_size = dataset_size 37 | self.current_epoch = 0 38 | self.lod = -1 39 | self.in_transition = False 40 | self.logger = logger 41 | self.iteration = 0 42 | self.epoch_end_time = 0 43 | self.epoch_start_time = 0 44 | self.per_epoch_ptime = 0 45 | self.reports = cfg.TRAIN.REPORT_FREQ 46 | self.snapshots = cfg.TRAIN.SNAPSHOT_FREQ 47 | self.tick_start_nimg_report = 0 48 | self.tick_start_nimg_snapshot = 0 49 | 50 | def get_lod_power2(self): 51 | return self.lod + 2 52 | 53 | def get_batch_size(self): 54 | return self.lod_2_batch[min(self.lod, len(self.lod_2_batch) - 1)] 55 | 56 | def get_dataset_size(self): 57 | return self.dataset_size 58 | 59 | def get_per_GPU_batch_size(self): 60 | return self.get_batch_size() // self.world_size 61 | 62 | def get_blend_factor(self): 63 | if self.cfg.TRAIN.EPOCHS_PER_LOD == 0: 64 | return 1 65 | blend_factor = float((self.current_epoch % self.cfg.TRAIN.EPOCHS_PER_LOD) * self.dataset_size + self.iteration) 66 | blend_factor /= float(self.cfg.TRAIN.EPOCHS_PER_LOD // 2 * self.dataset_size) 67 | blend_factor = math.sin(blend_factor * math.pi - 0.5 * math.pi) * 0.5 + 0.5 68 | 69 | if not self.in_transition: 70 | blend_factor = 1 71 | 72 | return blend_factor 73 | 74 | def is_time_to_report(self): 75 | if self.iteration >= self.tick_start_nimg_report + self.reports[min(self.lod, len(self.reports) - 1)] * 1000: 76 | self.tick_start_nimg_report = self.iteration 77 | return True 78 | return False 79 | 80 | def is_time_to_save(self): 81 | if self.iteration >= self.tick_start_nimg_snapshot + self.snapshots[ 82 | min(self.lod, len(self.snapshots) - 1)] * 1000: 83 | self.tick_start_nimg_snapshot = self.iteration 84 | return True 85 | return False 86 | 87 | def step(self): 88 | self.iteration += self.get_batch_size() 89 | self.epoch_end_time = time.time() 90 | self.per_epoch_ptime = self.epoch_end_time - self.epoch_start_time 91 | 92 | def set_epoch(self, epoch, optimizers): 93 | self.current_epoch = epoch 94 | self.iteration = 0 95 | self.tick_start_nimg_report = 0 96 | self.tick_start_nimg_snapshot = 0 97 | self.epoch_start_time = time.time() 98 | 99 | if self.cfg.TRAIN.EPOCHS_PER_LOD == 0: 100 | self.lod = self.cfg.MODEL.LAYER_COUNT - 1 101 | return 102 | 103 | new_lod = min(self.cfg.MODEL.LAYER_COUNT - 1, epoch // self.cfg.TRAIN.EPOCHS_PER_LOD) 104 | if new_lod != self.lod: 105 | self.lod = new_lod 106 | self.logger.info("#" * 80) 107 | self.logger.info("# Switching LOD to %d" % self.lod) 108 | self.logger.info("# Starting transition") 109 | self.logger.info("#" * 80) 110 | self.in_transition = True 111 | for opt in optimizers: 112 | opt.state = defaultdict(dict) 113 | 114 | is_in_first_half_of_cycle = (epoch % self.cfg.TRAIN.EPOCHS_PER_LOD) < (self.cfg.TRAIN.EPOCHS_PER_LOD // 2) 115 | is_growing = epoch // self.cfg.TRAIN.EPOCHS_PER_LOD == self.lod > 0 116 | new_in_transition = is_in_first_half_of_cycle and is_growing 117 | 118 | if new_in_transition != self.in_transition: 119 | self.in_transition = new_in_transition 120 | self.logger.info("#" * 80) 121 | self.logger.info("# Transition ended") 122 | self.logger.info("#" * 80) 123 | -------------------------------------------------------------------------------- /style_soft_intro_vae/tracker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019-2020 Stanislav Pidhorskyi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import csv 17 | from collections import OrderedDict 18 | import matplotlib 19 | matplotlib.use('Agg') 20 | import matplotlib.pyplot as plt 21 | import numpy as np 22 | import torch 23 | import os 24 | 25 | 26 | class RunningMean: 27 | def __init__(self): 28 | self.mean = 0.0 29 | self.n = 0 30 | 31 | def __iadd__(self, value): 32 | self.mean = (float(value) + self.mean * self.n)/(self.n + 1) 33 | self.n += 1 34 | return self 35 | 36 | def reset(self): 37 | self.mean = 0.0 38 | self.n = 0 39 | 40 | def mean(self): 41 | return self.mean 42 | 43 | 44 | class RunningMeanTorch: 45 | def __init__(self): 46 | self.values = [] 47 | 48 | def __iadd__(self, value): 49 | with torch.no_grad(): 50 | self.values.append(value.detach().cpu().unsqueeze(0)) 51 | return self 52 | 53 | def reset(self): 54 | self.values = [] 55 | 56 | def mean(self): 57 | with torch.no_grad(): 58 | if len(self.values) == 0: 59 | return 0.0 60 | return float(torch.cat(self.values).mean().item()) 61 | 62 | 63 | class LossTracker: 64 | def __init__(self, output_folder='.'): 65 | self.tracks = OrderedDict() 66 | self.epochs = [] 67 | self.means_over_epochs = OrderedDict() 68 | self.output_folder = output_folder 69 | 70 | def update(self, d): 71 | for k, v in d.items(): 72 | if k not in self.tracks: 73 | self.add(k) 74 | self.tracks[k] += v 75 | 76 | def add(self, name, pytorch=True): 77 | assert name not in self.tracks, "Name is already used" 78 | if pytorch: 79 | track = RunningMeanTorch() 80 | else: 81 | track = RunningMean() 82 | self.tracks[name] = track 83 | self.means_over_epochs[name] = [] 84 | return track 85 | 86 | def register_means(self, epoch): 87 | self.epochs.append(epoch) 88 | 89 | for key in self.means_over_epochs.keys(): 90 | if key in self.tracks: 91 | value = self.tracks[key] 92 | self.means_over_epochs[key].append(value.mean()) 93 | value.reset() 94 | else: 95 | self.means_over_epochs[key].append(None) 96 | 97 | with open(os.path.join(self.output_folder, 'log.csv'), mode='w') as csv_file: 98 | fieldnames = ['epoch'] + list(self.tracks.keys()) 99 | writer = csv.writer(csv_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) 100 | writer.writerow(fieldnames) 101 | for i in range(len(self.epochs)): 102 | writer.writerow([self.epochs[i]] + [self.means_over_epochs[x][i] for x in self.tracks.keys()]) 103 | 104 | def __str__(self): 105 | result = "" 106 | for key, value in self.tracks.items(): 107 | result += "%s: %.7f, " % (key, value.mean()) 108 | return result[:-2] 109 | 110 | def plot(self): 111 | plt.figure(figsize=(12, 8)) 112 | for key in self.tracks.keys(): 113 | plt.plot(self.epochs, self.means_over_epochs[key], label=key) 114 | 115 | plt.xlabel('Epoch') 116 | plt.ylabel('Loss') 117 | 118 | plt.legend(loc=2) 119 | plt.grid(True) 120 | plt.tight_layout() 121 | 122 | plt.savefig(os.path.join(self.output_folder, 'plot.png')) 123 | plt.close() 124 | 125 | def state_dict(self): 126 | return { 127 | 'tracks': self.tracks, 128 | 'epochs': self.epochs, 129 | 'means_over_epochs': self.means_over_epochs} 130 | 131 | def load_state_dict(self, state_dict): 132 | self.tracks = state_dict['tracks'] 133 | self.epochs = state_dict['epochs'] 134 | self.means_over_epochs = state_dict['means_over_epochs'] 135 | 136 | counts = list(map(len, self.means_over_epochs.values())) 137 | 138 | if len(counts) == 0: 139 | counts = [0] 140 | m = min(counts) 141 | 142 | if m < len(self.epochs): 143 | self.epochs = self.epochs[:m] 144 | 145 | for key in self.means_over_epochs.keys(): 146 | if len(self.means_over_epochs[key]) > m: 147 | self.means_over_epochs[key] = self.means_over_epochs[key][:m] 148 | -------------------------------------------------------------------------------- /soft_intro_vae_3d/generate_for_rendering.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate point clouds from a trained model for rendering using Mitsuba 3 | """ 4 | 5 | import json 6 | import numpy as np 7 | import os 8 | 9 | import torch 10 | import torch.nn.parallel 11 | import torch.utils.data 12 | from torch.utils.data import DataLoader 13 | 14 | from utils.util import cuda_setup 15 | from models.vae import SoftIntroVAE, reparameterize 16 | 17 | 18 | def generate_from_model(model, num_samples, prior_std=0.2, device=torch.device("cpu")): 19 | model.eval() 20 | noise = prior_std * torch.randn(size=(num_samples, model.zdim)).to(device) 21 | with torch.no_grad(): 22 | x_g = model.decode(noise) 23 | if x_g.shape[-2:] == (3, 2048): 24 | x_g.transpose_(1, 2) 25 | return x_g 26 | 27 | 28 | def interpolate(model, data, num_steps=20, device=torch.device("cpu")): 29 | assert data.shape[0] >= 2, "must supply at least 2 data points" 30 | model.eval() 31 | steps = np.linspace(0, 1, num_steps) 32 | data = data[:2].to(device) 33 | # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] 34 | if data.size(-1) == 3: 35 | data.transpose_(data.dim() - 2, data.dim() - 1) 36 | mu_z, logvar_z = model.encode(data) 37 | data_z = reparameterize(mu_z, logvar_z) 38 | interpolations = [data_z[0][None,]] 39 | for step in steps: 40 | interpolation = step * data_z[1] + (1 - step) * data_z[0] 41 | interpolations.append(interpolation[None,]) 42 | interpolations.append(data_z[1][None,]) 43 | interpolations = torch.cat(interpolations, dim=0) 44 | data_interpolation = model.decode(interpolations) 45 | return data_interpolation 46 | 47 | 48 | def save_point_cloud_np(save_path, data_tensor): 49 | # Change dim 50 | if data_tensor.size(-1) != 3: 51 | data_tensor.transpose_(data_tensor.dim() - 1, data_tensor.dim() - 2) 52 | data_np = data_tensor.data.cpu().numpy() 53 | np.save(save_path, data_np) 54 | print(f'saved data @ {save_path}') 55 | 56 | 57 | def prepare_model(config, path_to_weights, device=torch.device("cpu")): 58 | model = SoftIntroVAE(config).to(device) 59 | model.load_state_dict(torch.load(path_to_weights, map_location=device)) 60 | model.eval() 61 | return model 62 | 63 | 64 | def prepare_data(config, split='train', batch_size=32): 65 | dataset_name = config['dataset'].lower() 66 | if dataset_name == 'shapenet': 67 | from datasets.shapenet import ShapeNetDataset 68 | dataset = ShapeNetDataset(root_dir=config['data_dir'], 69 | classes=config['classes'], split=split) 70 | else: 71 | raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' 72 | f'`faust`. Got: `{dataset_name}`') 73 | data_loader = DataLoader(dataset, batch_size=batch_size, 74 | shuffle=True, num_workers=4, 75 | drop_last=False, pin_memory=True) 76 | return data_loader 77 | 78 | 79 | def prepare_dataset(config, split='train', batch_size=32): 80 | dataset_name = config['dataset'].lower() 81 | if dataset_name == 'shapenet': 82 | from datasets.shapenet import ShapeNetDataset 83 | dataset = ShapeNetDataset(root_dir=config['data_dir'], 84 | classes=config['classes'], split=split) 85 | else: 86 | raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' 87 | f'`faust`. Got: `{dataset_name}`') 88 | return dataset 89 | 90 | 91 | if __name__ == "__main__": 92 | """ 93 | cars + airplane: cars:[60, 5150], planes: [6450, 6550, 6950] 94 | """ 95 | save_path = './results/generated_data' 96 | os.makedirs(save_path, exist_ok=True) 97 | path_generated = os.path.join(save_path, 'generated.npy') 98 | path_interpolated = os.path.join(save_path, 'interpolations.npy') 99 | path_to_weights = './results/vae/soft_intro_vae_chair/weights/01618_jsd_0.0175.pth' 100 | config_path = 'config/soft_intro_vae_hp.json' 101 | config = None 102 | if config_path is not None and config_path.endswith('.json'): 103 | with open(config_path) as f: 104 | config = json.load(f) 105 | assert config is not None 106 | device = cuda_setup(config['cuda'], config['gpu']) 107 | print("using device: ", device) 108 | model = prepare_model(config, path_to_weights, device=device) 109 | dataset = prepare_dataset(config, split='train', batch_size=config['batch_size']) 110 | batch = torch.stack([torch.from_numpy(dataset[60][0]), torch.from_numpy(dataset[6450][0])], dim=0) 111 | # generate 112 | x_g = generate_from_model(model, num_samples=5, device=device) 113 | save_point_cloud_np(path_generated, x_g) 114 | print(f'save generations in {path_generated}') 115 | # interpolate 116 | x_interpolated = interpolate(model, batch, num_steps=50, device=device) 117 | save_point_cloud_np(path_interpolated, x_interpolated) 118 | print(f'save interpolations in {path_interpolated}') 119 | print("use these .npy files to render beautiful point clouds with Mitsuba, see the 'render' directory for instructions") 120 | -------------------------------------------------------------------------------- /style_soft_intro_vae/checkpointer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019-2020 Stanislav Pidhorskyi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import os 17 | from torch import nn 18 | import torch 19 | import utils 20 | 21 | 22 | def get_model_dict(x): 23 | if x is None: 24 | return None 25 | if isinstance(x, nn.DataParallel): 26 | return x.module.state_dict() 27 | else: 28 | return x.state_dict() 29 | 30 | 31 | def load_model(x, state_dict): 32 | if isinstance(x, nn.DataParallel): 33 | x.module.load_state_dict(state_dict) 34 | else: 35 | x.load_state_dict(state_dict) 36 | 37 | 38 | class Checkpointer(object): 39 | def __init__(self, cfg, models, auxiliary=None, logger=None, save=True): 40 | self.models = models 41 | self.auxiliary = auxiliary 42 | self.cfg = cfg 43 | self.logger = logger 44 | self._save = save 45 | 46 | def save(self, _name, **kwargs): 47 | if not self._save: 48 | return 49 | data = dict() 50 | data["models"] = dict() 51 | data["auxiliary"] = dict() 52 | for name, model in self.models.items(): 53 | data["models"][name] = get_model_dict(model) 54 | 55 | if self.auxiliary is not None: 56 | for name, item in self.auxiliary.items(): 57 | data["auxiliary"][name] = item.state_dict() 58 | data.update(kwargs) 59 | 60 | @utils.async_func 61 | def save_data(): 62 | save_file = os.path.join(self.cfg.OUTPUT_DIR, "%s.pth" % _name) 63 | self.logger.info("Saving checkpoint to %s" % save_file) 64 | torch.save(data, save_file) 65 | self.tag_last_checkpoint(save_file) 66 | 67 | return save_data() 68 | 69 | def load(self, ignore_last_checkpoint=False, file_name=None): 70 | save_file = os.path.join(self.cfg.OUTPUT_DIR, "last_checkpoint") 71 | try: 72 | with open(save_file, "r") as last_checkpoint: 73 | f = last_checkpoint.read().strip() 74 | except IOError: 75 | self.logger.info("No checkpoint found. Initializing model from scratch") 76 | if file_name is None: 77 | return {} 78 | 79 | if ignore_last_checkpoint: 80 | self.logger.info("Forced to Initialize model from scratch") 81 | return {} 82 | if file_name is not None: 83 | f = file_name 84 | 85 | self.logger.info("Loading checkpoint from {}".format(f)) 86 | checkpoint = torch.load(f, map_location=torch.device("cpu")) 87 | for name, model in self.models.items(): 88 | if name in checkpoint["models"]: 89 | try: 90 | model_dict = checkpoint["models"].pop(name) 91 | if model_dict is not None: 92 | self.models[name].load_state_dict(model_dict, strict=False) 93 | else: 94 | self.logger.warning("State dict for model \"%s\" is None " % name) 95 | except RuntimeError as e: 96 | self.logger.warning('%s\nFailed to load: %s\n%s' % ('!' * 160, name, '!' * 160)) 97 | self.logger.warning('\nFailed to load: %s' % str(e)) 98 | else: 99 | self.logger.warning("No state dict for model: %s" % name) 100 | checkpoint.pop('models') 101 | if "auxiliary" in checkpoint and self.auxiliary: 102 | self.logger.info("Loading auxiliary from {}".format(f)) 103 | for name, item in self.auxiliary.items(): 104 | try: 105 | if name in checkpoint["auxiliary"]: 106 | self.auxiliary[name].load_state_dict(checkpoint["auxiliary"].pop(name)) 107 | if "optimizers" in checkpoint and name in checkpoint["optimizers"]: 108 | self.auxiliary[name].load_state_dict(checkpoint["optimizers"].pop(name)) 109 | if name in checkpoint: 110 | self.auxiliary[name].load_state_dict(checkpoint.pop(name)) 111 | except IndexError: 112 | self.logger.warning('%s\nFailed to load: %s\n%s' % ('!' * 160, name, '!' * 160)) 113 | checkpoint.pop('auxiliary') 114 | 115 | return checkpoint 116 | 117 | def tag_last_checkpoint(self, last_filename): 118 | save_file = os.path.join(self.cfg.OUTPUT_DIR, "last_checkpoint") 119 | with open(save_file, "w") as f: 120 | f.write(last_filename) 121 | -------------------------------------------------------------------------------- /style_soft_intro_vae/make_figures/make_generation_figure.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2021 Tal Daniel 2 | # Copyright 2019-2020 Stanislav Pidhorskyi 3 | # 4 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # This work is licensed under the Creative Commons Attribution-NonCommercial 7 | # 4.0 International License. To view a copy of this license, visit 8 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 9 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 10 | 11 | from net import * 12 | from model import SoftIntroVAEModelTL 13 | from launcher import run 14 | from dataloader import * 15 | from checkpointer import Checkpointer 16 | from defaults import get_cfg_defaults 17 | from PIL import Image 18 | import PIL 19 | import os 20 | 21 | 22 | def millify(n): 23 | millnames = ['', 'k', 'M', 'G', 'T', 'P'] 24 | n = float(n) 25 | millidx = max(0, min(len(millnames) - 1, int(math.floor(0 if n == 0 else math.log10(abs(n)) / 3)))) 26 | 27 | return '{:.1f}{}'.format(n / 10 ** (3 * millidx), millnames[millidx]) 28 | 29 | 30 | def count_parameters(model, print_func=print, verbose=False): 31 | for n, p in model.named_parameters(): 32 | if p.requires_grad and verbose: 33 | print_func(n, millify(p.numel())) 34 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 35 | 36 | 37 | def draw_uncurated_result_figure(cfg, png, model, cx, cy, cw, ch, rows, lods, seed): 38 | print(png) 39 | N = sum(rows * 2 ** lod for lod in lods) 40 | images = [] 41 | 42 | rnd = np.random.RandomState(5) 43 | for i in range(N): 44 | latents = rnd.randn(1, cfg.MODEL.LATENT_SPACE_SIZE) 45 | samplez = torch.tensor(latents).float().cuda() 46 | image = model.generate(cfg.DATASET.MAX_RESOLUTION_LEVEL - 2, 1, samplez, 1, mixing=True) 47 | images.append(image[0]) 48 | 49 | canvas = PIL.Image.new('RGB', (sum(cw // 2 ** lod for lod in lods), ch * rows), 'white') 50 | image_iter = iter(list(images)) 51 | for col, lod in enumerate(lods): 52 | for row in range(rows * 2 ** lod): 53 | im = next(image_iter).cpu().numpy() 54 | im = im.transpose(1, 2, 0) 55 | im = im * 0.5 + 0.5 56 | image = PIL.Image.fromarray(np.clip(im * 255, 0, 255).astype(np.uint8), 'RGB') 57 | image = image.crop((cx, cy, cx + cw, cy + ch)) 58 | image = image.resize((cw // 2 ** lod, ch // 2 ** lod), PIL.Image.ANTIALIAS) 59 | canvas.paste(image, (sum(cw // 2 ** lod for lod in lods[:col]), row * ch // 2 ** lod)) 60 | canvas.save(png) 61 | 62 | 63 | def sample(cfg, logger): 64 | torch.cuda.set_device(0) 65 | model = SoftIntroVAEModelTL( 66 | startf=cfg.MODEL.START_CHANNEL_COUNT, 67 | layer_count=cfg.MODEL.LAYER_COUNT, 68 | maxf=cfg.MODEL.MAX_CHANNEL_COUNT, 69 | latent_size=cfg.MODEL.LATENT_SPACE_SIZE, 70 | dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA, 71 | style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB, 72 | mapping_layers=cfg.MODEL.MAPPING_LAYERS, 73 | channels=cfg.MODEL.CHANNELS, 74 | generator=cfg.MODEL.GENERATOR, 75 | encoder=cfg.MODEL.ENCODER, 76 | beta_kl=cfg.MODEL.BETA_KL, 77 | beta_rec=cfg.MODEL.BETA_REC, 78 | beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1], 79 | scale=cfg.MODEL.SCALE 80 | ) 81 | 82 | model.cuda(0) 83 | model.eval() 84 | model.requires_grad_(False) 85 | 86 | decoder = model.decoder 87 | encoder = model.encoder 88 | mapping_tl = model.mapping_tl 89 | mapping_fl = model.mapping_fl 90 | 91 | dlatent_avg = model.dlatent_avg 92 | 93 | logger.info("Trainable parameters decoder:") 94 | print(count_parameters(decoder)) 95 | 96 | logger.info("Trainable parameters encoder:") 97 | print(count_parameters(encoder)) 98 | 99 | arguments = dict() 100 | arguments["iteration"] = 0 101 | 102 | model_dict = { 103 | 'discriminator_s': encoder, 104 | 'generator_s': decoder, 105 | 'mapping_tl_s': mapping_tl, 106 | 'mapping_fl_s': mapping_fl, 107 | 'dlatent_avg': dlatent_avg 108 | } 109 | 110 | checkpointer = Checkpointer(cfg, 111 | model_dict, 112 | {}, 113 | logger=logger, 114 | save=False) 115 | 116 | checkpointer.load() 117 | 118 | model.eval() 119 | 120 | im_size = 2 ** (cfg.MODEL.LAYER_COUNT + 1) 121 | seed = np.random.randint(0, 999999) 122 | print("seed:", seed) 123 | with torch.no_grad(): 124 | path = './make_figures/output' 125 | os.makedirs(path, exist_ok=True) 126 | os.makedirs(os.path.join(path, cfg.NAME), exist_ok=True) 127 | draw_uncurated_result_figure(cfg, './make_figures/output/%s/generations.jpg' % cfg.NAME, 128 | model, cx=0, cy=0, cw=im_size, ch=im_size, rows=6, lods=[0, 0, 0, 1, 1, 2], seed=seed) 129 | 130 | 131 | if __name__ == "__main__": 132 | gpu_count = 1 133 | run(sample, get_cfg_defaults(), description='SoftIntroVAE-generations', default_config='./configs/ffhq256.yaml', 134 | world_size=gpu_count, write_log=False) 135 | -------------------------------------------------------------------------------- /soft_intro_vae/README.md: -------------------------------------------------------------------------------- 1 | # soft-intro-vae-pytorch-images 2 | 3 | Implementation of Soft-IntroVAE for image data. 4 | 5 | A step-by-step tutorial can be found in [Soft-IntroVAE Jupyter Notebook Tutorials](https://github.com/taldatech/soft-intro-vae-pytorch/tree/main/soft_intro_vae_tutorial). 6 | 7 |

8 | 9 |

10 | 11 | - [soft-intro-vae-pytorch-images](#soft-intro-vae-pytorch-images) 12 | * [Training](#training) 13 | * [Datasets](#datasets) 14 | * [Recommended hyperparameters](#recommended-hyperparameters) 15 | * [What to expect](#what-to-expect) 16 | * [Files and directories in the repository](#files-and-directories-in-the-repository) 17 | * [Tutorial](#tutorial) 18 | 19 | ## Training 20 | 21 | `main.py --help` 22 | 23 | 24 | You should use the `main.py` file with the following arguments: 25 | 26 | |Argument | Description |Legal Values | 27 | |-------------------------|---------------------------------------------|-------------| 28 | |-h, --help | shows arguments description | | 29 | |-d, --dataset | dataset to train on |str: 'cifar10', 'mnist', 'fmnist', 'svhn', 'monsters128', 'celeb128', 'celeb256', 'celeb1024' | 30 | |-n, --num_epochs | total number of epochs to run | int: default=250| 31 | |-z, --z_dim| latent dimensions | int: default=128| 32 | |-s, --seed| random state to use. for random: -1 | int: -1 , 0, 1, 2 ,....| 33 | |-v, --num_vae| number of iterations for vanilla vae training | int: default=0| 34 | |-l, --lr| learning rate | float: defalut=2e-4 | 35 | |-r, --beta_rec | beta coefficient for the reconstruction loss |float: default=1.0| 36 | |-k, --beta_kl| beta coefficient for the kl divergence | float: default=1.0| 37 | |-e, --beta_neg| beta coefficient for the kl divergence in the expELBO function | float: default=256.0| 38 | |-g, --gamma_r| coefficient for the reconstruction loss for fake data in the decoder | float: default=1e-8| 39 | |-b, --batch_size| batch size | int: default=32 | 40 | |-p, --pretrained | path to pretrained model, to continue training |str: default="None" | 41 | |-c, --device| device: -1 for cpu, 0 and up for specific cuda device |int: default=-1| 42 | |-f, --fid| if specified, FID wil be calculated during training |bool: default=False| 43 | 44 | Examples: 45 | 46 | `python main.py --dataset cifar10 --device 0 --lr 2e-4 --num_epochs 250 --beta_kl 1.0 --beta_rec 1.0 --beta_neg 256 --z_dim 128 --batch_size 32` 47 | 48 | `python main.py --dataset mnist --device 0 --lr 2e-4 --num_epochs 200 --beta_kl 1.0 --beta_rec 1.0 --beta_neg 256 --z_dim 32 --batch_size 128` 49 | 50 | ## Datasets 51 | * CelebHQ: please follow [ALAE](https://github.com/podgorskiy/ALAE#datasets) instructions. 52 | * Digital-Monsters dataset: we curated a “Digital Monsters” dataset: ~4000 images of Pokemon, Digimon and Nexomon (yes, it’s a thing). We currently don't provide a download link for this dataset (not because we are bad people), but please contact us if you wish to create it yourself. 53 | 54 | On the left is a sample from the (very diverse) Digital-Monsters dataset (we used augmentations to enrich it), and on the right, samples generated from S-IntroVAE. 55 | We hope this does not give you nightmares. 56 | 57 |

58 | 59 | 60 |

61 | 62 | ## Recommended hyperparameters 63 | 64 | |Dataset | `beta_kl` | `beta_rec`| `beta_neg`|`z_dim`|`batch_size`| 65 | |------------|------|----|---|----|---| 66 | |CIFAR10 (`cifar10`)|1.0|1.0| 256|128| 32| 67 | |SVHN (`svhn`)|1.0|1.0| 256|128| 32| 68 | |MNIST (`mnist`)|1.0|1.0|256|32|128| 69 | |FashionMNIST (`fmnist`)|1.0|1.0|256|32|128| 70 | |Monsters (`monsters128`)|0.2|0.2|256|128|16| 71 | |CelebA-HQ (`celeb256`)|0.5|1.0|1024|256|8| 72 | 73 | 74 | ## What to expect 75 | 76 | * During the training, figures of samples and reconstructions are saved locally. 77 | * During training, statistics are printed (reconstruction error, KLD, expELBO). 78 | * At the end of each epoch, a summary of statistics will be printed. 79 | * Tips: 80 | * KL of fake/rec samples should be >= KL of real data. 81 | * It is usually better to choose `beta_kl` >= `beta_rec`. 82 | * FID calculation is not so fast, so turn it off if you don't care about it. 83 | 84 | ## Files and directories in the repository 85 | 86 | |File name | Purpose | 87 | |----------------------|------| 88 | |`main.py`| general purpose main application for training Soft-IntroVAE for image data| 89 | |`train_soft_intro_vae.py`| main training function, datasets and architectures| 90 | |`datasets.py`| classes for creating PyTorch dataset classes from images| 91 | |`metrics/fid.py`, `metrics/inception.py`| functions for FID calculation from datasets, using the pretrained Inception network| 92 | 93 | 94 | ## Tutorial 95 | * [Jupyter Notebook tutorial for image datasets](https://github.com/taldatech/soft-intro-vae-pytorch/blob/main/soft_intro_vae_tutorial/soft_intro_vae_image_code_tutorial.ipynb) 96 | * [Open in Colab](https://colab.research.google.com/github/taldatech/soft-intro-vae-pytorch/blob/main/soft_intro_vae_tutorial/soft_intro_vae_image_code_tutorial.ipynb) 97 | 98 | 99 | -------------------------------------------------------------------------------- /soft_intro_vae_3d/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | # https://pytorch-geometric.readthedocs.io/en/latest/modules/transforms.html 2 | 3 | import numbers 4 | import random 5 | import math 6 | import numpy as np 7 | import torch 8 | 9 | 10 | class LinearTransformation(object): 11 | r"""Transforms node positions with a square transformation matrix computed 12 | offline. 13 | 14 | Args: 15 | matrix (Tensor): tensor with shape :math:`[D, D]` where :math:`D` 16 | corresponds to the dimensionality of node positions. 17 | """ 18 | 19 | def __init__(self, matrix): 20 | assert matrix.dim() == 2, ( 21 | 'Transformation matrix should be two-dimensional.') 22 | assert matrix.size(0) == matrix.size(1), ( 23 | 'Transformation matrix should be square. Got [{} x {}] rectangular' 24 | 'matrix.'.format(*matrix.size())) 25 | 26 | self.matrix = matrix 27 | 28 | def __call__(self, data): 29 | pos = data.pos.view(-1, 1) if data.pos.dim() == 1 else data.pos 30 | 31 | assert pos.size(-1) == self.matrix.size(-2), ( 32 | 'Node position matrix and transformation matrix have incompatible ' 33 | 'shape.') 34 | 35 | data.pos = torch.matmul(pos, self.matrix.to(pos.dtype).to(pos.device)) 36 | 37 | return data 38 | 39 | def __repr__(self): 40 | return '{}({})'.format(self.__class__.__name__, self.matrix.tolist()) 41 | 42 | 43 | class RandomRotate(object): 44 | r"""Rotates node positions around a specific axis by a randomly sampled 45 | factor within a given interval. 46 | 47 | Args: 48 | degrees (tuple or float): Rotation interval from which the rotation 49 | angle is sampled. If :obj:`degrees` is a number instead of a 50 | tuple, the interval is given by :math:`[-\mathrm{degrees}, 51 | \mathrm{degrees}]`. 52 | axis (int, optional): The rotation axis. (default: :obj:`0`) 53 | """ 54 | 55 | def __init__(self, degrees, axis=0): 56 | if isinstance(degrees, numbers.Number): 57 | degrees = (-abs(degrees), abs(degrees)) 58 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2 59 | self.degrees = degrees 60 | self.axis = axis 61 | 62 | def __call__(self, data): 63 | degree = math.pi * random.uniform(*self.degrees) / 180.0 64 | sin, cos = math.sin(degree), math.cos(degree) 65 | 66 | if data.pos.size(-1) == 2: 67 | matrix = [[cos, sin], [-sin, cos]] 68 | else: 69 | if self.axis == 0: 70 | matrix = [[1, 0, 0], [0, cos, sin], [0, -sin, cos]] 71 | elif self.axis == 1: 72 | matrix = [[cos, 0, -sin], [0, 1, 0], [sin, 0, cos]] 73 | else: 74 | matrix = [[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]] 75 | return LinearTransformation(torch.tensor(matrix))(data) 76 | 77 | def __repr__(self): 78 | return '{}({}, axis={})'.format(self.__class__.__name__, self.degrees, 79 | self.axis) 80 | 81 | 82 | # https://github.com/hxdengBerkeley/PointCNN.Pytorch/blob/master/provider.py 83 | 84 | def rotate_point_cloud(batch_data): 85 | """ Randomly rotate the point clouds to augument the dataset 86 | rotation is per shape based along up direction 87 | Input: 88 | BxNx3 array, original batch of point clouds 89 | Return: 90 | BxNx3 array, rotated batch of point clouds 91 | """ 92 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 93 | for k in range(batch_data.shape[0]): 94 | rotation_angle = np.random.uniform() * 2 * np.pi 95 | cosval = np.cos(rotation_angle) 96 | sinval = np.sin(rotation_angle) 97 | rotation_matrix = np.array([[cosval, 0, sinval], 98 | [0, 1, 0], 99 | [-sinval, 0, cosval]]) 100 | shape_pc = batch_data[k, ...] 101 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 102 | return rotated_data 103 | 104 | 105 | def rotate_point_cloud_by_angle(batch_data, rotation_angle): 106 | """ Rotate the point cloud along up direction with certain angle. 107 | Input: 108 | BxNx3 array, original batch of point clouds 109 | Return: 110 | BxNx3 array, rotated batch of point clouds 111 | """ 112 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 113 | for k in range(batch_data.shape[0]): 114 | # rotation_angle = np.random.uniform() * 2 * np.pi 115 | cosval = np.cos(rotation_angle) 116 | sinval = np.sin(rotation_angle) 117 | rotation_matrix = np.array([[cosval, 0, sinval], 118 | [0, 1, 0], 119 | [-sinval, 0, cosval]]) 120 | shape_pc = batch_data[k, ...] 121 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 122 | return rotated_data 123 | 124 | 125 | def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): 126 | """ Randomly jitter points. jittering is per point. 127 | Input: 128 | BxNx3 array, original batch of point clouds 129 | Return: 130 | BxNx3 array, jittered batch of point clouds 131 | """ 132 | B, N, C = batch_data.shape 133 | assert (clip > 0) 134 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1 * clip, clip) 135 | jittered_data += batch_data 136 | return jittered_data 137 | -------------------------------------------------------------------------------- /soft_intro_vae_3d/utils/pcutil.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | from numpy.linalg import norm 5 | # Don't delete this line, even if PyCharm says it's an unused import. 6 | # It is required for projection='3d' in add_subplot() 7 | from mpl_toolkits.mplot3d import Axes3D 8 | matplotlib.use('Agg') 9 | 10 | 11 | def rand_rotation_matrix(deflection=1.0, seed=None): 12 | """Creates a random rotation matrix. 13 | 14 | Args: 15 | deflection: the magnitude of the rotation. For 0, no rotation; for 1, 16 | completely random rotation. Small deflection => small 17 | perturbation. 18 | 19 | DOI: http://www.realtimerendering.com/resources/GraphicsGems/gemsiii/rand_rotation.c 20 | http://blog.lostinmyterminal.com/python/2015/05/12/random-rotation-matrix.html 21 | """ 22 | if seed is not None: 23 | np.random.seed(seed) 24 | 25 | theta, phi, z = np.random.uniform(size=(3,)) 26 | 27 | theta = theta * 2.0 * deflection * np.pi # Rotation about the pole (Z). 28 | phi = phi * 2.0 * np.pi # For direction of pole deflection. 29 | z = z * 2.0 * deflection # For magnitude of pole deflection. 30 | 31 | # Compute a vector V used for distributing points over the sphere 32 | # via the reflection I - V Transpose(V). This formulation of V 33 | # will guarantee that if x[1] and x[2] are uniformly distributed, 34 | # the reflected points will be uniform on the sphere. Note that V 35 | # has length sqrt(2) to eliminate the 2 in the Householder matrix. 36 | 37 | r = np.sqrt(z) 38 | V = (np.sin(phi) * r, 39 | np.cos(phi) * r, 40 | np.sqrt(2.0 - z)) 41 | 42 | st = np.sin(theta) 43 | ct = np.cos(theta) 44 | 45 | R = np.array(((ct, st, 0), (-st, ct, 0), (0, 0, 1))) 46 | 47 | # Construct the rotation matrix ( V Transpose(V) - I ) R. 48 | M = (np.outer(V, V) - np.eye(3)).dot(R) 49 | return M 50 | 51 | 52 | def add_gaussian_noise_to_pcloud(pcloud, mu=0, sigma=1): 53 | gnoise = np.random.normal(mu, sigma, pcloud.shape[0]) 54 | gnoise = np.tile(gnoise, (3, 1)).T 55 | pcloud += gnoise 56 | return pcloud 57 | 58 | 59 | def add_rotation_to_pcloud(pcloud): 60 | r_rotation = rand_rotation_matrix() 61 | 62 | if len(pcloud.shape) == 2: 63 | return pcloud.dot(r_rotation) 64 | else: 65 | return np.asarray([e.dot(r_rotation) for e in pcloud]) 66 | 67 | 68 | def apply_augmentations(batch, conf): 69 | if conf.gauss_augment is not None or conf.z_rotate: 70 | batch = batch.copy() 71 | 72 | if conf.gauss_augment is not None: 73 | mu = conf.gauss_augment['mu'] 74 | sigma = conf.gauss_augment['sigma'] 75 | batch += np.random.normal(mu, sigma, batch.shape) 76 | 77 | if conf.z_rotate: 78 | r_rotation = rand_rotation_matrix() 79 | r_rotation[0, 2] = 0 80 | r_rotation[2, 0] = 0 81 | r_rotation[1, 2] = 0 82 | r_rotation[2, 1] = 0 83 | r_rotation[2, 2] = 1 84 | batch = batch.dot(r_rotation) 85 | return batch 86 | 87 | 88 | def unit_cube_grid_point_cloud(resolution, clip_sphere=False): 89 | """Returns the center coordinates of each cell of a 3D grid with 90 | resolution^3 cells, that is placed in the unit-cube. 91 | If clip_sphere it True it drops the "corner" cells that lie outside 92 | the unit-sphere. 93 | """ 94 | grid = np.ndarray((resolution, resolution, resolution, 3), np.float32) 95 | spacing = 1.0 / float(resolution - 1) 96 | for i in range(resolution): 97 | for j in range(resolution): 98 | for k in range(resolution): 99 | grid[i, j, k, 0] = i * spacing - 0.5 100 | grid[i, j, k, 1] = j * spacing - 0.5 101 | grid[i, j, k, 2] = k * spacing - 0.5 102 | 103 | if clip_sphere: 104 | grid = grid.reshape(-1, 3) 105 | grid = grid[norm(grid, axis=1) <= 0.5] 106 | 107 | return grid, spacing 108 | 109 | 110 | def plot_3d_point_cloud(x, y, z, show=True, show_axis=True, in_u_sphere=False, 111 | marker='.', s=8, alpha=.8, figsize=(5, 5), elev=10, 112 | azim=240, axis=None, title=None, *args, **kwargs): 113 | # plt.switch_backend('agg') 114 | if axis is None: 115 | fig = plt.figure(figsize=figsize) 116 | ax = fig.add_subplot(111, projection='3d') 117 | else: 118 | ax = axis 119 | fig = axis 120 | 121 | if title is not None: 122 | plt.title(title) 123 | 124 | sc = ax.scatter(x, y, z, marker=marker, s=s, alpha=alpha, *args, **kwargs) 125 | ax.view_init(elev=elev, azim=azim) 126 | 127 | if in_u_sphere: 128 | ax.set_xlim3d(-0.5, 0.5) 129 | ax.set_ylim3d(-0.5, 0.5) 130 | ax.set_zlim3d(-0.5, 0.5) 131 | else: 132 | # Multiply with 0.7 to squeeze free-space. 133 | miv = 0.7 * np.min([np.min(x), np.min(y), np.min(z)]) 134 | mav = 0.7 * np.max([np.max(x), np.max(y), np.max(z)]) 135 | ax.set_xlim(miv, mav) 136 | ax.set_ylim(miv, mav) 137 | ax.set_zlim(miv, mav) 138 | plt.tight_layout() 139 | 140 | if not show_axis: 141 | # plt.axis('off') 142 | ax.set_axis_off() 143 | 144 | if 'c' in kwargs: 145 | plt.colorbar(sc) 146 | 147 | if show: 148 | plt.show() 149 | 150 | return fig 151 | 152 | 153 | def transform_point_clouds(X, only_z_rotation=False, deflection=1.0): 154 | r_rotation = rand_rotation_matrix(deflection) 155 | if only_z_rotation: 156 | r_rotation[0, 2] = 0 157 | r_rotation[2, 0] = 0 158 | r_rotation[1, 2] = 0 159 | r_rotation[2, 1] = 0 160 | r_rotation[2, 2] = 1 161 | X = X.dot(r_rotation).astype(np.float32) 162 | return X 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # soft-intro-vae-pytorch 2 | 3 |

4 |
5 | [CVPR 2021 Oral] Soft-IntroVAE: Analyzing and Improving Introspective Variational Autoencoders 6 |
7 |

8 |

9 | Tal Daniel • 10 | Aviv Tamar 11 | 12 |

13 |

Official repository of the paper

14 | 15 |

CVPR 2021 Oral

16 | 17 |

Project WebsiteVideo

18 | 19 |

20 | Open In Colab 21 |

22 | 23 | 24 |

25 | 26 | 27 |

28 |

29 | 30 | 31 |

32 | 33 | # Soft-IntroVAE 34 | 35 | > **Soft-IntroVAE: Analyzing and Improving Introspective Variational Autoencoders**
36 | > Tal Daniel, Aviv Tamar
37 | > 38 | > **Abstract:** *The recently introduced introspective variational autoencoder (IntroVAE) exhibits outstanding image generations, and allows for amortized inference using an image encoder. The main idea in IntroVAE is to train a VAE adversarially, using the VAE encoder to discriminate between generated and real data samples. However, the original IntroVAE loss function relied on a particular hinge-loss formulation that is very hard to stabilize in practice, and its theoretical convergence analysis ignored important terms in the loss. In this work, we take a step towards better understanding of the IntroVAE model, its practical implementation, and its applications. We propose the Soft-IntroVAE, a modified IntroVAE that replaces the hinge-loss terms with a smooth exponential loss on generated samples. This change significantly improves training stability, and also enables theoretical analysis of the complete algorithm. Interestingly, we show that the IntroVAE converges to a distribution that minimizes a sum of KL distance from the data distribution and an entropy term. We discuss the implications of this result, and demonstrate that it induces competitive image generation and reconstruction. Finally, we describe two applications of Soft-IntroVAE to unsupervised image translation and out-of-distribution detection, and demonstrate compelling results.* 39 | 40 | ## Citation 41 | Daniel, Tal, and Aviv Tamar. "Soft-IntroVAE: Analyzing and Improving the Introspective Variational Autoencoder." arXiv preprint arXiv:2012.13253 (2020). 42 | > 43 | @InProceedings{Daniel_2021_CVPR, 44 | author = {Daniel, Tal and Tamar, Aviv}, 45 | title = {Soft-IntroVAE: Analyzing and Improving the Introspective Variational Autoencoder}, 46 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 47 | month = {June}, 48 | year = {2021}, 49 | pages = {4391-4400} 50 | } 51 | 52 |

Preprint on ArXiv: 2012.13253

53 | 54 | 55 | - [soft-intro-vae-pytorch](#soft-intro-vae-pytorch) 56 | - [Soft-IntroVAE](#soft-introvae) 57 | * [Citation](#citation) 58 | * [Prerequisites](#prerequisites) 59 | * [Repository Organization](#repository-organization) 60 | * [Credits](#credits) 61 | 62 | 63 | ## Prerequisites 64 | 65 | * For your convenience, we provide an `environemnt.yml` file which installs the required packages in a `conda` environment name `torch`. 66 | * Use the terminal or an Anaconda Prompt and run the following command `conda env create -f environment.yml`. 67 | * For Style-SoftIntroVAE, more packages are required, and we provide them in the `style_soft_intro_vae` directory. 68 | 69 | 70 | |Library | Version | 71 | |----------------------|----| 72 | |`Python`| `3.6 (Anaconda)`| 73 | |`torch`| >= `1.2` (tested on `1.7`)| 74 | |`torchvision`| >= `0.4`| 75 | |`matplotlib`| >= `2.2.2`| 76 | |`numpy`| >= `1.17`| 77 | |`opencv`| >= `3.4.2`| 78 | |`tqdm`| >= `4.36.1`| 79 | |`scipy`| >= `1.3.1`| 80 | 81 | 82 | 83 | ## Repository Organization 84 | 85 | |File name | Content | 86 | |----------------------|------| 87 | |`/soft_intro_vae`| directory containing implementation for image data| 88 | |`/soft_intro_vae_2d`| directory containing implementations for 2D datasets| 89 | |`/soft_intro_vae_3d`| directory containing implementations for 3D point clouds data| 90 | |`/soft_intro_vae_bootstrap`| directory containing implementation for image data using bootstrapping (using a target decoder)| 91 | |`/style_soft_intro_vae`| directory containing implementation for image data using ALAE's style-based architecture| 92 | |`/soft_intro_vae_tutorials`| directory containing Jupyter Noteboook tutorials for the various types of Soft-IntroVAE| 93 | 94 | ## Related Projects 95 | 96 | * March 2022: `augmentation-enhanced-soft-intro-vae` - GitHub - using differentiable augmentations to improve image generation FID score. 97 | 98 | 99 | ## Credits 100 | * Adversarial Latent Autoencoders, Pidhorskyi et al., CVPR 2020 - [Code](https://github.com/podgorskiy/ALAE), [Paper](https://arxiv.org/abs/2004.04467). 101 | * FID is calculated natively in PyTorch using Seitzer implementation - [Code](https://github.com/mseitzer/pytorch-fid) 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /soft_intro_vae_3d/metrics/jsd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from numpy.linalg import norm 4 | from scipy.stats import entropy 5 | from sklearn.neighbors import NearestNeighbors 6 | 7 | 8 | __all__ = ['js_divercence_between_pc', 'jsd_between_point_cloud_sets'] 9 | 10 | 11 | # 12 | # Compute JS divergence 13 | # 14 | 15 | 16 | def js_divercence_between_pc(pc1: torch.Tensor, pc2: torch.Tensor, 17 | voxels: int = 64) -> float: 18 | """Method for computing JSD from 2 sets of point clouds.""" 19 | pc1_ = _pc_to_voxel_distribution(pc1, voxels) 20 | pc2_ = _pc_to_voxel_distribution(pc2, voxels) 21 | jsd = _js_divergence(pc1_, pc2_) 22 | return jsd 23 | 24 | 25 | def _js_divergence(P, Q): 26 | # Ensure probabilities. 27 | P_ = P / np.sum(P) 28 | Q_ = Q / np.sum(Q) 29 | 30 | # Calculate JSD using scipy.stats.entropy() 31 | e1 = entropy(P_, base=2) 32 | e2 = entropy(Q_, base=2) 33 | e_sum = entropy((P_ + Q_) / 2.0, base=2) 34 | res1 = e_sum - ((e1 + e2) / 2.0) 35 | 36 | # Calcujate JS-Div using manually defined KL divergence. 37 | # res2 = _jsdiv(P_, Q_) 38 | # 39 | # if not np.allclose(res1, res2, atol=10e-5, rtol=0): 40 | # warnings.warn('Numerical values of two JSD methods don\'t agree.') 41 | 42 | return res1 43 | 44 | 45 | def _jsdiv(P, Q): 46 | """Another way of computing JSD to check numerical stability.""" 47 | def _kldiv(A, B): 48 | a = A.copy() 49 | b = B.copy() 50 | idx = np.logical_and(a > 0, b > 0) 51 | a = a[idx] 52 | b = b[idx] 53 | return np.sum([v for v in a * np.log2(a / b)]) 54 | 55 | P_ = P / np.sum(P) 56 | Q_ = Q / np.sum(Q) 57 | 58 | M = 0.5 * (P_ + Q_) 59 | 60 | return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M)) 61 | 62 | 63 | def _pc_to_voxel_distribution(pc: torch.Tensor, n_voxels: int = 64) -> np.ndarray: 64 | pc_ = pc.clamp(-0.5, 0.4999) + 0.5 65 | # Because points are in range [0, 1], simple multiplication will bin them. 66 | pc_ = (pc_ * n_voxels).int() 67 | pc_ = pc_[:, :, 0] * n_voxels ** 2 + pc_[:, :, 1] * n_voxels + pc_[:, :, 2] 68 | 69 | B = np.zeros(n_voxels**3, dtype=np.int32) 70 | values, amounts = np.unique(pc_, return_counts=True) 71 | B[values] = amounts 72 | return B 73 | 74 | 75 | # 76 | # Stanford way to calculate JSD 77 | # 78 | 79 | 80 | def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, voxels=28, 81 | in_unit_sphere=True): 82 | """Computes the JSD between two sets of point-clouds, as introduced in the 83 | paper ```Learning Representations And Generative Models For 3D Point 84 | Clouds```. 85 | Args: 86 | sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points. 87 | ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points. 88 | voxels: (int) grid-resolution. Affects granularity of measurements. 89 | """ 90 | sample_grid_var = _entropy_of_occupancy_grid(sample_pcs, voxels, 91 | in_unit_sphere)[1] 92 | ref_grid_var = _entropy_of_occupancy_grid(ref_pcs, voxels, 93 | in_unit_sphere)[1] 94 | return _js_divergence(sample_grid_var, ref_grid_var) 95 | 96 | 97 | def _entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False): 98 | """Given a collection of point-clouds, estimate the entropy of the random 99 | variables corresponding to occupancy-grid activation patterns. 100 | Inputs: 101 | pclouds: (numpy array) #point-clouds x points per point-cloud x 3 102 | grid_resolution (int) size of occupancy grid that will be used. 103 | """ 104 | pclouds = pclouds.cpu().numpy() 105 | epsilon = 10e-4 106 | bound = 0.5 + epsilon 107 | # if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound: 108 | # warnings.warn('Point-clouds are not in unit cube.') 109 | # 110 | # if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound: 111 | # warnings.warn('Point-clouds are not in unit sphere.') 112 | 113 | grid_coordinates, _ = _unit_cube_grid_point_cloud(grid_resolution, in_sphere) 114 | grid_coordinates = grid_coordinates.reshape(-1, 3) 115 | grid_counters = np.zeros(len(grid_coordinates)) 116 | grid_bernoulli_rvars = np.zeros(len(grid_coordinates)) 117 | nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates) 118 | 119 | for pc in pclouds: 120 | _, indices = nn.kneighbors(pc) 121 | indices = np.squeeze(indices) 122 | for i in indices: 123 | grid_counters[i] += 1 124 | indices = np.unique(indices) 125 | for i in indices: 126 | grid_bernoulli_rvars[i] += 1 127 | 128 | acc_entropy = 0.0 129 | n = float(len(pclouds)) 130 | for g in grid_bernoulli_rvars: 131 | p = 0.0 132 | if g > 0: 133 | p = float(g) / n 134 | acc_entropy += entropy([p, 1.0 - p]) 135 | 136 | return acc_entropy / len(grid_counters), grid_counters 137 | 138 | 139 | def _unit_cube_grid_point_cloud(resolution, clip_sphere=False): 140 | """Returns the center coordinates of each cell of a 3D grid with resolution^3 cells, 141 | that is placed in the unit-cube. 142 | If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere. 143 | """ 144 | grid = np.ndarray((resolution, resolution, resolution, 3), np.float32) 145 | spacing = 1.0 / float(resolution - 1) 146 | for i in range(resolution): 147 | for j in range(resolution): 148 | for k in range(resolution): 149 | grid[i, j, k, 0] = i * spacing - 0.5 150 | grid[i, j, k, 1] = j * spacing - 0.5 151 | grid[i, j, k, 2] = k * spacing - 0.5 152 | 153 | if clip_sphere: 154 | grid = grid.reshape(-1, 3) 155 | grid = grid[norm(grid, axis=1) <= 0.5] 156 | 157 | return grid, spacing 158 | -------------------------------------------------------------------------------- /soft_intro_vae_3d/README.md: -------------------------------------------------------------------------------- 1 | # 3d-soft-intro-vae-pytorch 2 | 3 | Implementation of 3D Soft-IntroVAE for point clouds. 4 | 5 | This codes builds upon the code base of [3D-AAE](https://github.com/MaciejZamorski/3d-AAE) 6 | from the paper "Adversarial Autoencoders for Compact Representations of 3D Point Clouds" 7 | by Maciej Zamorski, Maciej Zięba, Piotr Klukowski, Rafał Nowak, Karol Kurach, Wojciech Stokowiec, and Tomasz Trzciński 8 |

9 | 10 | 11 |

12 | 13 |

14 | 15 |

16 | 17 | - [3d-soft-intro-vae-pytorch](#3d-soft-intro-vae-pytorch) 18 | * [Requirements](#requirements) 19 | * [Training](#training) 20 | * [Testing](#testing) 21 | * [Rendering](#rendering) 22 | * [Datasets](#datasets) 23 | * [Pretrained models](#pretrained-models) 24 | * [Recommended hyperparameters](#recommended-hyperparameters) 25 | * [What to expect](#what-to-expect) 26 | * [Files and directories in the repository](#files-and-directories-in-the-repository) 27 | * [Credits](#credits) 28 | 29 | ## Requirements 30 | 31 | * The required packages are located in the `requirements.txt` file, nothing special. 32 | * `pip install -r requirements.txt` 33 | * We provide an `environment.yml` file for `conda` (at the repo's root), which installs all that is needed to run the files. 34 | * `conda env create -f environment.yml` 35 | 36 | ## Training 37 | 38 | To run training: 39 | 40 | * Modify the hyperparameters in `/config/soft_intro_vae_hp.json` 41 | 42 | * Run: `python train_soft_intro_vae_3d.py` 43 | 44 | ## Testing 45 | 46 | * To test the generations from a trained model in terms of JSD, modify `path_to_weights` and `config_path` in `test_model.py` and run it: `python test_model.py`. 47 | * To produce reconstructed and generated point clouds in a form of NumPy array to be used with validation methods from ["Learning Representations and Generative Models For 3D Point Clouds" repository](https://github.com/optas/latent_3d_points/blob/master/notebooks/compute_evaluation_metrics.ipynb) 48 | modify `path_to_weights` and `config_path` in `evaluation/generate_data_for_metrics.py` and run: `python evaluation/generate_data_for_metrics.py` 49 | 50 | ## Rendering 51 | * To render beautiful point clouds from a trained model, we provide a script that uses Mitsuba 2 renderer. Instructions can be found in `/render`. 52 | 53 | ## Datasets 54 | * We currently support [ShapeNet](https://shapenet.org/), which will be downloaded automatically on first run. 55 | 56 | ## Pretrained models 57 | |Dataset/Class | Filename | Validation Sample JSD| Links| 58 | |------------|------|----|---| 59 | |ShapeNet-Chair|`chair_01618_jsd_0.0175.pth` |0.0175|[MEGA.co.nz](https://mega.nz/file/RJ8mmIjL#DKuvWImRZdzKL_JN9JwwsvZw3F4Iv0i5g0qaLiSL84Q), [Mediafire](http://www.mediafire.com/file/i9ozb2yv4bv1i76/chair_01618_jsd_0.0175.pth/file) | 60 | |ShapeNet-Table|`table_01592_jsd_0.0143.pth` |0.0143 | [MEGA.co.nz](https://mega.nz/file/ZQ8GjSQB#ctGaJXgvUsgaMYQm1R3bfMUzKld7nGO-oUAGGA9EOX8), [Mediafire](http://www.mediafire.com/file/hvygeusesaa58y2/table_01592_jsd_0.0143.pth/file) | 61 | |ShapeNet-Car|`car_01344_jsd_0.0113.pth` |0.0113 | [MEGA.co.nz](https://mega.nz/file/kZ0AQQQL#hecHNlPyh0ww3_RZOvrXCE48yr5ZmfL3RZ01MSz2NwU), [Mediafire](http://www.mediafire.com/file/ja1p9wjnc58uab4/car_01344_jsd_0.0113.pth/file) | 62 | |ShapeNet-Airplane|`airplane_00536_jsd_0.0191.pth` |0.0191 | [MEGA.co.nz](https://mega.nz/file/xA9g0ajA#jyhBgPQC4VxLwgDPfk-xo_xAbCUQofzVz9jdP0OUvDc), [Mediafire](http://www.mediafire.com/file/79ett5dhhwm2yl8/airplane_00536_jsd_0.0191.pth/file) | 63 | 64 | 65 | ## Recommended hyperparameters 66 | 67 | |Dataset | `beta_kl` | `beta_rec`| `beta_neg`|`z_dim`| 68 | |------------|------|----|---|----| 69 | |ShapeNet|0.2|1.0| 20.0|128| 70 | 71 | 72 | ## What to expect 73 | 74 | * During the training, figures of samples and reconstructions are saved locally. 75 | * First row - real, second row - reconstructions, third row - random samples 76 | * During training, statistics are printed (reconstruction error, KLD, expELBO). 77 | * Checkpoint is saved every epoch, and JSD is calculated on the validation split. 78 | * Tips: 79 | * KL of fake/rec samples should be > KL of real data (by a fair margin). 80 | * Currently, this code only supports the Chamfer Distance loss, which requires high `beta_rec`. 81 | * Since the practice is to train on a single class, it is usually better to use a narrower Gaussian for the prior (e.g., N(0, 0.2)). 82 | 83 | 84 | ## Files and directories in the repository 85 | 86 | |File name | Purpose | 87 | |----------------------|------| 88 | |`train_soft_intro_vae_3d.py`| main training function.| 89 | |`generate_for_rendering.py`| generate samples (+interpolation) from a trained model for rendering with Mitsuba.| 90 | |`test_model.py`| test sampling JSD of a trained model (w.r.t the test split).| 91 | |`config/soft_intro_vae_hp.json`| contains the hyperparmeters of the model.| 92 | |`/datasets`| directory containing various datasets files (e.g., data loader for ShapeNet).| 93 | |`/evaluations`| directory containing evaluation scrips (e.g., generating data for evaluation metrics).| 94 | |`/losses/chamfer_loss.py`| PyTorch implementation of the Chamfer distance loss function.| 95 | |`/metrics/jsd.py`| functions to measure JSD between point clouds.| 96 | |`/models/vae.py`| VAE module and architecture.| 97 | |`/render`| directory containing scripts and instructions to render point clouds with Mitsuba 2 renderer.| 98 | |`/utils`| various utility functions to process the data| 99 | 100 | ## Credits 101 | * Adversarial Autoencoders for Compact Representations of 3D Point Clouds, Zamorski et al., 2018 - [Code](https://github.com/MaciejZamorski/3d-AAE), [Paper](https://arxiv.org/abs/1811.07605). 102 | 103 | -------------------------------------------------------------------------------- /soft_intro_vae_bootstrap/README.md: -------------------------------------------------------------------------------- 1 | # soft-intro-vae-pytorch-images-bootstrap 2 | 3 | Implementation of Soft-IntroVAE for image data, using "bootstrapping". This version was not used in the paper. 4 | 5 | A step-by-step tutorial can be found in [Soft-IntroVAE Jupyter Notebook Tutorials](https://github.com/taldatech/soft-intro-vae-pytorch/tree/main/soft_intro_vae_tutorial). 6 | 7 |

8 | 9 |

10 | 11 | - [soft-intro-vae-pytorch-images-bootstrap](#soft-intro-vae-pytorch-images-bootstrap) 12 | * [What is different?](#what-is-different-) 13 | * [Training](#training) 14 | * [Datasets](#datasets) 15 | * [Recommended hyperparameters](#recommended-hyperparameters) 16 | * [What to expect](#what-to-expect) 17 | * [Files and directories in the repository](#files-and-directories-in-the-repository) 18 | * [Tutorial](#tutorial) 19 | 20 | ## What is different? 21 | 22 | The idea is to use a `target` decoder to update both encoder and decoder. This makes 23 | the optimization a bit simpler, and allows more flexible values for `gamma_r` (e.g. 1.0 instead of 1e-8), 24 | the coefficient of the reconstruction error for the fake data in the decoder. 25 | Implementation-wise, the `target` decoder is not trained, but uses the weights of the original 26 | decoder, and it lags 1 epoch behind (so we just copy the weights of the decoder to the target decoder every 1 epoch). 27 | 28 | * In `train_soft_intro_vae_bootstrap.py`: 29 | * In the `SoftIntroVAE` class, another decoder is added (`self.target_decoder`), the `forward()` function uses the target decoder by default. 30 | * In the decoder training step: no need to `detach()` the reconstructions of fake data. 31 | * At the end of each epoch, weights are copied from `model.decoder` to `model.target_decoder`. 32 | 33 | ## Training 34 | 35 | `main.py --help` 36 | 37 | 38 | You should use the `main.py` file with the following arguments: 39 | 40 | |Argument | Description |Legal Values | 41 | |-------------------------|---------------------------------------------|-------------| 42 | |-h, --help | shows arguments description | | 43 | |-d, --dataset | dataset to train on |str: 'cifar10', 'mnist', 'fmnist', 'svhn', 'monsters128', 'celeb128', 'celeb256', 'celeb1024' | 44 | |-n, --num_epochs | total number of epochs to run | int: default=250| 45 | |-z, --z_dim| latent dimensions | int: default=128| 46 | |-s, --seed| random state to use. for random: -1 | int: -1 , 0, 1, 2 ,....| 47 | |-v, --num_vae| number of iterations for vanilla vae training | int: default=0| 48 | |-l, --lr| learning rate | float: defalut=2e-4 | 49 | |-r, --beta_rec | beta coefficient for the reconstruction loss |float: default=1.0| 50 | |-k, --beta_kl| beta coefficient for the kl divergence | float: default=1.0| 51 | |-e, --beta_neg| beta coefficient for the kl divergence in the expELBO function | float: default=256.0| 52 | |-g, --gamma_r| coefficient for the reconstruction loss for fake data in the decoder | float: default=1e-8| 53 | |-b, --batch_size| batch size | int: default=32 | 54 | |-p, --pretrained | path to pretrained model, to continue training |str: default="None" | 55 | |-c, --device| device: -1 for cpu, 0 and up for specific cuda device |int: default=-1| 56 | |-f, --fid| if specified, FID wil be calculated during training |bool: default=False| 57 | |-o, --freq| epochs between copying weights from decoder to target decoder |int: default=1| 58 | 59 | Examples: 60 | 61 | `python main.py --dataset cifar10 --device 0 --lr 2e-4 --num_epochs 250 --beta_kl 1.0 --beta_rec 1.0 --beta_neg 256 --z_dim 128 --batch_size 32` 62 | 63 | `python main.py --dataset mnist --device 0 --lr 2e-4 --num_epochs 200 --beta_kl 1.0 --beta_rec 1.0 --beta_neg 256 --z_dim 32 --batch_size 128` 64 | 65 | ## Datasets 66 | * CelebHQ: please follow [ALAE](https://github.com/podgorskiy/ALAE#datasets) instructions. 67 | 68 | ## Recommended hyperparameters 69 | 70 | |Dataset | `beta_kl` | `beta_rec`| `beta_neg`|`z_dim`|`batch_size`| 71 | |------------|------|----|---|----|---| 72 | |CIFAR10 (`cifar10`)|1.0|1.0| 256|128| 32| 73 | |SVHN (`svhn`)|1.0|1.0| 256|128| 32| 74 | |MNIST (`mnist`)|1.0|1.0|256|32|128| 75 | |FashionMNIST (`fmnist`)|1.0|1.0|256|32|128| 76 | |Monsters (`monsters128`)|0.2|0.2|256|128|16| 77 | |CelebA-HQ (`celeb256`)|0.5|1.0|1024|256|8| 78 | 79 | 80 | ## What to expect 81 | 82 | * During the training, figures of samples and reconstructions are saved locally. 83 | * During training, statistics are printed (reconstruction error, KLD, expELBO). 84 | * At the end of each epoch, a summary of statistics will be printed. 85 | * Tips: 86 | * KL of fake/rec samples should be >= KL of real data. 87 | * It is usually better to choose `beta_kl` >= `beta_rec`. 88 | * FID calculation is not so fast, so turn it off if you don't care about it. 89 | * `gamma_r` can be set to values such as 0.5, 1.0, and etc... 90 | 91 | ## Files and directories in the repository 92 | 93 | |File name | Purpose | 94 | |----------------------|------| 95 | |`main.py`| general purpose main application for training Soft-IntroVAE for image data| 96 | |`train_soft_intro_vae_bootstrap.py`| main training function, datasets and architectures| 97 | |`datasets.py`| classes for creating PyTorch dataset classes from images| 98 | |`metrics/fid.py`, `metrics/inception.py`| functions for FID calculation from datasets, using the pretrained Inception network| 99 | 100 | 101 | ## Tutorial 102 | * [Jupyter Notebook tutorial for image datasets with bootstrapping](https://github.com/taldatech/soft-intro-vae-pytorch/blob/main/soft_intro_vae_tutorial/soft_intro_vae_bootstrap_code_tutorial.ipynb) 103 | * [Open in Colab](https://colab.research.google.com/github/taldatech/soft-intro-vae-pytorch/blob/main/soft_intro_vae_tutorial/soft_intro_vae_bootstrap_code_tutorial.ipynb) 104 | 105 | -------------------------------------------------------------------------------- /soft_intro_vae_3d/datasets/shapenet.py: -------------------------------------------------------------------------------- 1 | import urllib 2 | import shutil 3 | from os import listdir, makedirs, remove 4 | from os.path import exists, join 5 | from zipfile import ZipFile 6 | 7 | import pandas as pd 8 | from torch.utils.data import Dataset 9 | 10 | from utils.plyfile import load_ply 11 | 12 | synth_id_to_category = { 13 | '02691156': 'airplane', '02773838': 'bag', '02801938': 'basket', 14 | '02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench', 15 | '02834778': 'bicycle', '02843684': 'birdhouse', '02871439': 'bookshelf', 16 | '02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus', 17 | '02933112': 'cabinet', '02747177': 'can', '02942699': 'camera', 18 | '02954340': 'cap', '02958343': 'car', '03001627': 'chair', 19 | '03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor', 20 | '04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can', 21 | '04460130': 'tower', '04468005': 'train', '03085013': 'keyboard', 22 | '03261776': 'earphone', '03325088': 'faucet', '03337140': 'file', 23 | '03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar', 24 | '03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop', 25 | '03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone', 26 | '03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug', 27 | '03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol', 28 | '03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control', 29 | '04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard', 30 | '04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel', 31 | '04554684': 'washer', '02858304': 'boat', '02992529': 'cellphone' 32 | } 33 | 34 | category_to_synth_id = {v: k for k, v in synth_id_to_category.items()} 35 | synth_id_to_number = {k: i for i, k in enumerate(synth_id_to_category.keys())} 36 | 37 | 38 | class ShapeNetDataset(Dataset): 39 | def __init__(self, root_dir='/home/datasets/shapenet', classes=[], 40 | transform=None, split='train'): 41 | """ 42 | Args: 43 | root_dir (string): Directory with all the point clouds. 44 | transform (callable, optional): Optional transform to be applied 45 | on a sample. 46 | """ 47 | self.root_dir = root_dir 48 | self.transform = transform 49 | self.split = split 50 | 51 | self._maybe_download_data() 52 | 53 | pc_df = self._get_names() 54 | if classes: 55 | if classes[0] not in synth_id_to_category.keys(): 56 | classes = [category_to_synth_id[c] for c in classes] 57 | pc_df = pc_df[pc_df.category.isin(classes)].reset_index(drop=True) 58 | else: 59 | classes = synth_id_to_category.keys() 60 | 61 | self.point_clouds_names_train = pd.concat([pc_df[pc_df['category'] == c][:int(0.85*len(pc_df[pc_df['category'] == c]))].reset_index(drop=True) for c in classes]) 62 | self.point_clouds_names_valid = pd.concat([pc_df[pc_df['category'] == c][int(0.85*len(pc_df[pc_df['category'] == c])):int(0.9*len(pc_df[pc_df['category'] == c]))].reset_index(drop=True) for c in classes]) 63 | self.point_clouds_names_test = pd.concat([pc_df[pc_df['category'] == c][int(0.9*len(pc_df[pc_df['category'] == c])):].reset_index(drop=True) for c in classes]) 64 | 65 | def __len__(self): 66 | if self.split == 'train': 67 | pc_names = self.point_clouds_names_train 68 | elif self.split == 'valid': 69 | pc_names = self.point_clouds_names_valid 70 | elif self.split == 'test': 71 | pc_names = self.point_clouds_names_test 72 | else: 73 | raise ValueError('Invalid split. Should be train, valid or test.') 74 | return len(pc_names) 75 | 76 | def __getitem__(self, idx): 77 | if self.split == 'train': 78 | pc_names = self.point_clouds_names_train 79 | elif self.split == 'valid': 80 | pc_names = self.point_clouds_names_valid 81 | elif self.split == 'test': 82 | pc_names = self.point_clouds_names_test 83 | else: 84 | raise ValueError('Invalid split. Should be train, valid or test.') 85 | 86 | pc_category, pc_filename = pc_names.iloc[idx].values 87 | 88 | pc_filepath = join(self.root_dir, pc_category, pc_filename) 89 | sample = load_ply(pc_filepath) 90 | 91 | if self.transform: 92 | sample = self.transform(sample) 93 | 94 | return sample, synth_id_to_number[pc_category] 95 | 96 | def _get_names(self) -> pd.DataFrame: 97 | filenames = [] 98 | for category_id in synth_id_to_category.keys(): 99 | for f in listdir(join(self.root_dir, category_id)): 100 | if f not in ['.DS_Store']: 101 | filenames.append((category_id, f)) 102 | return pd.DataFrame(filenames, columns=['category', 'filename']) 103 | 104 | def _maybe_download_data(self): 105 | if exists(self.root_dir): 106 | return 107 | 108 | print(f'ShapeNet doesn\'t exist in root directory {self.root_dir}. ' 109 | f'Downloading...') 110 | makedirs(self.root_dir) 111 | 112 | url = 'https://www.dropbox.com/s/vmsdrae6x5xws1v/shape_net_core_uniform_samples_2048.zip?dl=1' 113 | 114 | data = urllib.request.urlopen(url) 115 | filename = url.rpartition('/')[2][:-5] 116 | file_path = join(self.root_dir, filename) 117 | with open(file_path, mode='wb') as f: 118 | d = data.read() 119 | f.write(d) 120 | 121 | print('Extracting...') 122 | with ZipFile(file_path, mode='r') as zip_f: 123 | zip_f.extractall(self.root_dir) 124 | 125 | remove(file_path) 126 | 127 | extracted_dir = join(self.root_dir, 128 | 'shape_net_core_uniform_samples_2048') 129 | for d in listdir(extracted_dir): 130 | shutil.move(src=join(extracted_dir, d), 131 | dst=self.root_dir) 132 | 133 | shutil.rmtree(extracted_dir) 134 | 135 | -------------------------------------------------------------------------------- /soft_intro_vae/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | 4 | from os import listdir 5 | from os.path import join 6 | from PIL import Image, ImageOps 7 | import random 8 | import torchvision.transforms as transforms 9 | import os 10 | 11 | 12 | def load_image(file_path, input_height=128, input_width=None, output_height=128, output_width=None, 13 | crop_height=None, crop_width=None, is_random_crop=True, is_mirror=True, is_gray=False): 14 | if input_width is None: 15 | input_width = input_height 16 | if output_width is None: 17 | output_width = output_height 18 | if crop_width is None: 19 | crop_width = crop_height 20 | 21 | img = Image.open(file_path) 22 | if is_gray is False and img.mode is not 'RGB': 23 | img = img.convert('RGB') 24 | if is_gray and img.mode is not 'L': 25 | img = img.convert('L') 26 | 27 | if is_mirror and random.randint(0, 1) is 0: 28 | img = ImageOps.mirror(img) 29 | 30 | if input_height is not None: 31 | img = img.resize((input_width, input_height), Image.BICUBIC) 32 | 33 | if crop_height is not None: 34 | [w, h] = img.size 35 | if is_random_crop: 36 | # print([w,cropSize]) 37 | cx1 = random.randint(0, w - crop_width) 38 | cx2 = w - crop_width - cx1 39 | cy1 = random.randint(0, h - crop_height) 40 | cy2 = h - crop_height - cy1 41 | else: 42 | cx2 = cx1 = int(round((w - crop_width) / 2.)) 43 | cy2 = cy1 = int(round((h - crop_height) / 2.)) 44 | img = ImageOps.crop(img, (cx1, cy1, cx2, cy2)) 45 | 46 | img = img.resize((output_width, output_height), Image.BICUBIC) 47 | return img 48 | 49 | 50 | class ImageDatasetFromFile(data.Dataset): 51 | def __init__(self, image_list, root_path, 52 | input_height=128, input_width=None, output_height=128, output_width=None, 53 | crop_height=None, crop_width=None, is_random_crop=False, is_mirror=True, is_gray=False): 54 | super(ImageDatasetFromFile, self).__init__() 55 | 56 | self.image_filenames = image_list 57 | self.is_random_crop = is_random_crop 58 | self.is_mirror = is_mirror 59 | self.input_height = input_height 60 | self.input_width = input_width 61 | self.output_height = output_height 62 | self.output_width = output_width 63 | self.root_path = root_path 64 | self.crop_height = crop_height 65 | self.crop_width = crop_width 66 | self.is_gray = is_gray 67 | 68 | self.input_transform = transforms.Compose([ 69 | transforms.ToTensor() 70 | ]) 71 | 72 | def __getitem__(self, index): 73 | img = load_image(join(self.root_path, self.image_filenames[index]), 74 | self.input_height, self.input_width, self.output_height, self.output_width, 75 | self.crop_height, self.crop_width, self.is_random_crop, self.is_mirror, self.is_gray) 76 | 77 | img = self.input_transform(img) 78 | 79 | return img 80 | 81 | def __len__(self): 82 | return len(self.image_filenames) 83 | 84 | 85 | def list_images_in_dir(path): 86 | valid_images = [".jpg", ".gif", ".png"] 87 | img_list = [] 88 | for f in os.listdir(path): 89 | ext = os.path.splitext(f)[1] 90 | if ext.lower() not in valid_images: 91 | continue 92 | img_list.append(os.path.join(path, f)) 93 | return img_list 94 | 95 | 96 | class DigitalMonstersDataset(data.Dataset): 97 | def __init__(self, root_path, 98 | input_height=None, input_width=None, output_height=128, output_width=None, is_gray=False, pokemon=True, 99 | digimon=True, nexomon=True): 100 | super(DigitalMonstersDataset, self).__init__() 101 | image_list = [] 102 | if pokemon: 103 | print("collecting pokemon...") 104 | image_list.extend(list_images_in_dir(os.path.join(root_path, 'pokemon'))) 105 | if digimon: 106 | print("collecting digimon...") 107 | image_list.extend(list_images_in_dir(os.path.join(root_path, 'digimon', '200'))) 108 | if nexomon: 109 | print("collecting nexomon...") 110 | image_list.extend(list_images_in_dir(os.path.join(root_path, 'nexomon'))) 111 | print(f'total images: {len(image_list)}') 112 | 113 | self.image_filenames = image_list 114 | self.input_height = input_height 115 | self.input_width = input_width 116 | self.output_height = output_height 117 | self.output_width = output_width 118 | self.root_path = root_path 119 | self.is_gray = is_gray 120 | 121 | # self.input_transform = transforms.Compose([ 122 | # transforms.RandomAffine(0, translate=(5 / output_height, 5 / output_height), fillcolor=(255, 255, 255)), 123 | # transforms.ColorJitter(hue=0.5), 124 | # transforms.RandomHorizontalFlip(p=0.5), 125 | # transforms.ToTensor(), 126 | # transforms.Normalize((0.5, 0.5, 0.5,), (0.5, 0.5, 0.5,)) 127 | # ]) 128 | 129 | self.input_transform = transforms.Compose([ 130 | transforms.RandomAffine(0, translate=(5 / output_height, 5 / output_height), fillcolor=(255, 255, 255)), 131 | transforms.ColorJitter(hue=0.5), 132 | transforms.RandomHorizontalFlip(p=0.5), 133 | transforms.ToTensor() 134 | ]) 135 | 136 | # self.input_transform = transforms.Compose([ 137 | # transforms.ToTensor() 138 | # ]) 139 | 140 | def __getitem__(self, index): 141 | img = load_image(self.image_filenames[index], input_height=self.input_height, input_width=self.input_width, 142 | output_height=self.output_height, output_width=self.output_width, 143 | crop_height=None, crop_width=None, is_random_crop=False, is_mirror=False, is_gray=False) 144 | img = self.input_transform(img) 145 | 146 | return img 147 | 148 | def __len__(self): 149 | return len(self.image_filenames) 150 | 151 | 152 | if __name__ == "__main__": 153 | ds = DigitalMonstersDataset(root_path='./pokemon_ds') 154 | print(ds) 155 | -------------------------------------------------------------------------------- /soft_intro_vae_bootstrap/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | 4 | from os import listdir 5 | from os.path import join 6 | from PIL import Image, ImageOps 7 | import random 8 | import torchvision.transforms as transforms 9 | import os 10 | 11 | 12 | def load_image(file_path, input_height=128, input_width=None, output_height=128, output_width=None, 13 | crop_height=None, crop_width=None, is_random_crop=True, is_mirror=True, is_gray=False): 14 | if input_width is None: 15 | input_width = input_height 16 | if output_width is None: 17 | output_width = output_height 18 | if crop_width is None: 19 | crop_width = crop_height 20 | 21 | img = Image.open(file_path) 22 | if is_gray is False and img.mode is not 'RGB': 23 | img = img.convert('RGB') 24 | if is_gray and img.mode is not 'L': 25 | img = img.convert('L') 26 | 27 | if is_mirror and random.randint(0, 1) is 0: 28 | img = ImageOps.mirror(img) 29 | 30 | if input_height is not None: 31 | img = img.resize((input_width, input_height), Image.BICUBIC) 32 | 33 | if crop_height is not None: 34 | [w, h] = img.size 35 | if is_random_crop: 36 | # print([w,cropSize]) 37 | cx1 = random.randint(0, w - crop_width) 38 | cx2 = w - crop_width - cx1 39 | cy1 = random.randint(0, h - crop_height) 40 | cy2 = h - crop_height - cy1 41 | else: 42 | cx2 = cx1 = int(round((w - crop_width) / 2.)) 43 | cy2 = cy1 = int(round((h - crop_height) / 2.)) 44 | img = ImageOps.crop(img, (cx1, cy1, cx2, cy2)) 45 | 46 | img = img.resize((output_width, output_height), Image.BICUBIC) 47 | return img 48 | 49 | 50 | class ImageDatasetFromFile(data.Dataset): 51 | def __init__(self, image_list, root_path, 52 | input_height=128, input_width=None, output_height=128, output_width=None, 53 | crop_height=None, crop_width=None, is_random_crop=False, is_mirror=True, is_gray=False): 54 | super(ImageDatasetFromFile, self).__init__() 55 | 56 | self.image_filenames = image_list 57 | self.is_random_crop = is_random_crop 58 | self.is_mirror = is_mirror 59 | self.input_height = input_height 60 | self.input_width = input_width 61 | self.output_height = output_height 62 | self.output_width = output_width 63 | self.root_path = root_path 64 | self.crop_height = crop_height 65 | self.crop_width = crop_width 66 | self.is_gray = is_gray 67 | 68 | self.input_transform = transforms.Compose([ 69 | transforms.ToTensor() 70 | ]) 71 | 72 | def __getitem__(self, index): 73 | img = load_image(join(self.root_path, self.image_filenames[index]), 74 | self.input_height, self.input_width, self.output_height, self.output_width, 75 | self.crop_height, self.crop_width, self.is_random_crop, self.is_mirror, self.is_gray) 76 | 77 | img = self.input_transform(img) 78 | 79 | return img 80 | 81 | def __len__(self): 82 | return len(self.image_filenames) 83 | 84 | 85 | def list_images_in_dir(path): 86 | valid_images = [".jpg", ".gif", ".png"] 87 | img_list = [] 88 | for f in os.listdir(path): 89 | ext = os.path.splitext(f)[1] 90 | if ext.lower() not in valid_images: 91 | continue 92 | img_list.append(os.path.join(path, f)) 93 | return img_list 94 | 95 | 96 | class DigitalMonstersDataset(data.Dataset): 97 | def __init__(self, root_path, 98 | input_height=None, input_width=None, output_height=128, output_width=None, is_gray=False, pokemon=True, 99 | digimon=True, nexomon=True): 100 | super(DigitalMonstersDataset, self).__init__() 101 | image_list = [] 102 | if pokemon: 103 | print("collecting pokemon...") 104 | image_list.extend(list_images_in_dir(os.path.join(root_path, 'pokemon'))) 105 | if digimon: 106 | print("collecting digimon...") 107 | image_list.extend(list_images_in_dir(os.path.join(root_path, 'digimon', '200'))) 108 | if nexomon: 109 | print("collecting nexomon...") 110 | image_list.extend(list_images_in_dir(os.path.join(root_path, 'nexomon'))) 111 | print(f'total images: {len(image_list)}') 112 | 113 | self.image_filenames = image_list 114 | self.input_height = input_height 115 | self.input_width = input_width 116 | self.output_height = output_height 117 | self.output_width = output_width 118 | self.root_path = root_path 119 | self.is_gray = is_gray 120 | 121 | # self.input_transform = transforms.Compose([ 122 | # transforms.RandomAffine(0, translate=(5 / output_height, 5 / output_height), fillcolor=(255, 255, 255)), 123 | # transforms.ColorJitter(hue=0.5), 124 | # transforms.RandomHorizontalFlip(p=0.5), 125 | # transforms.ToTensor(), 126 | # transforms.Normalize((0.5, 0.5, 0.5,), (0.5, 0.5, 0.5,)) 127 | # ]) 128 | 129 | self.input_transform = transforms.Compose([ 130 | transforms.RandomAffine(0, translate=(5 / output_height, 5 / output_height), fillcolor=(255, 255, 255)), 131 | transforms.ColorJitter(hue=0.5), 132 | transforms.RandomHorizontalFlip(p=0.5), 133 | transforms.ToTensor() 134 | ]) 135 | 136 | # self.input_transform = transforms.Compose([ 137 | # transforms.ToTensor() 138 | # ]) 139 | 140 | def __getitem__(self, index): 141 | img = load_image(self.image_filenames[index], input_height=self.input_height, input_width=self.input_width, 142 | output_height=self.output_height, output_width=self.output_width, 143 | crop_height=None, crop_width=None, is_random_crop=False, is_mirror=False, is_gray=False) 144 | img = self.input_transform(img) 145 | 146 | return img 147 | 148 | def __len__(self): 149 | return len(self.image_filenames) 150 | 151 | 152 | if __name__ == "__main__": 153 | ds = DigitalMonstersDataset(root_path='./pokemon_ds') 154 | print(ds) 155 | -------------------------------------------------------------------------------- /style_soft_intro_vae/make_figures/make_recon_figure_ffhq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2021 Tal Daniel 2 | # Copyright 2019-2020 Stanislav Pidhorskyi 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | import torch.utils.data 18 | from torchvision.utils import save_image 19 | from net import * 20 | from model import SoftIntroVAEModelTL 21 | from launcher import run 22 | from checkpointer import Checkpointer 23 | from defaults import get_cfg_defaults 24 | import lreq 25 | from dataloader import * 26 | 27 | lreq.use_implicit_lreq.set(True) 28 | 29 | 30 | def millify(n): 31 | millnames = ['', 'k', 'M', 'G', 'T', 'P'] 32 | n = float(n) 33 | millidx = max(0, min(len(millnames) - 1, int(math.floor(0 if n == 0 else math.log10(abs(n)) / 3)))) 34 | 35 | return '{:.1f}{}'.format(n / 10 ** (3 * millidx), millnames[millidx]) 36 | 37 | 38 | def count_parameters(model, print_func=print, verbose=False): 39 | for n, p in model.named_parameters(): 40 | if p.requires_grad and verbose: 41 | print_func(n, millify(p.numel())) 42 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 43 | 44 | 45 | def place(canvas, image, x, y): 46 | im_size = image.shape[2] 47 | if len(image.shape) == 4: 48 | image = image[0] 49 | canvas[:, y: y + im_size, x: x + im_size] = image * 0.5 + 0.5 50 | 51 | 52 | def save_sample(model, sample, i): 53 | os.makedirs('results', exist_ok=True) 54 | 55 | with torch.no_grad(): 56 | model.eval() 57 | x_rec = model.generate(model.generator.layer_count - 1, 1, z=sample) 58 | 59 | def save_pic(x_rec): 60 | resultsample = x_rec * 0.5 + 0.5 61 | resultsample = resultsample.cpu() 62 | save_image(resultsample, 63 | 'sample_%i_lr.png' % i, nrow=16) 64 | 65 | save_pic(x_rec) 66 | 67 | 68 | def sample(cfg, logger): 69 | torch.cuda.set_device(0) 70 | model = SoftIntroVAEModelTL( 71 | startf=cfg.MODEL.START_CHANNEL_COUNT, 72 | layer_count=cfg.MODEL.LAYER_COUNT, 73 | maxf=cfg.MODEL.MAX_CHANNEL_COUNT, 74 | latent_size=cfg.MODEL.LATENT_SPACE_SIZE, 75 | dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA, 76 | style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB, 77 | mapping_layers=cfg.MODEL.MAPPING_LAYERS, 78 | channels=cfg.MODEL.CHANNELS, 79 | generator=cfg.MODEL.GENERATOR, 80 | encoder=cfg.MODEL.ENCODER, 81 | beta_kl=cfg.MODEL.BETA_KL, 82 | beta_rec=cfg.MODEL.BETA_REC, 83 | beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1], 84 | scale=cfg.MODEL.SCALE 85 | ) 86 | model.cuda(0) 87 | model.eval() 88 | model.requires_grad_(False) 89 | 90 | decoder = model.decoder 91 | encoder = model.encoder 92 | mapping_tl = model.mapping_tl 93 | mapping_fl = model.mapping_fl 94 | dlatent_avg = model.dlatent_avg 95 | 96 | logger.info("Trainable parameters decoder:") 97 | print(count_parameters(decoder)) 98 | 99 | logger.info("Trainable parameters encoder:") 100 | print(count_parameters(encoder)) 101 | 102 | arguments = dict() 103 | arguments["iteration"] = 0 104 | 105 | model_dict = { 106 | 'discriminator_s': encoder, 107 | 'generator_s': decoder, 108 | 'mapping_tl_s': mapping_tl, 109 | 'mapping_fl_s': mapping_fl, 110 | 'dlatent_avg': dlatent_avg 111 | } 112 | 113 | checkpointer = Checkpointer(cfg, 114 | model_dict, 115 | {}, 116 | logger=logger, 117 | save=False) 118 | 119 | extra_checkpoint_data = checkpointer.load() 120 | 121 | model.eval() 122 | 123 | layer_count = cfg.MODEL.LAYER_COUNT 124 | 125 | def encode(x): 126 | z, mu, _ = model.encode(x, layer_count - 1, 1) 127 | styles = model.mapping_fl(mu) 128 | return styles 129 | 130 | def decode(x): 131 | return model.decoder(x, layer_count - 1, 1, noise=True) 132 | 133 | rnd = np.random.RandomState(5) 134 | 135 | dataset = TFRecordsDataset(cfg, logger, rank=0, world_size=1, buffer_size_mb=10, channels=cfg.MODEL.CHANNELS, 136 | train=False) 137 | 138 | dataset.reset(cfg.DATASET.MAX_RESOLUTION_LEVEL, 10) 139 | b = iter(make_dataloader(cfg, logger, dataset, 10, 0, numpy=True)) 140 | 141 | def make(sample): 142 | canvas = [] 143 | with torch.no_grad(): 144 | for img in sample: 145 | x = torch.tensor(np.asarray(img, dtype=np.float32), device='cpu', 146 | requires_grad=True).cuda() / 127.5 - 1. 147 | if x.shape[0] == 4: 148 | x = x[:3] 149 | latents = encode(x[None, ...].cuda()) 150 | f = decode(latents) 151 | r = torch.cat([x[None, ...].detach().cpu(), f.detach().cpu()], dim=3) 152 | canvas.append(r) 153 | return canvas 154 | 155 | sample = next(b) 156 | canvas = make(sample) 157 | canvas = torch.cat(canvas, dim=0) 158 | 159 | save_image(canvas * 0.5 + 0.5, './make_figures/reconstructions_ffhq_real_1.png', nrow=2, pad_value=1.0) 160 | 161 | sample = next(b) 162 | canvas = make(sample) 163 | canvas = torch.cat(canvas, dim=0) 164 | 165 | save_image(canvas * 0.5 + 0.5, './make_figures/reconstructions_ffhq_real_2.png', nrow=2, pad_value=1.0) 166 | 167 | 168 | if __name__ == "__main__": 169 | gpu_count = 1 170 | run(sample, get_cfg_defaults(), description='SoftIntroVAE-reconstruction-ffhq', 171 | default_config='./configs/ffhq256.yaml', 172 | world_size=gpu_count, write_log=False) 173 | -------------------------------------------------------------------------------- /soft_intro_vae_3d/datasets/modelnet40.py: -------------------------------------------------------------------------------- 1 | import urllib 2 | import shutil 3 | from os import listdir, makedirs, remove 4 | from os.path import exists, join 5 | from zipfile import ZipFile 6 | 7 | import h5py 8 | import numpy as np 9 | import pandas as pd 10 | from torch.utils.data import Dataset 11 | 12 | from utils.pcutil import rand_rotation_matrix 13 | 14 | all_classes = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 15 | 'bowl', 'car', 'chair', 'cone', 'cup', 'curtain', 'desk', 'door', 16 | 'dresser', 'flower_pot', 'glass_box', 'guitar', 'keyboard', 17 | 'lamp', 'laptop', 'mantel', 'monitor', 'night_stand', 'person', 18 | 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 19 | 'stairs', 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 20 | 'wardrobe', 'xbox'] 21 | 22 | number_to_category = {i: c for i, c in enumerate(all_classes)} 23 | category_to_number = {c: i for i, c in enumerate(all_classes)} 24 | 25 | 26 | class ModelNet40(Dataset): 27 | def __init__(self, root_dir='/home/datasets/modelnet40', classes=[], 28 | transform=[], split='train', valid_percent=10, percent_supervised=0.0): 29 | """ 30 | Args: 31 | root_dir (string): Directory with all the point clouds. 32 | transform (callable, optional): Optional transform to be applied 33 | on a sample. 34 | split (string): `train` or `test` 35 | valid_percent (int): Percent of train (from the end) to use as valid set. 36 | """ 37 | self.root_dir = root_dir 38 | self.transform = transform 39 | self.split = split.lower() 40 | self.valid_percent = valid_percent 41 | self.percent_supervised = percent_supervised 42 | 43 | self._maybe_download_data() 44 | 45 | if self.split in ('train', 'valid'): 46 | self.files_list = join(self.root_dir, 'train_files.txt') 47 | elif self.split == 'test': 48 | self.files_list = join(self.root_dir, 'test_files.txt') 49 | else: 50 | raise ValueError('Incorrect split') 51 | 52 | data, labels = self._load_files() 53 | 54 | if classes: 55 | if classes[0] in all_classes: 56 | classes = np.asarray([category_to_number[c] for c in classes]) 57 | filter = [label in classes for label in labels] 58 | data = data[filter] 59 | labels = labels[filter] 60 | else: 61 | classes = np.arange(len(all_classes)) 62 | 63 | if self.split in ('train', 'valid'): 64 | new_data, new_labels = [], [] 65 | if self.percent_supervised > 0.0: 66 | data_sup, labels_sub = [], [] 67 | for c in classes: 68 | pc_in_class = sum(labels.flatten() == c) 69 | 70 | if self.split == 'train': 71 | portion = slice(0, int(pc_in_class * (1 - (self.valid_percent / 100)))) 72 | else: 73 | portion = slice(int(pc_in_class * (1 - (self.valid_percent / 100))), pc_in_class) 74 | 75 | new_data.append(data[labels.flatten() == c][portion]) 76 | new_labels.append(labels[labels.flatten() == c][portion]) 77 | 78 | if self.percent_supervised > 0.0: 79 | n_max = int(self.percent_supervised * (portion.stop - 1)) 80 | data_sup.append(data[labels.flatten() == c][:n_max]) 81 | labels_sub.append(labels[labels.flatten() == c][:n_max]) 82 | data = np.vstack(new_data) 83 | labels = np.vstack(new_labels) 84 | if self.percent_supervised > 0.0: 85 | self.data_sup = np.vstack(data_sup) 86 | self.labels_sup = np.vstack(labels_sub) 87 | self.data = data 88 | self.labels = labels 89 | 90 | def _load_files(self) -> pd.DataFrame: 91 | 92 | with open(self.files_list) as f: 93 | files = [join(self.root_dir, line.rstrip().rsplit('/', 1)[1]) for line in f] 94 | 95 | data, labels = [], [] 96 | for file in files: 97 | with h5py.File(file) as f: 98 | data.extend(f['data'][:]) 99 | labels.extend(f['label'][:]) 100 | 101 | return np.asarray(data), np.asarray(labels) 102 | 103 | def __len__(self): 104 | return len(self.data) 105 | 106 | def __getitem__(self, idx): 107 | sample = self.data[idx] 108 | sample /= 2 # Scale to [-0.5, 0.5] range 109 | label = self.labels[idx] 110 | 111 | if 'rotate'.lower() in self.transform: 112 | r_rotation = rand_rotation_matrix() 113 | r_rotation[0, 2] = 0 114 | r_rotation[2, 0] = 0 115 | r_rotation[1, 2] = 0 116 | r_rotation[2, 1] = 0 117 | r_rotation[2, 2] = 1 118 | 119 | sample = sample.dot(r_rotation).astype(np.float32) 120 | if self.percent_supervised > 0.0: 121 | id_sup = np.random.randint(self.data_sup.shape[0]) 122 | sample_sup = self.data_sup[id_sup] 123 | sample_sup /= 2 124 | label_sup = self.labels_sup[id_sup] 125 | return sample, label, sample_sup, label_sup 126 | else: 127 | return sample, label 128 | 129 | def _maybe_download_data(self): 130 | if exists(self.root_dir): 131 | return 132 | 133 | print(f'ModelNet40 doesn\'t exist in root directory {self.root_dir}. ' 134 | f'Downloading...') 135 | makedirs(self.root_dir) 136 | 137 | url = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip' 138 | 139 | data = urllib.request.urlopen(url) 140 | filename = url.rpartition('/')[2][:-5] 141 | file_path = join(self.root_dir, filename) 142 | with open(file_path, mode='wb') as f: 143 | d = data.read() 144 | f.write(d) 145 | 146 | print('Extracting...') 147 | with ZipFile(file_path, mode='r') as zip_f: 148 | zip_f.extractall(self.root_dir) 149 | 150 | remove(file_path) 151 | 152 | extracted_dir = join(self.root_dir, 'modelnet40_ply_hdf5_2048') 153 | for d in listdir(extracted_dir): 154 | shutil.move(src=join(extracted_dir, d), 155 | dst=self.root_dir) 156 | 157 | shutil.rmtree(extracted_dir) 158 | 159 | 160 | if __name__ == '__main__': 161 | ModelNet40() 162 | -------------------------------------------------------------------------------- /style_soft_intro_vae/dataset_preparation/split_tfrecords_ffhq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019-2020 Stanislav Pidhorskyi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import os 17 | import sys 18 | import argparse 19 | import logging 20 | import tensorflow as tf 21 | # from defaults import get_cfg_defaults 22 | 23 | from yacs.config import CfgNode as CN 24 | 25 | _C = CN() 26 | 27 | _C.NAME = "" 28 | _C.PPL_CELEBA_ADJUSTMENT = False 29 | _C.OUTPUT_DIR = "results" 30 | 31 | _C.DATASET = CN() 32 | _C.DATASET.PATH = 'celeba/data_fold_%d_lod_%d.pkl' 33 | _C.DATASET.PATH_TEST = '' 34 | _C.DATASET.FFHQ_SOURCE = '/data/datasets/ffhq-dataset/tfrecords/ffhq/ffhq-r%02d.tfrecords' 35 | _C.DATASET.PART_COUNT = 1 36 | _C.DATASET.PART_COUNT_TEST = 1 37 | _C.DATASET.SIZE = 70000 38 | _C.DATASET.SIZE_TEST = 10000 39 | _C.DATASET.FLIP_IMAGES = True 40 | _C.DATASET.SAMPLES_PATH = 'dataset_samples/faces/realign128x128' 41 | 42 | _C.DATASET.STYLE_MIX_PATH = 'style_mixing/test_images/set_celeba/' 43 | 44 | _C.DATASET.MAX_RESOLUTION_LEVEL = 10 45 | 46 | _C.MODEL = CN() 47 | 48 | _C.MODEL.LAYER_COUNT = 6 49 | _C.MODEL.START_CHANNEL_COUNT = 64 50 | _C.MODEL.MAX_CHANNEL_COUNT = 512 51 | _C.MODEL.LATENT_SPACE_SIZE = 256 52 | _C.MODEL.DLATENT_AVG_BETA = 0.995 53 | _C.MODEL.TRUNCATIOM_PSI = 0.7 54 | _C.MODEL.TRUNCATIOM_CUTOFF = 8 55 | _C.MODEL.STYLE_MIXING_PROB = 0.9 56 | _C.MODEL.MAPPING_LAYERS = 5 57 | _C.MODEL.CHANNELS = 3 58 | _C.MODEL.GENERATOR = "GeneratorDefault" 59 | _C.MODEL.ENCODER = "EncoderDefault" 60 | _C.MODEL.MAPPING_TO_LATENT = "MappingToLatent" 61 | _C.MODEL.MAPPING_FROM_LATENT = "MappingFromLatent" 62 | _C.MODEL.Z_REGRESSION = False 63 | _C.MODEL.BETA_KL = 1.0 64 | _C.MODEL.BETA_REC = 1.0 65 | _C.MODEL.BETA_NEG = [2048, 2048, 1024, 512, 512, 128, 128, 64, 64] 66 | _C.MODEL.SCALE = 1 / (3 * 256 ** 2) 67 | 68 | _C.TRAIN = CN() 69 | 70 | _C.TRAIN.EPOCHS_PER_LOD = 15 71 | 72 | _C.TRAIN.BASE_LEARNING_RATE = 0.0015 73 | _C.TRAIN.ADAM_BETA_0 = 0.0 74 | _C.TRAIN.ADAM_BETA_1 = 0.99 75 | _C.TRAIN.LEARNING_DECAY_RATE = 0.1 76 | _C.TRAIN.LEARNING_DECAY_STEPS = [] 77 | _C.TRAIN.TRAIN_EPOCHS = 110 78 | _C.TRAIN.NUM_VAE = 1 79 | 80 | _C.TRAIN.LOD_2_BATCH_8GPU = [512, 256, 128, 64, 32, 32] 81 | _C.TRAIN.LOD_2_BATCH_4GPU = [512, 256, 128, 64, 32, 16] 82 | _C.TRAIN.LOD_2_BATCH_2GPU = [256, 256, 128, 64, 32, 16] 83 | _C.TRAIN.LOD_2_BATCH_1GPU = [128, 128, 128, 64, 32, 16] 84 | 85 | _C.TRAIN.SNAPSHOT_FREQ = [300, 300, 300, 100, 50, 30, 20, 20, 10] 86 | 87 | _C.TRAIN.REPORT_FREQ = [100, 80, 60, 30, 20, 10, 10, 5, 5] 88 | 89 | _C.TRAIN.LEARNING_RATES = [0.002] 90 | 91 | def get_cfg_defaults(): 92 | return _C.clone() 93 | 94 | 95 | def split_tfrecord(cfg, logger): 96 | tfrecord_path = cfg.DATASET.FFHQ_SOURCE 97 | 98 | ffhq_train_size = 60000 99 | 100 | part_size = ffhq_train_size // cfg.DATASET.PART_COUNT 101 | 102 | logger.info("Splitting into % size parts" % part_size) 103 | 104 | for i in range(2, cfg.DATASET.MAX_RESOLUTION_LEVEL + 1): 105 | with tf.Graph().as_default(), tf.Session() as sess: 106 | ds = tf.data.TFRecordDataset(tfrecord_path % i) 107 | ds = ds.batch(part_size) 108 | batch = ds.make_one_shot_iterator().get_next() 109 | part_num = 0 110 | while True: 111 | try: 112 | records = sess.run(batch) 113 | if part_num < cfg.DATASET.PART_COUNT: 114 | part_path = cfg.DATASET.PATH % (i, part_num) 115 | os.makedirs(os.path.dirname(part_path), exist_ok=True) 116 | with tf.python_io.TFRecordWriter(part_path) as writer: 117 | for record in records: 118 | writer.write(record) 119 | else: 120 | part_path = cfg.DATASET.PATH_TEST % (i, part_num - cfg.DATASET.PART_COUNT) 121 | os.makedirs(os.path.dirname(part_path), exist_ok=True) 122 | with tf.python_io.TFRecordWriter(part_path) as writer: 123 | for record in records: 124 | writer.write(record) 125 | part_num += 1 126 | except tf.errors.OutOfRangeError: 127 | break 128 | 129 | 130 | def run(): 131 | parser = argparse.ArgumentParser(description="ALAE. Split FFHQ into parts for training and testing") 132 | parser.add_argument( 133 | "--config-file", 134 | default="/home/tal/tmp/StyleSandwichVAE2/configs/ffhq256.yaml", 135 | metavar="FILE", 136 | help="path to config file", 137 | type=str, 138 | ) 139 | parser.add_argument( 140 | "opts", 141 | help="Modify config options using the command-line", 142 | default=None, 143 | nargs=argparse.REMAINDER, 144 | ) 145 | 146 | args = parser.parse_args() 147 | cfg = get_cfg_defaults() 148 | cfg.merge_from_file(args.config_file) 149 | cfg.merge_from_list(args.opts) 150 | cfg.freeze() 151 | 152 | logger = logging.getLogger("logger") 153 | logger.setLevel(logging.DEBUG) 154 | 155 | output_dir = cfg.OUTPUT_DIR 156 | os.makedirs(output_dir, exist_ok=True) 157 | 158 | ch = logging.StreamHandler(stream=sys.stdout) 159 | ch.setLevel(logging.DEBUG) 160 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 161 | ch.setFormatter(formatter) 162 | logger.addHandler(ch) 163 | 164 | fh = logging.FileHandler(os.path.join(output_dir, 'log.txt')) 165 | fh.setLevel(logging.DEBUG) 166 | fh.setFormatter(formatter) 167 | logger.addHandler(fh) 168 | 169 | logger.info(args) 170 | 171 | logger.info("Loaded configuration file {}".format(args.config_file)) 172 | with open(args.config_file, "r") as cf: 173 | config_str = "\n" + cf.read() 174 | logger.info(config_str) 175 | logger.info("Running with config:\n{}".format(cfg)) 176 | 177 | split_tfrecord(cfg, logger) 178 | 179 | 180 | if __name__ == '__main__': 181 | run() 182 | 183 | -------------------------------------------------------------------------------- /style_soft_intro_vae/make_figures/make_recon_figure_interpolation_2_images.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2021 Tal Daniel 2 | # Copyright 2019-2020 Stanislav Pidhorskyi 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | import torch.utils.data 18 | from torchvision.utils import save_image 19 | from net import * 20 | from model import SoftIntroVAEModelTL 21 | from launcher import run 22 | from checkpointer import Checkpointer 23 | from defaults import get_cfg_defaults 24 | import lreq 25 | from PIL import Image 26 | 27 | lreq.use_implicit_lreq.set(True) 28 | 29 | 30 | def millify(n): 31 | millnames = ['', 'k', 'M', 'G', 'T', 'P'] 32 | n = float(n) 33 | millidx = max(0, min(len(millnames) - 1, int(math.floor(0 if n == 0 else math.log10(abs(n)) / 3)))) 34 | 35 | return '{:.1f}{}'.format(n / 10 ** (3 * millidx), millnames[millidx]) 36 | 37 | 38 | def count_parameters(model, print_func=print, verbose=False): 39 | for n, p in model.named_parameters(): 40 | if p.requires_grad and verbose: 41 | print_func(n, millify(p.numel())) 42 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 43 | 44 | 45 | def place(canvas, image, x, y): 46 | im_size = image.shape[2] 47 | if len(image.shape) == 4: 48 | image = image[0] 49 | canvas[:, y: y + im_size, x: x + im_size] = image * 0.5 + 0.5 50 | 51 | 52 | def save_sample(model, sample, i): 53 | os.makedirs('results', exist_ok=True) 54 | 55 | with torch.no_grad(): 56 | model.eval() 57 | x_rec = model.generate(model.generator.layer_count - 1, 1, z=sample) 58 | 59 | def save_pic(x_rec): 60 | resultsample = x_rec * 0.5 + 0.5 61 | resultsample = resultsample.cpu() 62 | save_image(resultsample, 63 | 'sample_%i_lr.png' % i, nrow=16) 64 | 65 | save_pic(x_rec) 66 | 67 | 68 | def sample(cfg, logger): 69 | torch.cuda.set_device(0) 70 | model = SoftIntroVAEModelTL( 71 | startf=cfg.MODEL.START_CHANNEL_COUNT, 72 | layer_count=cfg.MODEL.LAYER_COUNT, 73 | maxf=cfg.MODEL.MAX_CHANNEL_COUNT, 74 | latent_size=cfg.MODEL.LATENT_SPACE_SIZE, 75 | dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA, 76 | style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB, 77 | mapping_layers=cfg.MODEL.MAPPING_LAYERS, 78 | channels=cfg.MODEL.CHANNELS, 79 | generator=cfg.MODEL.GENERATOR, 80 | encoder=cfg.MODEL.ENCODER, 81 | beta_kl=cfg.MODEL.BETA_KL, 82 | beta_rec=cfg.MODEL.BETA_REC, 83 | beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1], 84 | scale=cfg.MODEL.SCALE 85 | ) 86 | model.cuda(0) 87 | model.eval() 88 | model.requires_grad_(False) 89 | 90 | decoder = model.decoder 91 | encoder = model.encoder 92 | mapping_tl = model.mapping_tl 93 | mapping_fl = model.mapping_fl 94 | dlatent_avg = model.dlatent_avg 95 | 96 | logger.info("Trainable parameters decoder:") 97 | print(count_parameters(decoder)) 98 | 99 | logger.info("Trainable parameters encoder:") 100 | print(count_parameters(encoder)) 101 | 102 | arguments = dict() 103 | arguments["iteration"] = 0 104 | 105 | model_dict = { 106 | 'discriminator_s': encoder, 107 | 'generator_s': decoder, 108 | 'mapping_tl_s': mapping_tl, 109 | 'mapping_fl_s': mapping_fl, 110 | 'dlatent_avg': dlatent_avg 111 | } 112 | 113 | checkpointer = Checkpointer(cfg, 114 | model_dict, 115 | {}, 116 | logger=logger, 117 | save=False) 118 | 119 | extra_checkpoint_data = checkpointer.load() 120 | 121 | model.eval() 122 | 123 | layer_count = cfg.MODEL.LAYER_COUNT 124 | 125 | def encode(x): 126 | z, mu, _ = model.encode(x, layer_count - 1, 1) 127 | styles = model.mapping_fl(mu) 128 | return styles 129 | 130 | def decode(x): 131 | return model.decoder(x, layer_count - 1, 1, noise=True) 132 | 133 | rnd = np.random.RandomState(4) 134 | 135 | path = cfg.DATASET.SAMPLES_PATH 136 | im_size = 2 ** (cfg.MODEL.LAYER_COUNT + 1) 137 | 138 | pathA = '17460.jpg' 139 | pathB = '02973.jpg' 140 | 141 | def open_image(filename): 142 | img = np.asarray(Image.open(path + '/' + filename)) 143 | if img.shape[2] == 4: 144 | img = img[:, :, :3] 145 | im = img.transpose((2, 0, 1)) 146 | x = torch.tensor(np.asarray(im, dtype=np.float32), device='cpu', requires_grad=True).cuda() / 127.5 - 1. 147 | if x.shape[0] == 4: 148 | x = x[:3] 149 | factor = x.shape[2] // im_size 150 | if factor != 1: 151 | x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0] 152 | assert x.shape[2] == im_size 153 | _latents = encode(x[None, ...].cuda()) 154 | latents = _latents[0, 0] 155 | return latents 156 | 157 | def make(w): 158 | with torch.no_grad(): 159 | w = w[None, None, ...].repeat(1, model.mapping_fl.num_layers, 1) 160 | x_rec = decode(w) 161 | return x_rec 162 | 163 | wa = open_image(pathA) 164 | wb = open_image(pathB) 165 | 166 | width = 7 167 | 168 | images = [] 169 | 170 | for j in range(width): 171 | kh = j / (width - 1.0) 172 | 173 | ka = (1.0 - kh) 174 | kb = kh 175 | w = ka * wa + kb * wb 176 | 177 | interpolated = make(w) 178 | images.append(interpolated) 179 | 180 | images = torch.cat(images) 181 | 182 | path = './make_figures/output' 183 | os.makedirs(path, exist_ok=True) 184 | os.makedirs(os.path.join(path, cfg.NAME), exist_ok=True) 185 | save_image(images * 0.5 + 0.5, './make_figures/output/%s/interpolations.png' % cfg.NAME, nrow=width) 186 | save_image(images * 0.5 + 0.5, './make_figures/output/%s/interpolations.jpg' % cfg.NAME, nrow=width) 187 | 188 | 189 | if __name__ == "__main__": 190 | gpu_count = 1 191 | run(sample, get_cfg_defaults(), description='SoftIntroVAE-interpolations', 192 | default_config='./configs/ffhq256.yaml', 193 | world_size=gpu_count, write_log=False) 194 | -------------------------------------------------------------------------------- /soft_intro_vae_3d/evaluation/find_best_epoch_on_validation_soft.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import random 5 | import re 6 | from datetime import datetime 7 | from os import listdir 8 | from os.path import join 9 | 10 | import numpy as np 11 | import pandas as pd 12 | import torch 13 | from torch.utils.data import DataLoader 14 | 15 | from datasets.shapenet import ShapeNetDataset 16 | from metrics.jsd import jsd_between_point_cloud_sets 17 | from utils.util import cuda_setup, setup_logging 18 | from models.vae import SoftIntroVAE, reparameterize 19 | 20 | 21 | def _get_epochs_by_regex(path, regex): 22 | reg = re.compile(regex) 23 | return {int(w[:5]) for w in listdir(path) if reg.match(w)} 24 | 25 | 26 | def main(eval_config): 27 | # Load hyperparameters as they were during training 28 | train_results_path = join(eval_config['results_root'], eval_config['arch'], 29 | eval_config['experiment_name']) 30 | with open(join(train_results_path, 'config.json')) as f: 31 | train_config = json.load(f) 32 | 33 | if train_config['seed'] >= 0: 34 | random.seed(train_config['seed']) 35 | torch.manual_seed(train_config['seed']) 36 | torch.cuda.manual_seed(train_config['seed']) 37 | np.random.seed(train_config['seed']) 38 | torch.backends.cudnn.deterministic = True 39 | print("random seed: ", train_config['seed']) 40 | 41 | setup_logging(join(train_results_path, 'results')) 42 | log = logging.getLogger(__name__) 43 | 44 | log.debug('Evaluating Jensen-Shannon divergences on validation set on all ' 45 | 'saved epochs.') 46 | 47 | weights_path = join(train_results_path, 'weights') 48 | 49 | # Find all epochs that have saved model weights 50 | v_epochs = _get_epochs_by_regex(weights_path, r'(?P\d{5})\.pth') 51 | epochs = sorted(v_epochs) 52 | log.debug(f'Testing epochs: {epochs}') 53 | 54 | device = cuda_setup(eval_config['cuda'], eval_config['gpu']) 55 | log.debug(f'Device variable: {device}') 56 | if device.type == 'cuda': 57 | log.debug(f'Current CUDA device: {torch.cuda.current_device()}') 58 | 59 | # 60 | # Dataset 61 | # 62 | dataset_name = train_config['dataset'].lower() 63 | if dataset_name == 'shapenet': 64 | dataset = ShapeNetDataset(root_dir=train_config['data_dir'], 65 | classes=train_config['classes'], split='valid') 66 | # elif dataset_name == 'faust': 67 | # from datasets.dfaust import DFaustDataset 68 | # dataset = DFaustDataset(root_dir=train_config['data_dir'], 69 | # classes=train_config['classes'], split='valid') 70 | # elif dataset_name == 'mcgill': 71 | # from datasets.mcgill import McGillDataset 72 | # dataset = McGillDataset(root_dir=train_config['data_dir'], 73 | # classes=train_config['classes'], split='valid') 74 | else: 75 | raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' 76 | f'`faust`. Got: `{dataset_name}`') 77 | classes_selected = ('all' if not train_config['classes'] 78 | else ','.join(train_config['classes'])) 79 | log.debug(f'Selected {classes_selected} classes. Loaded {len(dataset)} ' 80 | f'samples.') 81 | 82 | # if 'distribution' in train_config: 83 | # distribution = train_config['distribution'] 84 | # elif 'distribution' in eval_config: 85 | # distribution = eval_config['distribution'] 86 | # else: 87 | # log.warning('No distribution type specified. Assumed normal = N(0, 0.2)') 88 | # distribution = 'normal' 89 | 90 | # 91 | # Models 92 | 93 | model = SoftIntroVAE(train_config).to(device) 94 | model.eval() 95 | 96 | num_samples = len(dataset.point_clouds_names_valid) 97 | data_loader = DataLoader(dataset, batch_size=num_samples, 98 | shuffle=False, num_workers=4, 99 | drop_last=False, pin_memory=True) 100 | 101 | # We take 3 times as many samples as there are in test data in order to 102 | # perform JSD calculation in the same manner as in the reference publication 103 | # noise = torch.FloatTensor(3 * num_samples, train_config['z_size'], 1) 104 | noise = torch.randn(3 * num_samples, model.zdim) 105 | noise = noise.to(device) 106 | 107 | x, _ = next(iter(data_loader)) 108 | x = x.to(device) 109 | 110 | results = {} 111 | 112 | for epoch in reversed(epochs): 113 | try: 114 | model.load_state_dict(torch.load( 115 | join(weights_path, f'{epoch:05}.pth'))) 116 | 117 | start_clock = datetime.now() 118 | 119 | # We average JSD computation from 3 independet trials. 120 | js_results = [] 121 | for _ in range(3): 122 | # if distribution == 'normal': 123 | # noise.normal_(0, 0.2) 124 | # elif distribution == 'beta': 125 | # noise_np = np.random.beta(train_config['z_beta_a'], 126 | # train_config['z_beta_b'], 127 | # noise.shape) 128 | # noise = torch.tensor(noise_np).float().round().to(device) 129 | 130 | with torch.no_grad(): 131 | x_g = model.decode(noise) 132 | if x_g.shape[-2:] == (3, 2048): 133 | x_g.transpose_(1, 2) 134 | 135 | jsd = jsd_between_point_cloud_sets(x, x_g, voxels=28) 136 | js_results.append(jsd) 137 | 138 | js_result = np.mean(js_results) 139 | log.debug(f'Epoch: {epoch} JSD: {js_result: .6f} ' 140 | f'Time: {datetime.now() - start_clock}') 141 | results[epoch] = js_result 142 | except KeyboardInterrupt: 143 | log.debug(f'Interrupted during epoch: {epoch}') 144 | break 145 | 146 | results = pd.DataFrame.from_dict(results, orient='index', columns=['jsd']) 147 | log.debug(f"Minimum JSD at epoch {results.idxmin()['jsd']}: " 148 | f"{results.min()['jsd']: .6f}") 149 | 150 | 151 | if __name__ == '__main__': 152 | logger = logging.getLogger() 153 | 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument('-c', '--config', default=None, type=str, 156 | help='File path for evaluation config') 157 | args = parser.parse_args() 158 | 159 | args.config = './config/soft_intro_vae_hp.json' 160 | evaluation_config = None 161 | if args.config is not None and args.config.endswith('.json'): 162 | with open(args.config) as f: 163 | evaluation_config = json.load(f) 164 | assert evaluation_config is not None 165 | 166 | main(evaluation_config) 167 | -------------------------------------------------------------------------------- /style_soft_intro_vae/make_figures/make_recon_figure_paged.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2021 Tal Daniel 2 | # Copyright 2019-2020 Stanislav Pidhorskyi 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | import torch.utils.data 18 | from torchvision.utils import save_image 19 | import random 20 | from net import * 21 | from model import SoftIntroVAEModelTL 22 | from launcher import run 23 | from checkpointer import Checkpointer 24 | from defaults import get_cfg_defaults 25 | import lreq 26 | from PIL import Image 27 | 28 | lreq.use_implicit_lreq.set(True) 29 | 30 | 31 | def millify(n): 32 | millnames = ['', 'k', 'M', 'G', 'T', 'P'] 33 | n = float(n) 34 | millidx = max(0, min(len(millnames) - 1, int(math.floor(0 if n == 0 else math.log10(abs(n)) / 3)))) 35 | 36 | return '{:.1f}{}'.format(n / 10 ** (3 * millidx), millnames[millidx]) 37 | 38 | 39 | def count_parameters(model, print_func=print, verbose=False): 40 | for n, p in model.named_parameters(): 41 | if p.requires_grad and verbose: 42 | print_func(n, millify(p.numel())) 43 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 44 | 45 | 46 | def place(canvas, image, x, y): 47 | im_size = image.shape[2] 48 | if len(image.shape) == 4: 49 | image = image[0] 50 | canvas[:, y: y + im_size, x: x + im_size] = image * 0.5 + 0.5 51 | 52 | 53 | def save_sample(model, sample, i): 54 | os.makedirs('results', exist_ok=True) 55 | 56 | with torch.no_grad(): 57 | model.eval() 58 | x_rec = model.generate(model.generator.layer_count - 1, 1, z=sample) 59 | 60 | def save_pic(x_rec): 61 | resultsample = x_rec * 0.5 + 0.5 62 | resultsample = resultsample.cpu() 63 | save_image(resultsample, 64 | 'sample_%i_lr.png' % i, nrow=16) 65 | 66 | save_pic(x_rec) 67 | 68 | 69 | def sample(cfg, logger): 70 | torch.cuda.set_device(0) 71 | model = SoftIntroVAEModelTL( 72 | startf=cfg.MODEL.START_CHANNEL_COUNT, 73 | layer_count=cfg.MODEL.LAYER_COUNT, 74 | maxf=cfg.MODEL.MAX_CHANNEL_COUNT, 75 | latent_size=cfg.MODEL.LATENT_SPACE_SIZE, 76 | dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA, 77 | style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB, 78 | mapping_layers=cfg.MODEL.MAPPING_LAYERS, 79 | channels=cfg.MODEL.CHANNELS, 80 | generator=cfg.MODEL.GENERATOR, 81 | encoder=cfg.MODEL.ENCODER, 82 | beta_kl=cfg.MODEL.BETA_KL, 83 | beta_rec=cfg.MODEL.BETA_REC, 84 | beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1], 85 | scale=cfg.MODEL.SCALE 86 | ) 87 | model.cuda(0) 88 | model.eval() 89 | model.requires_grad_(False) 90 | 91 | decoder = model.decoder 92 | encoder = model.encoder 93 | mapping_tl = model.mapping_tl 94 | mapping_fl = model.mapping_fl 95 | dlatent_avg = model.dlatent_avg 96 | 97 | logger.info("Trainable parameters decoder:") 98 | print(count_parameters(decoder)) 99 | 100 | logger.info("Trainable parameters encoder:") 101 | print(count_parameters(encoder)) 102 | 103 | arguments = dict() 104 | arguments["iteration"] = 0 105 | 106 | model_dict = { 107 | 'discriminator_s': encoder, 108 | 'generator_s': decoder, 109 | 'mapping_tl_s': mapping_tl, 110 | 'mapping_fl_s': mapping_fl, 111 | 'dlatent_avg': dlatent_avg 112 | } 113 | 114 | checkpointer = Checkpointer(cfg, 115 | model_dict, 116 | {}, 117 | logger=logger, 118 | save=False) 119 | 120 | extra_checkpoint_data = checkpointer.load() 121 | 122 | model.eval() 123 | 124 | layer_count = cfg.MODEL.LAYER_COUNT 125 | 126 | def encode(x): 127 | z, mu, _ = model.encode(x, layer_count - 1, 1) 128 | styles = model.mapping_fl(mu) 129 | return styles 130 | 131 | def decode(x): 132 | return model.decoder(x, layer_count - 1, 1, noise=True) 133 | 134 | path = cfg.DATASET.SAMPLES_PATH 135 | im_size = 2 ** (cfg.MODEL.LAYER_COUNT + 1) 136 | 137 | paths = list(os.listdir(path)) 138 | 139 | paths = sorted(paths) 140 | random.seed(1) 141 | random.shuffle(paths) 142 | 143 | def make(paths): 144 | canvas = [] 145 | with torch.no_grad(): 146 | for filename in paths: 147 | img = np.asarray(Image.open(path + '/' + filename)) 148 | if img.shape[2] == 4: 149 | img = img[:, :, :3] 150 | im = img.transpose((2, 0, 1)) 151 | x = torch.tensor(np.asarray(im, dtype=np.float32), device='cpu', requires_grad=True).cuda() / 127.5 - 1. 152 | if x.shape[0] == 4: 153 | x = x[:3] 154 | factor = x.shape[2] // im_size 155 | if factor != 1: 156 | x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0] 157 | assert x.shape[2] == im_size 158 | latents = encode(x[None, ...].cuda()) 159 | f = decode(latents) 160 | r = torch.cat([x[None, ...].detach().cpu(), f.detach().cpu()], dim=3) 161 | canvas.append(r) 162 | return canvas 163 | 164 | def chunker_list(seq, n): 165 | return [seq[i * n:(i + 1) * n] for i in range((len(seq) + n - 1) // n)] 166 | 167 | paths = chunker_list(paths, 8 * 3) 168 | 169 | path = './make_figures/output' 170 | os.makedirs(path, exist_ok=True) 171 | os.makedirs(os.path.join(path, cfg.NAME), exist_ok=True) 172 | 173 | for i, chunk in enumerate(paths): 174 | canvas = make(chunk) 175 | canvas = torch.cat(canvas, dim=0) 176 | 177 | save_path = './make_figures/output/%s/reconstructions_%d.png' % (cfg.NAME, i) 178 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 179 | save_image(canvas * 0.5 + 0.5, save_path, 180 | nrow=3, 181 | pad_value=1.0) 182 | 183 | 184 | if __name__ == "__main__": 185 | gpu_count = 1 186 | run(sample, get_cfg_defaults(), description='SoftIntroVAE-figure-reconstructions-paged', 187 | default_config='./configs/ffhq256.yaml', 188 | world_size=gpu_count, write_log=False) 189 | -------------------------------------------------------------------------------- /style_soft_intro_vae/make_figures/make_recon_figure_interpolation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2021 Tal Daniel 2 | # Copyright 2019-2020 Stanislav Pidhorskyi 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | import torch.utils.data 18 | from torchvision.utils import save_image 19 | from net import * 20 | from model import SoftIntroVAEModelTL 21 | from launcher import run 22 | from checkpointer import Checkpointer 23 | from defaults import get_cfg_defaults 24 | import lreq 25 | from PIL import Image 26 | 27 | lreq.use_implicit_lreq.set(True) 28 | 29 | 30 | def millify(n): 31 | millnames = ['', 'k', 'M', 'G', 'T', 'P'] 32 | n = float(n) 33 | millidx = max(0, min(len(millnames) - 1, int(math.floor(0 if n == 0 else math.log10(abs(n)) / 3)))) 34 | 35 | return '{:.1f}{}'.format(n / 10 ** (3 * millidx), millnames[millidx]) 36 | 37 | 38 | def count_parameters(model, print_func=print, verbose=False): 39 | for n, p in model.named_parameters(): 40 | if p.requires_grad and verbose: 41 | print_func(n, millify(p.numel())) 42 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 43 | 44 | 45 | def place(canvas, image, x, y): 46 | im_size = image.shape[2] 47 | if len(image.shape) == 4: 48 | image = image[0] 49 | canvas[:, y: y + im_size, x: x + im_size] = image * 0.5 + 0.5 50 | 51 | 52 | def save_sample(model, sample, i): 53 | os.makedirs('results', exist_ok=True) 54 | 55 | with torch.no_grad(): 56 | model.eval() 57 | x_rec = model.generate(model.generator.layer_count - 1, 1, z=sample) 58 | 59 | def save_pic(x_rec): 60 | resultsample = x_rec * 0.5 + 0.5 61 | resultsample = resultsample.cpu() 62 | save_image(resultsample, 63 | 'sample_%i_lr.png' % i, nrow=16) 64 | 65 | save_pic(x_rec) 66 | 67 | 68 | def sample(cfg, logger): 69 | torch.cuda.set_device(0) 70 | model = SoftIntroVAEModelTL( 71 | startf=cfg.MODEL.START_CHANNEL_COUNT, 72 | layer_count=cfg.MODEL.LAYER_COUNT, 73 | maxf=cfg.MODEL.MAX_CHANNEL_COUNT, 74 | latent_size=cfg.MODEL.LATENT_SPACE_SIZE, 75 | dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA, 76 | style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB, 77 | mapping_layers=cfg.MODEL.MAPPING_LAYERS, 78 | channels=cfg.MODEL.CHANNELS, 79 | generator=cfg.MODEL.GENERATOR, 80 | encoder=cfg.MODEL.ENCODER, 81 | beta_kl=cfg.MODEL.BETA_KL, 82 | beta_rec=cfg.MODEL.BETA_REC, 83 | beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1], 84 | scale=cfg.MODEL.SCALE 85 | ) 86 | model.cuda(0) 87 | model.eval() 88 | model.requires_grad_(False) 89 | 90 | decoder = model.decoder 91 | encoder = model.encoder 92 | mapping_tl = model.mapping_tl 93 | mapping_fl = model.mapping_fl 94 | dlatent_avg = model.dlatent_avg 95 | 96 | logger.info("Trainable parameters decoder:") 97 | print(count_parameters(decoder)) 98 | 99 | logger.info("Trainable parameters encoder:") 100 | print(count_parameters(encoder)) 101 | 102 | arguments = dict() 103 | arguments["iteration"] = 0 104 | 105 | model_dict = { 106 | 'discriminator_s': encoder, 107 | 'generator_s': decoder, 108 | 'mapping_tl_s': mapping_tl, 109 | 'mapping_fl_s': mapping_fl, 110 | 'dlatent_avg': dlatent_avg 111 | } 112 | 113 | checkpointer = Checkpointer(cfg, 114 | model_dict, 115 | {}, 116 | logger=logger, 117 | save=False) 118 | 119 | extra_checkpoint_data = checkpointer.load() 120 | 121 | model.eval() 122 | 123 | layer_count = cfg.MODEL.LAYER_COUNT 124 | 125 | def encode(x): 126 | z, mu, _ = model.encode(x, layer_count - 1, 1) 127 | styles = model.mapping_fl(mu) 128 | return styles 129 | 130 | def decode(x): 131 | return model.decoder(x, layer_count - 1, 1, noise=True) 132 | 133 | rnd = np.random.RandomState(4) 134 | 135 | path = cfg.DATASET.SAMPLES_PATH 136 | im_size = 2 ** (cfg.MODEL.LAYER_COUNT + 1) 137 | 138 | pathA = '00001.png' 139 | pathB = '00022.png' 140 | pathC = '00077.png' 141 | pathD = '00016.png' 142 | 143 | def open_image(filename): 144 | img = np.asarray(Image.open(path + '/' + filename)) 145 | if img.shape[2] == 4: 146 | img = img[:, :, :3] 147 | im = img.transpose((2, 0, 1)) 148 | x = torch.tensor(np.asarray(im, dtype=np.float32), device='cpu', requires_grad=True).cuda() / 127.5 - 1. 149 | if x.shape[0] == 4: 150 | x = x[:3] 151 | factor = x.shape[2] // im_size 152 | if factor != 1: 153 | x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0] 154 | assert x.shape[2] == im_size 155 | _latents = encode(x[None, ...].cuda()) 156 | latents = _latents[0, 0] 157 | return latents 158 | 159 | def make(w): 160 | with torch.no_grad(): 161 | w = w[None, None, ...].repeat(1, model.mapping_fl.num_layers, 1) 162 | x_rec = decode(w) 163 | return x_rec 164 | 165 | wa = open_image(pathA) 166 | wb = open_image(pathB) 167 | wc = open_image(pathC) 168 | wd = open_image(pathD) 169 | 170 | height = 7 171 | width = 7 172 | 173 | images = [] 174 | 175 | for i in range(height): 176 | for j in range(width): 177 | kv = i / (height - 1.0) 178 | kh = j / (width - 1.0) 179 | 180 | ka = (1.0 - kh) * (1.0 - kv) 181 | kb = kh * (1.0 - kv) 182 | kc = (1.0 - kh) * kv 183 | kd = kh * kv 184 | 185 | w = ka * wa + kb * wb + kc * wc + kd * wd 186 | 187 | interpolated = make(w) 188 | images.append(interpolated) 189 | 190 | images = torch.cat(images) 191 | 192 | path = './make_figures/output' 193 | os.makedirs(path, exist_ok=True) 194 | os.makedirs(os.path.join(path, cfg.NAME), exist_ok=True) 195 | 196 | save_image(images * 0.5 + 0.5, './make_figures/output/%s/interpolations.png' % cfg.NAME, nrow=width) 197 | save_image(images * 0.5 + 0.5, './make_figures/output/%s/interpolations.jpg' % cfg.NAME, nrow=width) 198 | 199 | 200 | if __name__ == "__main__": 201 | gpu_count = 1 202 | run(sample, get_cfg_defaults(), description='SoftIntroVAE-interpolations', default_config='./configs/ffhq256.yaml', 203 | world_size=gpu_count, write_log=False) 204 | -------------------------------------------------------------------------------- /style_soft_intro_vae/README.md: -------------------------------------------------------------------------------- 1 | # style-soft-intro-vae-pytorch 2 | 3 | Implementation of Style Soft-IntroVAE for image data. 4 | 5 | This codes builds upon the original Adversarial Latent Autoencoders (ALAE) implementation by Stanislav Pidhorskyi. 6 | Please see the [official repository](https://github.com/podgorskiy/ALAE) for a more detailed explanation of the files and how to get the datasets. 7 | The authors would like to thank Stanislav Pidhorskyi, Donald A. Adjeroh and Gianfranco Doretto for their great work which inspired this implementation. 8 | 9 |

10 | 11 | 12 |

13 | 14 |

15 | 16 |

17 | 18 | - [style-soft-intro-vae-pytorch](#style-soft-intro-vae-pytorch) 19 | * [Requirements](#requirements) 20 | * [Training](#training) 21 | * [Datasets](#datasets) 22 | * [Pretrained models](#pretrained-models) 23 | * [Recommended hyperparameters](#recommended-hyperparameters) 24 | * [What to expect](#what-to-expect) 25 | * [Files and directories in the repository](#files-and-directories-in-the-repository) 26 | * [Credits](#credits) 27 | 28 | ## Requirements 29 | 30 | * Please see ALAE's repository for explanation of the requirements and how to get them. 31 | * We provide an `environment.yml` file for `conda`, which installs all that is needed to run the files, in an environment called `tf_torch`. 32 | * `conda env create -f environment.yml` 33 | * The required packages are located in the `requirements.txt` file. 34 | * `pip install -r requirements.txt` 35 | * If you installed the environment using the `environment.yml` file, there is no need to use the `requirements.txt` file. 36 | * As in the original ALAE repository, the code is organized in such a way that all scripts must be run from the root of the repository. 37 | * If you use an IDE (e.g. PyCharm or Visual Studio Code), just set Working Directory to point to the root of the repository. 38 | * If you want to run from the command line, then you also need to set PYTHONPATH variable to point to the root of the repository. 39 | * Run `$ export PYTHONPATH=$PYTHONPATH:$(pwd)` in the root directory. 40 | 41 | ## Training 42 | 43 | * This implementation uses the [DareBlopy](https://github.com/podgorskiy/DareBlopy) package to load the data. 44 | * `pip install dareblopy` (in addition to the `pip install -r requirements.txt`) 45 | * TL;DR: read TFRecords files in PyTorch, for better utilization of the data loading, train faster. 46 | 47 | To run training: 48 | 49 | `python train_style_soft_intro_vae.py -c ` 50 | 51 | * Configs are located in the `configs` directory, edit them to change the hyperparameters. 52 | * It will run multi-GPU training on all available GPUs. It uses DistributedDataParallel for parallelism. If only one GPU available, it will run on single GPU, no special care is needed. 53 | * To modify the visible GPUs, edit the line: `os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"` in the `train_style_soft_intro_vae.py` file. 54 | 55 | Examples: 56 | 57 | `python train_style_soft_intro_vae.py` (FFHQ) 58 | 59 | `python train_style_soft_intro_vae.py -c ./configs/celeba-hq256` 60 | 61 | 62 | ## Datasets 63 | * CelebHQ: please follow [ALAE](https://github.com/podgorskiy/ALAE#datasets) instructions. 64 | * FFHQ: please see the following repository [Flickr-Faces-HQ Dataset (FFHQ)](https://github.com/NVlabs/ffhq-dataset). 65 | 66 | ## Pretrained models 67 | |Dataset | Filename | Where to Put| Links| 68 | |------------|------|----|---| 69 | |CelebA-HQ (256x256)|`celebahq_fid_18.63_epoch_230.pth` |`training_artifacts/celeba-hq256`|[MEGA.co.nz](https://mega.nz/file/sJkS2BAC#aGFwJIPvOTIP147GwBGHOJgwRMC_NYKHT_QK7abb0VE), [Mediafire](https://www.mediafire.com/file/fgf0a85z5d0jtu5/celebahq_fid_18.63_epoch_230.pth/file) | 70 | |FFHQ (256x256)|`ffhq_fid_17.55_epoch_270.pth` |`training_artifacts/ffhq` | [MEGA.co.nz](https://mega.nz/file/YJ1SkBwI#9t0ZEZTC0WWG0NUsJg2OwZujOuUXKn_ehP6fba1pV7o), [Mediafire](https://www.mediafire.com/file/x6jkyg4rlkqc4hl/ffhq_fid_17.55_epoch_270.pth/file) | 71 | 72 | * In config files, `OUTPUT_DIR` points to where weights are saved to and read from. For example: `OUTPUT_DIR: training_artifacts/celeba-hq256`. 73 | 74 | * In `OUTPUT_DIR` it saves a file `last_checkpoint` which contains the path to the actual `.pth` pickle with model weight. If you want to test the model with a specific weight file, you can simply modify `last_checkpoint` file. 75 | 76 | ## Recommended hyperparameters 77 | 78 | |Dataset | `beta_kl` | `beta_rec`| `beta_neg`|`z_dim`| 79 | |------------|------|----|---|----| 80 | |FFHQ (256x256)|0.2|0.05| 512|512| 81 | |CelebA-HQ (256x256)|0.2|0.1| 512|512| 82 | 83 | 84 | ## What to expect 85 | 86 | * During the training, figures of samples and reconstructions are saved locally. 87 | * During training, statistics are printed (reconstruction error, KLD, expELBO). 88 | * In the final resolution stage (256x256), FID will be calculated every 10 epcohs. 89 | * Tips: 90 | * KL of fake/rec samples should be > KL of real data (by a fair margin). 91 | * It is usually better to choose `beta_kl` >= `beta_rec`. 92 | * We stick to ALAE's original architecture hyperparameters, and mostly didn't change their configs. 93 | 94 | 95 | ## Files and directories in the repository 96 | 97 | * For a full description, please ALAE's repository. 98 | 99 | |File name | Purpose | 100 | |----------------------|------| 101 | |`train_style_soft_intro_vae.py`| main training function| 102 | |`checkpointer.py`| module for saving/restoring model weights, optimizer state and loss history.| 103 | |`custom_adam.py`| customized adam optimizer for learning rate equalization and zero second beta.| 104 | |`dataloader.py`| module with dataset classes, loaders, iterators, etc.| 105 | |`defaults.py`| definition for config variables with default values.| 106 | |`launcher.py`| helper for running multi-GPU, multiprocess training. Sets up config and logging.| 107 | |`lod_driver.py`| helper class for managing growing/stabilizing network.| 108 | |`lreq.py`| custom `Linear`, `Conv2d` and `ConvTranspose2d` modules for learning rate equalization.| 109 | |`model.py`| module with high-level model definition.| 110 | |`net.py`| definition of all network blocks for multiple architectures.| 111 | |`registry.py`| registry of network blocks for selecting from config file.| 112 | |`scheduler.py`| custom schedulers with warm start and aggregating several optimizers.| 113 | |`tracker.py`| module for plotting losses.| 114 | |`utils.py`| decorator for async call, decorator for caching, registry for network blocks.| 115 | |`configs/celeba-hq256.yaml`, `configs/ffhq256.yaml`| config file for CelebA-HQ and FFHQ datasets at 256x256 resolution.| 116 | |`dataset_preparation/`| folder with scripts for dataset preparation (creating and splitting TFRecords files).| 117 | |`make_figures/`| scripts for making various figures.| 118 | |`metrics/fid_score.py`, `metrics/inception.py`| functions for FID calculation from datasets, using the pretrained Inception network| 119 | |` training_artifacts/`| default folder for saving checkpoints/sample outputs/plots.| 120 | 121 | ## Credits 122 | * Adversarial Latent Autoencoders, Pidhorskyi et al., CVPR 2020 - [Code](https://github.com/podgorskiy/ALAE), [Paper](https://arxiv.org/abs/2004.04467). 123 | 124 | -------------------------------------------------------------------------------- /soft_intro_vae_3d/render/render_mitsuba2_pc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Render Point Clouds with Mitsuba Renderer. 3 | Adpated from: https://github.com/tolgabirdal/Mitsuba2PointCloudRenderer 4 | """ 5 | 6 | import numpy as np 7 | import sys, os, subprocess 8 | import OpenEXR 9 | import Imath 10 | from PIL import Image 11 | from plyfile import PlyData, PlyElement 12 | 13 | # mitsuba exectuable - EDIT WITH YOUR OWN INSTALLATION PATH 14 | PATH_TO_MITSUBA2 = "/home/tal/mitsuba2/mitsuba2/build/dist/mitsuba" 15 | 16 | # replaced by command line arguments 17 | # PATH_TO_NPY = 'pcl_ex.npy' # the tensor to load 18 | 19 | # note that sampler is changed to 'independent' and the ldrfilm is changed to hdrfilm 20 | xml_head = \ 21 | """ 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | """ 51 | 52 | # I also use a smaller point size 53 | xml_ball_segment = \ 54 | """ 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | """ 65 | 66 | xml_tail = \ 67 | """ 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | """ 87 | 88 | 89 | def colormap(x, y, z): 90 | vec = np.array([x, y, z]) 91 | vec = np.clip(vec, 0.001, 1.0) 92 | norm = np.sqrt(np.sum(vec ** 2)) 93 | vec /= norm 94 | return [vec[0], vec[1], vec[2]] 95 | 96 | 97 | def standardize_bbox(pcl, points_per_object): 98 | pt_indices = np.random.choice(pcl.shape[0], points_per_object, replace=False) 99 | np.random.shuffle(pt_indices) 100 | pcl = pcl[pt_indices] # n by 3 101 | mins = np.amin(pcl, axis=0) 102 | maxs = np.amax(pcl, axis=0) 103 | center = (mins + maxs) / 2. 104 | scale = np.amax(maxs - mins) 105 | print("Center: {}, Scale: {}".format(center, scale)) 106 | result = ((pcl - center) / scale).astype(np.float32) # [-0.5, 0.5] 107 | return result 108 | 109 | 110 | # only for debugging reasons 111 | def writeply(vertices, ply_file): 112 | sv = np.shape(vertices) 113 | points = [] 114 | for v in range(sv[0]): 115 | vertex = vertices[v] 116 | points.append("%f %f %f\n" % (vertex[0], vertex[1], vertex[2])) 117 | print(np.shape(points)) 118 | file = open(ply_file, "w") 119 | file.write('''ply 120 | format ascii 1.0 121 | element vertex %d 122 | property float x 123 | property float y 124 | property float z 125 | end_header 126 | %s 127 | ''' % (len(vertices), "".join(points))) 128 | file.close() 129 | 130 | 131 | # as done in https://gist.github.com/drakeguan/6303065 132 | def ConvertEXRToJPG(exrfile, jpgfile): 133 | File = OpenEXR.InputFile(exrfile) 134 | PixType = Imath.PixelType(Imath.PixelType.FLOAT) 135 | DW = File.header()['dataWindow'] 136 | Size = (DW.max.x - DW.min.x + 1, DW.max.y - DW.min.y + 1) 137 | 138 | rgb = [np.fromstring(File.channel(c, PixType), dtype=np.float32) for c in 'RGB'] 139 | for i in range(3): 140 | rgb[i] = np.where(rgb[i] <= 0.0031308, 141 | (rgb[i] * 12.92) * 255.0, 142 | (1.055 * (np.sign(rgb[i]) * np.abs(rgb[i]) ** (1.0 / 2.4)) - 0.055) * 255.0) 143 | 144 | rgb8 = [Image.frombytes("F", Size, c.tostring()).convert("L") for c in rgb] 145 | # rgb8 = [Image.fromarray(c.astype(int)) for c in rgb] 146 | Image.merge("RGB", rgb8).save(jpgfile, "JPEG", quality=95) 147 | 148 | 149 | def main(argv): 150 | if len(argv) < 2: 151 | print('filename to npy/ply is not passed as argument. terminated.') 152 | return 153 | 154 | pathToFile = argv[1] 155 | 156 | filename, file_extension = os.path.splitext(pathToFile) 157 | folder = os.path.dirname(pathToFile) 158 | print(folder) 159 | filename = os.path.basename(pathToFile) 160 | 161 | # for the moment supports npy and ply 162 | if file_extension == '.npy': 163 | pclTime = np.load(pathToFile) 164 | pclTimeSize = np.shape(pclTime) 165 | elif file_extension == '.npz': 166 | pclTime = np.load(pathToFile) 167 | pclTime = pclTime['pred'] 168 | pclTimeSize = np.shape(pclTime) 169 | elif file_extension == '.ply': 170 | ply = PlyData.read(pathToFile) 171 | vertex = ply['vertex'] 172 | (x, y, z) = (vertex[t] for t in ('x', 'y', 'z')) 173 | pclTime = np.column_stack((x, y, z)) 174 | else: 175 | print('unsupported file format.') 176 | return 177 | 178 | if len(np.shape(pclTime)) < 3: 179 | pclTimeSize = [1, np.shape(pclTime)[0], np.shape(pclTime)[1]] 180 | pclTime.resize(pclTimeSize) 181 | 182 | for pcli in range(0, pclTimeSize[0]): 183 | pcl = pclTime[pcli, :, :] 184 | 185 | pcl = standardize_bbox(pcl, 2048) 186 | # pcl = pcl[:, [2, 0, 1]] 187 | # pcl[:, 0] *= -1 188 | # pcl[:, 2] += 0.0125 189 | pcl[:, 1] += 0.0125 190 | 191 | xml_segments = [xml_head] 192 | for i in range(pcl.shape[0]): 193 | color = colormap(pcl[i, 0] + 0.5, pcl[i, 1] + 0.5, pcl[i, 2] + 0.5 - 0.0125) 194 | xml_segments.append(xml_ball_segment.format(pcl[i, 0], pcl[i, 1], pcl[i, 2], *color)) 195 | xml_segments.append(xml_tail) 196 | 197 | xml_content = str.join('', xml_segments) 198 | 199 | xmlFile = ("%s/%s_%02d.xml" % (folder, filename, pcli)) 200 | 201 | with open(xmlFile, 'w') as f: 202 | f.write(xml_content) 203 | f.close() 204 | 205 | exrFile = ("%s/%s_%02d.exr" % (folder, filename, pcli)) 206 | if not os.path.exists(exrFile): 207 | print(['Running Mitsuba, writing to: ', xmlFile]) 208 | subprocess.run([PATH_TO_MITSUBA2, xmlFile]) 209 | else: 210 | print('skipping rendering because the EXR file already exists') 211 | 212 | png = ("%s/%s_%02d.jpg" % (folder, filename, pcli)) 213 | 214 | print(['Converting EXR to JPG...']) 215 | ConvertEXRToJPG(exrFile, png) 216 | 217 | 218 | if __name__ == "__main__": 219 | main(sys.argv) 220 | --------------------------------------------------------------------------------