├── .gitignore ├── LICENSE ├── README.md ├── dataset ├── __init__.py └── celeba.py ├── infer.py ├── misc ├── lr_scheduler.py ├── ops.py ├── test.png ├── test_celeba.png └── util.py ├── network ├── __init__.py ├── builder.py ├── inferer.py ├── model.py ├── module.py └── trainer.py ├── profile ├── celeba.json └── test.json ├── requirements.txt ├── result ├── interpolated_Attractive.png ├── interpolated_Black_Hair.png ├── interpolated_Blurry.png ├── interpolated_Mouth_Slightly_Open.png └── reconstructed.png ├── setting.py ├── test ├── __init__.py ├── test_dataset.py ├── test_model.py ├── test_module.py ├── test_ops.py └── test_util.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | .idea 106 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Yusu Pan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-glow 2 | PyTorch implementation of ["Glow: Generative Flow with Invertible 1x1 Convolutions"](https://arxiv.org/abs/1807.03039) 3 | 4 | ## Usage 5 | First you need to install all requirements by 6 | ```shell 7 | pip3 install -r requirements.txt 8 | ``` 9 | 10 | ### Training 11 | 1. Prepare dataset and corresponding profile file (like `profile/celeba.json`) 12 | 2. Start training: 13 | ```shell 14 | python3 train.py [profile] 15 | ``` 16 | 17 | ### Inference 18 | ``` 19 | Usage: python3 infer.py [OPTIONS] COMMAND [ARGS]... 20 | 21 | Options: 22 | --profile PATH 23 | --snapshot PATH 24 | --help Show this message and exit. 25 | 26 | Commands: 27 | compute_deltaz 28 | interpolate 29 | reconstruct 30 | sample 31 | ``` 32 | 33 | ## Result 34 | ### Reconstruction 35 | The upper is reconstructed image, the lower is original one. 36 | ![reconstructed result](result/reconstructed.png) 37 | 38 | ### Interpolation 39 | From left to right, the attribute offset is from `-1.0` to `1.0` by step `0.25`. 40 | 41 | - Attractive 42 | 43 | ![interpolation_attractive](result/interpolated_Attractive.png) 44 | 45 | - Black hair 46 | 47 | ![interpolation_black_hair](result/interpolated_Black_Hair.png) 48 | 49 | - Blurry 50 | 51 | ![interpolation_blurry](result/interpolated_Blurry.png) 52 | 53 | - Mouth slightly open 54 | 55 | ![interpolation_mouth_slightly_open](result/interpolated_Mouth_Slightly_Open.png) 56 | 57 | ## Acknowledgement 58 | This project refers to: 59 | - [openai/glow](https://github.com/openai/glow) (Official implementation) 60 | - [chaiyujin/glow-pytorch](https://github.com/chaiyujin/glow-pytorch) 61 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .celeba import CelebA 2 | 3 | __all__ = ( 4 | CelebA 5 | ) 6 | -------------------------------------------------------------------------------- /dataset/celeba.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | from misc import util 8 | 9 | 10 | class CelebA(Dataset): 11 | def __init__(self, root, 12 | image_dir='images', 13 | anno_file='list_attr_celeba.txt', 14 | transform=None): 15 | """ 16 | CelebA dataset 17 | 18 | :param root: path to dataset root 19 | :type root: str 20 | :param anno_file: fielname of annotation file 21 | :type anno_file: str 22 | :param transform: desired transformation for image 23 | """ 24 | super().__init__() 25 | assert os.path.isdir(root), 'Dataset dirctory not exists: {}'.format(root) 26 | self.root = root 27 | self.image_dir = image_dir 28 | self.anno_file = anno_file 29 | self.transform = transform 30 | 31 | self.data, self.attrs = self.parse_anno_file() 32 | 33 | def parse_anno_file(self): 34 | """ 35 | Parse annotation file 36 | 37 | :return: image data and attributes 38 | :rtype: dict, list 39 | """ 40 | if os.path.exists(self.anno_file): 41 | anno_path = self.anno_file 42 | elif os.path.exists(os.path.join(self.root, self.anno_file)): 43 | anno_path = os.path.join(self.root, self.anno_file) 44 | else: 45 | raise FileNotFoundError('Annotation file of dataset not exists: {}'.format(self.anno_file)) 46 | 47 | data = [] 48 | attrs = None 49 | num_images = 0 50 | with open(anno_path, 'r') as f: 51 | for idx, line in enumerate(f): 52 | line = line.strip() 53 | if idx == 0: 54 | num_images = int(line) 55 | elif idx == 1: 56 | attrs = line.split(' ') 57 | else: 58 | elements = [e for e in line.split(' ') if e] 59 | image_path = os.path.join(self.root, self.image_dir, elements[0]) 60 | image_attr = elements[1:] 61 | if not os.path.exists(image_path) or not util.is_image(image_path): 62 | continue 63 | # 0 for -1 and 1 for 1 64 | image_onehot = [int(int(attr)) for attr in image_attr] 65 | data.append({ 66 | 'path': image_path, 67 | 'attr': image_onehot 68 | }) 69 | print('[Dataset] CelebA: Expect {} images with {} attributes.'.format(num_images, len(attrs))) 70 | print('[Dataset] CelebA: Find {} images with {} attributes.'.format(len(data), len(data[-1]['attr']))) 71 | 72 | return data, attrs 73 | 74 | def __getitem__(self, index): 75 | data = self.data[index] 76 | image_path = data['path'] 77 | image_attr = data['attr'] 78 | 79 | image = Image.open(image_path).convert('RGB') 80 | if self.transform is not None: 81 | image = self.transform(image) 82 | 83 | return { 84 | 'x': image, 85 | 'y_onehot': np.array(image_attr, dtype='float32') 86 | } 87 | 88 | def __len__(self): 89 | return len(self.data) 90 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import click 4 | import signal 5 | import torch 6 | 7 | from PIL import Image 8 | from tqdm import tqdm 9 | from torchvision import transforms 10 | from torchvision.utils import make_grid 11 | 12 | from misc import util 13 | from network import Builder, Inferer 14 | from dataset import CelebA 15 | 16 | 17 | @click.group(name='Inference for glow model') 18 | @click.option('--profile', type=click.Path(exists=True)) 19 | @click.option('--snapshot', type=click.Path(exists=True)) 20 | @click.pass_context 21 | def cli(ctx, profile, snapshot): 22 | # load hyper-parameters 23 | hps = util.load_profile(profile) 24 | util.manual_seed(hps.ablation.seed) 25 | if snapshot is not None: 26 | hps.general.warm_start = True 27 | hps.general.pre_trained = snapshot 28 | 29 | # build graph 30 | builder = Builder(hps) 31 | state = builder.build(training=False) 32 | 33 | # load dataset 34 | dataset = CelebA(root=hps.dataset.root, 35 | transform=transforms.Compose(( 36 | transforms.CenterCrop(160), 37 | transforms.Resize(64), 38 | transforms.ToTensor() 39 | ))) 40 | 41 | # start inference 42 | inferer = Inferer( 43 | hps=hps, 44 | graph=state['graph'], 45 | devices=state['devices'], 46 | data_device=state['data_device'] 47 | ) 48 | ctx.obj['hps'] = hps 49 | ctx.obj['dataset'] = dataset 50 | ctx.obj['inferer'] = inferer 51 | 52 | 53 | @cli.command() 54 | @click.pass_context 55 | def sample(ctx): 56 | hps = ctx.obj['hps'] 57 | inferer = ctx.obj['inferer'] 58 | # smaple 59 | img = inferer.sample(z=None, y_onehot=None, eps_std=0.5) 60 | # save result 61 | result_subdir = util.create_result_subdir(hps.general.result_dir, 62 | desc='sample', 63 | profile=hps) 64 | util.tensor_to_pil(img).save(os.path.join(result_subdir, 'sample.png')) 65 | 66 | 67 | @cli.command() 68 | @click.pass_context 69 | def compute_deltaz(ctx): 70 | hps = ctx.obj['hps'] 71 | inferer = ctx.obj['inferer'] 72 | dataset = ctx.obj['dataset'] 73 | 74 | # compute delta 75 | deltaz = inferer.compute_attribute_delta(dataset) 76 | 77 | # save result 78 | result_subdir = util.create_result_subdir(hps.general.result_dir, 79 | desc='deltaz', 80 | profile=hps) 81 | util.save_deltaz(deltaz, result_subdir) 82 | 83 | 84 | @cli.command() 85 | @click.argument('image_path', type=click.Path(exists=True)) 86 | @click.pass_context 87 | def reconstruct(ctx, image_path): 88 | hps = ctx.obj['hps'] 89 | inferer = ctx.obj['inferer'] 90 | 91 | # get image list 92 | img_list = [] 93 | if os.path.isfile(image_path) and util.is_image(image_path): 94 | img_list = [image_path] 95 | elif os.path.isdir(image_path): 96 | img_list = [os.path.join(image_path, f) 97 | for f in os.listdir(image_path) 98 | if util.is_image(os.path.join(image_path, f))] 99 | 100 | # reconstruct images 101 | img_grid_list = [] 102 | util.check_path('reconstructed') 103 | for img_path in img_list: 104 | img = Image.open(img_path).convert('RGB') 105 | x = util.pil_to_tensor(img, 106 | transform=transforms.Compose(( 107 | transforms.CenterCrop(160), 108 | transforms.Resize(64), 109 | transforms.ToTensor() 110 | ))) 111 | z = inferer.encode(x) 112 | x_ = inferer.decode(z) 113 | img_grid = torch.cat((x, x_.cpu()), dim=1) 114 | img_grid_list.append(img_grid) 115 | # util.tensor_to_pil(img_grid).save('reconstructed/{}'.format(os.path.basename(img_path))) 116 | 117 | # generate grid of reconstructed images 118 | imgs_grid = make_grid(torch.stack(img_grid_list)) 119 | 120 | # save result 121 | result_subdir = util.create_result_subdir(hps.general.result_dir, 122 | desc='reconstruct', 123 | profile=hps) 124 | util.tensor_to_pil(imgs_grid).save(os.path.join(result_subdir, 'grid.png')) 125 | 126 | 127 | @cli.command() 128 | @click.argument('delta_file', type=click.Path(exists=True)) 129 | @click.argument('image_file', type=click.Path(exists=True, dir_okay=False)) 130 | @click.option('--batch', is_flag=True, default=True) 131 | @click.pass_context 132 | def interpolate(ctx, delta_file, image_file, batch): 133 | hps = ctx.obj['hps'] 134 | inferer = ctx.obj['inferer'] 135 | dataset = ctx.obj['dataset'] 136 | 137 | img = Image.open(image_file).convert('RGB') 138 | deltaz = util.load_deltaz(delta_file) 139 | result_subdir = util.create_result_subdir(hps.general.result_dir, 140 | desc='interpolation', 141 | profile=hps) 142 | 143 | if batch: 144 | interpolation_vector = util.make_interpolation_vector(hps.dataset.num_classes) 145 | for cls in range(interpolation_vector.shape[0]): 146 | print('[Inferer] interpolating class "{}"'.format(dataset.attrs[cls])) 147 | imgs_interpolated = [] 148 | progress = tqdm(range(interpolation_vector.shape[1])) 149 | for lv in progress: 150 | img_interpolated = inferer.apply_attribute_delta( 151 | img, deltaz, 152 | interpolation_vector[cls, lv, :]) 153 | imgs_interpolated.append(img_interpolated) 154 | # img_interpolated = util.tensor_to_pil(img_interpolated) 155 | # img_interpolated.save('interpolation/interpolated_{:s}_{:0.2f}.png'.format( 156 | # dataset.attrs[cls], 157 | # interpolation_vector[cls, lv, cls])) 158 | imgs_stacked = torch.stack(imgs_interpolated) 159 | imgs_grid = make_grid(imgs_stacked, nrow=interpolation_vector.shape[1]) 160 | imgs = util.tensor_to_pil(imgs_grid) 161 | imgs.save(os.path.join(result_subdir, 162 | 'interpolated_{:s}.png'.format(dataset.attrs[cls]))) 163 | else: 164 | interpolation = [0.] * hps.dataset.num_classes 165 | interpolation[0] = 1. 166 | img_interpolated = inferer.apply_attribute_delta(img, deltaz, interpolation) 167 | img_interpolated = util.tensor_to_pil(img_interpolated) 168 | img_interpolated.save(os.path.join(result_subdir, 169 | 'interpolated.png')) 170 | 171 | 172 | if __name__ == '__main__': 173 | # this enables a Ctrl-C without triggering errors 174 | signal.signal(signal.SIGINT, lambda x, y: sys.exit(0)) 175 | 176 | # initialize logging 177 | util.init_output_logging() 178 | 179 | # command group 180 | cli(obj={}) 181 | -------------------------------------------------------------------------------- /misc/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def constant(base_lr, global_step): 5 | """ 6 | Return constant learning rate 7 | 8 | :param base_lr: base learning rate 9 | :type base_lr: float 10 | :param global_step: global training steps 11 | :type global_step: int 12 | :return: scheduled learning rate 13 | :rtype: float 14 | """ 15 | return base_lr 16 | 17 | 18 | def noam_decay(base_lr, global_step, warmup_steps=4000, min_lr=1e-4): 19 | """ 20 | Noam learning rate decay (from section 5.3 of Attention is all you need) 21 | 22 | :param base_lr: base learning rate 23 | :type base_lr: float 24 | :param global_step: global training steps 25 | :type global_step: int 26 | :param warmup_steps: number of steps for warming up 27 | :type warmup_steps: int 28 | :param min_lr: minimum learning rate 29 | :type min_lr: float 30 | :return: scheduled learning rate 31 | :rtype: float 32 | """ 33 | step_num = global_step + 1. 34 | lr = base_lr * warmup_steps ** 0.5 * np.minimum(step_num ** -0.5, step_num * float(warmup_steps) ** -1.5) 35 | if global_step >= warmup_steps: 36 | lr = max(min_lr, lr) 37 | return lr 38 | 39 | 40 | def linear_anneal(base_lr, global_step, num_train, warmup_steps=10): 41 | """ 42 | Linearly annealed learning rate from 0 in the first warming up epochs. 43 | 44 | :param base_lr: base learning rate 45 | :type base_lr: float 46 | :param global_step: global training steps 47 | :type global_step: int 48 | :param num_train: 49 | :type num_train: 50 | :param warmup_steps: number of steps for warming up 51 | :type warmup_steps: int 52 | :return: scheduled learning rate 53 | :rtype: float 54 | """ 55 | lr = base_lr * np.minimum(1., global_step / (num_train * warmup_steps)) 56 | return lr 57 | 58 | 59 | def step_anneal(base_lr, global_step, 60 | anneal_rate=0.98, 61 | anneal_interval=30000): 62 | """ 63 | Annealing learning rate by steps 64 | 65 | :param base_lr: base learning rate 66 | :type base_lr: float 67 | :param global_step: global training steps 68 | :type global_step: int 69 | :param anneal_rate: rate of annealing 70 | :type anneal_rate: float 71 | :param anneal_interval: interval steps of annealing 72 | :type anneal_interval: int 73 | :return: scheduled learning rate 74 | :rtype: float 75 | """ 76 | 77 | lr = base_lr * anneal_rate ** (global_step // anneal_interval) 78 | return lr 79 | 80 | 81 | def cyclic_cosine_anneal(base_lr, global_step, t, m): 82 | """ 83 | Cyclic cosine annealing (from section 3 of SNAPSHOT ENSEMBLES: TRAIN 1, GET M FOR FREE) 84 | 85 | :param base_lr: base learning rate 86 | :type base_lr: float 87 | :param global_step: global training steps 88 | :type global_step: int 89 | :param t: total number of epochs 90 | :type t: int 91 | :param m: number of ensembles we want 92 | :type m: int 93 | :return: scheduled learning rate 94 | :rtype: float 95 | """ 96 | lr = (base_lr / 2.) * (np.cos(np.pi * ((global_step - 1) % (t // m)) / (t // m)) + 1.) 97 | return lr 98 | -------------------------------------------------------------------------------- /misc/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def reduce_mean(tensor, dim=None, keepdim=False, out=None): 5 | """ 6 | Returns the mean value of each row of the input tensor in the given dimension dim. 7 | 8 | Support multi-dim mean 9 | 10 | :param tensor: the input tensor 11 | :type tensor: torch.Tensor 12 | :param dim: the dimension to reduce 13 | :type dim: int or list[int] 14 | :param keepdim: whether the output tensor has dim retained or not 15 | :type keepdim: bool 16 | :param out: the output tensor 17 | :type out: torch.Tensor 18 | :return: mean result 19 | :rtype: torch.Tensor 20 | """ 21 | # mean all dims 22 | if dim is None: 23 | return torch.mean(tensor) 24 | # prepare dim 25 | if isinstance(dim, int): 26 | dim = [dim] 27 | dim = sorted(dim) 28 | # get mean dim by dim 29 | for d in dim: 30 | tensor = tensor.mean(dim=d, keepdim=True) 31 | # squeeze reduced dimensions if not keeping dim 32 | if not keepdim: 33 | for cnt, d in enumerate(dim): 34 | tensor.squeeze_(d - cnt) 35 | if out is not None: 36 | out.copy_(tensor) 37 | return tensor 38 | 39 | 40 | def reduce_sum(tensor, dim=None, keepdim=False, out=None): 41 | """ 42 | Returns the sum of all elements in the input tensor. 43 | 44 | Support multi-dim sum 45 | 46 | :param tensor: the input tensor 47 | :type tensor: torch.Tensor 48 | :param dim: the dimension to reduce 49 | :type dim: int or list[int] 50 | :param keepdim: whether the output tensor has dim retained or not 51 | :type keepdim: bool 52 | :param out: the output tensor 53 | :type out: torch.Tensor 54 | :return: sum result 55 | :rtype: torch.Tensor 56 | """ 57 | # summarize all dims 58 | if dim is None: 59 | return torch.sum(tensor) 60 | # prepare dim 61 | if isinstance(dim, int): 62 | dim = [dim] 63 | dim = sorted(dim) 64 | # get summary dim by dim 65 | for d in dim: 66 | tensor = tensor.sum(dim=d, keepdim=True) 67 | # squeeze reduced dimensions if not keeping dim 68 | if not keepdim: 69 | for cnt, d in enumerate(dim): 70 | tensor.squeeze_(d - cnt) 71 | if out is not None: 72 | out.copy_(tensor) 73 | return tensor 74 | 75 | 76 | def tensor_equal(a, b, eps=1e-6): 77 | """ 78 | Compare two tensors 79 | 80 | :param a: input tensor a 81 | :type a: torch.Tensor 82 | :param b: input tensor b 83 | :type b: torch.Tensor 84 | :param eps: epsilon 85 | :type eps: float 86 | :return: whether two tensors are equal 87 | :rtype: bool 88 | """ 89 | if a.shape != b.shape: 90 | return False 91 | 92 | return 0 <= float(torch.max(torch.abs(a - b))) <= eps 93 | 94 | 95 | def split_channel(tensor, split_type='simple'): 96 | """ 97 | Split channels of tensor 98 | 99 | :param tensor: input tensor 100 | :type tensor: torch.Tensor 101 | :param split_type: type of splitting 102 | :type split_type: str 103 | :return: split tensor 104 | :rtype: tuple(torch.Tensor, torch.Tensor) 105 | """ 106 | assert len(tensor.shape) == 4 107 | assert split_type in ['simple', 'cross'] 108 | 109 | nc = tensor.shape[1] 110 | if split_type == 'simple': 111 | return tensor[:, :nc // 2, ...], tensor[:, nc // 2:, ...] 112 | elif split_type == 'cross': 113 | return tensor[:, 0::2, ...], tensor[:, 1::2, ...] 114 | 115 | 116 | def cat_channel(a, b): 117 | """ 118 | Concatenates channels of tensors 119 | 120 | :param a: input tensor a 121 | :type a: torch.Tensor 122 | :param b: input tensor b 123 | :type b: torch.Tensor 124 | :return: concatenated tensor 125 | :rtype: torch.Tensor 126 | """ 127 | return torch.cat((a, b), dim=1) 128 | 129 | 130 | def count_pixels(tensor): 131 | """ 132 | Count number of pixels in given tensor 133 | 134 | :param tensor: input tensor 135 | :type tensor: torch.Tensor 136 | :return: number of pixels 137 | :rtype: int 138 | """ 139 | assert len(tensor.shape) == 4 140 | return int(tensor.shape[2] * tensor.shape[3]) 141 | 142 | 143 | def onehot(y, num_classes): 144 | """ 145 | Generate one-hot vector 146 | 147 | :param y: ground truth labels 148 | :type y: torch.Tensor 149 | :param num_classes: number os classes 150 | :type num_classes: int 151 | :return: one-hot vector generated from labels 152 | :rtype: torch.Tensor 153 | """ 154 | assert len(y.shape) in [1, 2], "Label y should be 1D or 2D vector" 155 | y_onehot = torch.zeros(y.shape[0], num_classes, device=y.device) 156 | if len(y.shape) == 1: 157 | y_onehot = y_onehot.scatter_(1, y.unsqueeze(-1), 1) 158 | else: 159 | y_onehot = y_onehot.scatter_(1, y, 1) 160 | return y_onehot 161 | -------------------------------------------------------------------------------- /misc/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/corenel/pytorch-glow/dfd6db37da093d7741585ce1d046b2d3c3ac99bd/misc/test.png -------------------------------------------------------------------------------- /misc/test_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/corenel/pytorch-glow/dfd6db37da093d7741585ce1d046b2d3c3ac99bd/misc/test_celeba.png -------------------------------------------------------------------------------- /misc/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import cv2 4 | import sys 5 | import glob 6 | import json 7 | import shutil 8 | import numpy as np 9 | import torch 10 | 11 | from PIL import Image 12 | from easydict import EasyDict 13 | from torchvision.transforms import transforms 14 | 15 | 16 | # Profile 17 | 18 | def load_profile(filepath): 19 | """ 20 | Load experiment profile as EasyDict 21 | 22 | :param filepath: path to profile 23 | :type filepath: str 24 | :return: hyper-parameters 25 | :rtype: EasyDict 26 | """ 27 | if os.path.exists(filepath): 28 | with open(filepath) as f: 29 | return EasyDict(json.load(f)) 30 | 31 | 32 | # Device 33 | 34 | def get_devices(devices, verbose=True): 35 | """ 36 | Get devices for running model 37 | 38 | :param devices: list of devices from profile 39 | :type devices: list 40 | :param verbose: print log 41 | :type verbose: bool 42 | :return: list of usable devices according to desired and available hardware 43 | :rtype: list[str] 44 | """ 45 | 46 | def parse_cuda_device(device): 47 | """ 48 | Parse device into device id 49 | 50 | :param device: given device 51 | :type device: str or int 52 | :return: device id 53 | :rtype: int 54 | """ 55 | origin = str(device) 56 | if isinstance(device, str) and re.search('cuda:([\d]+)', device): 57 | device = int(re.findall('cuda:([\d]+)', device)[0]) 58 | if isinstance(device, int): 59 | if 0 <= device <= torch.cuda.device_count() - 1: 60 | return device 61 | _print('[Builder] Incorrect device "{}"'.format(origin), verbose=verbose) 62 | return 63 | 64 | use_cpu = any([d.find('cpu') >= 0 for d in devices]) 65 | use_cuda = any([(d.find('cuda') >= 0 or isinstance(d, int)) for d in devices]) 66 | assert not (use_cpu and use_cuda), 'CPU and GPU cannot be mixed.' 67 | 68 | if use_cuda: 69 | devices = [parse_cuda_device(d) for d in devices] 70 | devices = [d for d in devices if d is not None] 71 | if len(devices) == 0: 72 | _print('[Builder] No available GPU found, use CPU only', verbose=verbose) 73 | devices = ['cpu'] 74 | 75 | return devices 76 | 77 | 78 | # Logger 79 | 80 | class OutputLogger(object): 81 | """Output logger""" 82 | 83 | def __init__(self): 84 | self.file = None 85 | self.buffer = '' 86 | 87 | def set_log_file(self, filename, mode='wt'): 88 | assert self.file is None 89 | self.file = open(filename, mode) 90 | if self.buffer is not None: 91 | self.file.write(self.buffer) 92 | self.buffer = None 93 | 94 | def write(self, data): 95 | if self.file is not None: 96 | self.file.write(data) 97 | if self.buffer is not None: 98 | self.buffer += data 99 | 100 | def flush(self): 101 | if self.file is not None: 102 | self.file.flush() 103 | 104 | 105 | class TeeOutputStream(object): 106 | """Redirect output stream""" 107 | 108 | def __init__(self, child_streams, autoflush=False): 109 | self.child_streams = child_streams 110 | self.autoflush = autoflush 111 | 112 | def write(self, data): 113 | if isinstance(data, bytes): 114 | data = data.decode('utf-8') 115 | for stream in self.child_streams: 116 | stream.write(data) 117 | if self.autoflush: 118 | self.flush() 119 | 120 | def flush(self): 121 | for stream in self.child_streams: 122 | stream.flush() 123 | 124 | 125 | output_logger = None 126 | 127 | 128 | def init_output_logging(): 129 | """ 130 | Initialize output logger 131 | """ 132 | global output_logger 133 | if output_logger is None: 134 | output_logger = OutputLogger() 135 | sys.stdout = TeeOutputStream([sys.stdout, output_logger], autoflush=True) 136 | sys.stderr = TeeOutputStream([sys.stderr, output_logger], autoflush=True) 137 | 138 | 139 | def set_output_log_file(filename, mode='wt'): 140 | """ 141 | Set file name of output log 142 | 143 | :param filename: file name of log 144 | :type filename: str 145 | :param mode: the mode in which the file is opened 146 | :type mode: str 147 | """ 148 | if output_logger is not None: 149 | output_logger.set_log_file(filename, mode) 150 | 151 | 152 | # Result directory 153 | 154 | def create_result_subdir(result_dir, desc, profile): 155 | """ 156 | Create and initialize result sub-directory 157 | 158 | :param result_dir: path to root of result directory 159 | :type result_dir: str 160 | :param desc: description of current experiment 161 | :type desc: str 162 | :param profile: profile 163 | :type profile: dict 164 | :return: path to result sub-directory 165 | :rtype: str 166 | """ 167 | # determine run id 168 | run_id = 0 169 | for fname in glob.glob(os.path.join(result_dir, '*')): 170 | fbase = os.path.basename(fname) 171 | finds = re.findall('^([\d]+)-', fbase) 172 | if len(finds) != 0: 173 | ford = int(finds[0]) 174 | run_id = max(run_id, ford + 1) 175 | 176 | # create result sub-directory 177 | result_subdir = os.path.join(result_dir, '{:03d}-{:s}'.format(run_id, desc)) 178 | if not os.path.exists(result_subdir): 179 | os.makedirs(result_subdir) 180 | set_output_log_file(os.path.join(result_subdir, 'log.txt')) 181 | print("[Builder] Saving results to {}".format(result_subdir)) 182 | 183 | # export profile 184 | with open(os.path.join(result_subdir, 'config.json'), 'w') as f: 185 | json.dump(profile, f) 186 | 187 | return result_subdir 188 | 189 | 190 | def locate_result_subdir(result_dir, run_id_or_result_subdir): 191 | """ 192 | Locate result subdir by given run id or path 193 | 194 | :param result_dir: path to root of result directory 195 | :type result_dir: str 196 | :param run_id_or_result_subdir: run id or subdir path 197 | :type run_id_or_result_subdir: int or str 198 | :return: located result subdir 199 | :rtype: str 200 | """ 201 | if isinstance(run_id_or_result_subdir, str) and os.path.isdir(run_id_or_result_subdir): 202 | return run_id_or_result_subdir 203 | 204 | searchdirs = ['', 'results', 'networks'] 205 | 206 | for searchdir in searchdirs: 207 | d = result_dir if searchdir == '' else os.path.join(result_dir, searchdir) 208 | # search directly by name 209 | d = os.path.join(d, str(run_id_or_result_subdir)) 210 | if os.path.isdir(d): 211 | return d 212 | # search by prefix 213 | if isinstance(run_id_or_result_subdir, int): 214 | prefix = '{:03d}'.format(run_id_or_result_subdir) 215 | else: 216 | prefix = str(run_id_or_result_subdir) 217 | dirs = sorted(glob.glob(os.path.join(result_dir, searchdir, prefix + '-*'))) 218 | dirs = [d for d in dirs if os.path.isdir(d)] 219 | if len(dirs) == 1: 220 | return dirs[0] 221 | print('[Builder] Cannot locate result subdir for run: {}'.format(run_id_or_result_subdir)) 222 | return None 223 | 224 | 225 | def format_time(seconds): 226 | """ 227 | Format seconds into desired format 228 | 229 | :param seconds: number of seconds 230 | :type seconds: float 231 | :return: formatted time 232 | :rtype: str 233 | """ 234 | s = int(np.rint(seconds)) 235 | if s < 60: 236 | return '{:d}s'.format(s) 237 | elif s < 60 * 60: 238 | return '{:d}m {:02d}s'.format(s // 60, s % 60) 239 | elif s < 24 * 60 * 60: 240 | return '{:d}h {:02d}m {:02}ds'.format(s // (60 * 60), (s // 60) % 60, s % 60) 241 | else: 242 | return '{:d}d {:02d}h {:02d}m'.format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) 243 | 244 | 245 | # Model 246 | 247 | def get_model_name(step): 248 | """ 249 | Return filename of model snapshot by step 250 | 251 | :param step: global step of model 252 | :type step: int 253 | :return: model snapshot file name 254 | :rtype: str 255 | """ 256 | return 'network-snapshot-{:06d}.pth'.format(step) 257 | 258 | 259 | def get_best_model_name(): 260 | """ 261 | Return filename of best model snapshot by step 262 | 263 | :return: filename of best model snapshot 264 | :rtype: str 265 | """ 266 | return 'network-snapshot-best.pth' 267 | 268 | 269 | def get_last_model_name(result_subdir): 270 | """ 271 | Return filename of best model snapshot by step 272 | 273 | :param result_subdir: path to result sub-directory 274 | :type result_subdir: str 275 | :return: filename of last model snapshot 276 | :rtype: str 277 | """ 278 | latest = -1 279 | for f in os.listdir(result_subdir): 280 | if os.path.isfile(os.path.join(result_subdir, f)) and \ 281 | re.search('network-snapshot-([\d]+).pth', f): 282 | f_step = int(re.findall('network-snapshot-([\d]+).pth', f)[0]) 283 | if latest < f_step: 284 | latest = f_step 285 | 286 | return get_model_name(latest) 287 | 288 | 289 | def save_model(result_subdir, step, graph, optimizer, seconds, is_best, criterion_dict=None): 290 | """ 291 | Save model snapshot to result subdir 292 | 293 | :param result_subdir: path to result sub-directory 294 | :type result_subdir: str 295 | :param step: global step of model 296 | :type step: int 297 | :param graph: model graph 298 | :type graph: torch.nn.Module 299 | :param optimizer: optimizer 300 | :type optimizer: torch.optim.Optimizer 301 | :param seconds: seconds of running time 302 | :type seconds: float 303 | :param is_best: whether this model is best 304 | :type is_best: bool 305 | :param criterion_dict: dict of criterion 306 | :type criterion_dict: dict 307 | """ 308 | # construct state 309 | state = { 310 | 'step': step, 311 | # DataParallel wraps model in `module` attribute. 312 | 'graph': graph.module.state_dict() if hasattr(graph, "module") else graph.state_dict(), 313 | 'optimizer': optimizer.state_dict(), 314 | 'criterion': {}, 315 | 'seconds': seconds 316 | } 317 | if criterion_dict is not None: 318 | state['criterion'] = {k: v.state_dict() for k, v in criterion_dict.items()} 319 | 320 | # save current state 321 | save_path = os.path.join(result_subdir, get_model_name(step)) 322 | torch.save(state, save_path) 323 | 324 | # save best state 325 | if is_best: 326 | best_path = os.path.join(result_subdir, get_best_model_name()) 327 | shutil.copy(save_path, best_path) 328 | 329 | 330 | def load_model(result_subdir, step_or_model_path, graph, optimizer=None, criterion_dict=None, device=None): 331 | """ 332 | lOad model snapshot from esult subdir 333 | 334 | :param result_subdir: path to result sub-directory 335 | :type result_subdir: str 336 | :param step_or_model_path: step or model path 337 | :type step_or_model_path: int or str 338 | :param graph: model graph 339 | :type graph: torch.nn.Module 340 | :param optimizer: optimizer 341 | :type optimizer: torch.optim.Optimizer 342 | :param criterion_dict: dict of criterion 343 | :type criterion_dict: dict 344 | :param device: device to run mode 345 | :type device: str 346 | :return: state 347 | :rtype: dict 348 | """ 349 | # check existence of model file 350 | model_path = step_or_model_path 351 | if isinstance(step_or_model_path, int): 352 | model_path = get_model_name(step_or_model_path) 353 | if step_or_model_path == 'best': 354 | model_path = get_best_model_name() 355 | if step_or_model_path == 'latest': 356 | model_path = None 357 | if not os.path.exists(model_path): 358 | model_path = os.path.join(result_subdir, model_path) 359 | if not os.path.exists(model_path): 360 | raise FileNotFoundError('Failed to find model snapshot with {}'.format(step_or_model_path)) 361 | 362 | # load model snapshot 363 | if isinstance(device, int): 364 | device = 'cuda:{}'.format(device) 365 | state = torch.load(model_path, map_location=device) 366 | step = state['step'] 367 | graph.load_state_dict(state['graph']) 368 | graph.set_actnorm_inited() 369 | if optimizer is not None: 370 | optimizer.load_state_dict(state['optimizer']) 371 | if criterion_dict is not None: 372 | for k in criterion_dict.keys(): 373 | criterion_dict[k].load_state_dict(state['criterion'][k]) 374 | print('[Builder] Load model snapshot successfully from {}'.format(model_path)) 375 | 376 | return state 377 | 378 | 379 | # Dataset 380 | 381 | def is_image(filepath): 382 | """ 383 | Determine whether file is an image or not 384 | 385 | :param filepath: file path 386 | :type filepath: str 387 | :return: whether file is an image 388 | :rtype: bool 389 | """ 390 | image_extensions = ['.png', '.jpg', '.jpeg'] 391 | basename = os.path.basename(filepath) 392 | _, extension = os.path.splitext(basename) 393 | return extension.lower() in image_extensions 394 | 395 | 396 | def tensor_to_ndarray(tensor): 397 | """ 398 | Convert float tensor into numpy image 399 | 400 | :param tensor: input tensor 401 | :type tensor: torch.Tensor 402 | :return: numpy image 403 | :rtype: np.ndarray 404 | """ 405 | tensor_np = tensor.permute(1, 2, 0).cpu().numpy() 406 | tensor_np = tensor_np.astype(np.float32) 407 | tensor_np = (tensor_np * 255).astype(np.uint8) 408 | return tensor_np 409 | 410 | 411 | def tensor_to_pil(tensor): 412 | """ 413 | Convert float tensor into PIL image 414 | 415 | :param tensor: input tensor 416 | :type tensor: torch.Tensor 417 | :return: PIL image 418 | :rtype: Image.Image 419 | """ 420 | transform = transforms.ToPILImage() 421 | tensor = tensor.cpu() 422 | return transform(tensor) 423 | 424 | 425 | def ndarray_to_tensor(img, shape=(64, 64, 3), bgr2rgb=True): 426 | """ 427 | Convert numpy image to float tensor 428 | 429 | :param img: numpy image 430 | :type img: np.ndarray 431 | :param shape: image shape in (H, W, C) 432 | :type shape: tuple or list 433 | :param bgr2rgb: convert color space from BGR to RGB 434 | :type bgr2rgb: bool 435 | :return: tensor 436 | :rtype: torch.Tensor 437 | """ 438 | if bgr2rgb: 439 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 440 | img = cv2.resize(img, (shape[0], shape[1])) 441 | img = (img / 255.0).astype(np.float32) 442 | img = torch.Tensor(img).permute(2, 0, 1) 443 | return img 444 | 445 | 446 | def pil_to_tensor(img, shape=(64, 64, 3), transform=None): 447 | """ 448 | Convert PIL image to float tensor 449 | 450 | :param img: PIL image 451 | :type img: Image.Image 452 | :param shape: image shape in (H, W, C) 453 | :type shape: tuple or list 454 | :param transform: image transform 455 | :return: tensor 456 | :rtype: torch.Tensor 457 | """ 458 | if transform is None: 459 | transform = transforms.Compose(( 460 | transforms.Resize(shape[0]), 461 | transforms.ToTensor() 462 | )) 463 | return transform(img) 464 | 465 | 466 | def image_to_tensor(img, shape=(64, 64, 3), bgr2rgb=True): 467 | """ 468 | Convert image to torch tensor 469 | 470 | :param img: image 471 | :type img: Image.Image or np.ndarray 472 | :param shape: image shape in (H, W, C) 473 | :type shape: tuple or list 474 | :param bgr2rgb: convert color space from BGR to RGB 475 | :type bgr2rgb: bool 476 | :return: image tensor 477 | :rtype: torch.Tensor 478 | """ 479 | if isinstance(img, Image.Image): 480 | return pil_to_tensor(img, shape) 481 | if isinstance(np.ndarray, img): 482 | return ndarray_to_tensor(img, shape, bgr2rgb) 483 | else: 484 | raise NotImplementedError('Unsupported image type: {}'.format(type(img))) 485 | 486 | 487 | def save_deltaz(deltaz, save_dir): 488 | """ 489 | Save deltaz as numpy 490 | 491 | :param deltaz: delta vector of attributes in latent space 492 | :type deltaz: np.ndarray 493 | :param save_dir: directory to save 494 | :type save_dir: str 495 | """ 496 | check_path(save_dir) 497 | np.save(os.path.join(save_dir, 'deltaz.npy'), deltaz) 498 | 499 | 500 | def load_deltaz(path): 501 | """ 502 | Load deltaz as numpy 503 | 504 | :param path: path to numpy file 505 | :type path: str 506 | :return: delta vector of attributes in latent space 507 | :rtype: np.ndarray 508 | """ 509 | if os.path.exists(path): 510 | return np.load(path) 511 | 512 | 513 | # Misc 514 | 515 | def manual_seed(seed): 516 | """ 517 | Set manual random seed 518 | 519 | :param seed: random seed 520 | :type seed: int 521 | """ 522 | np.random.seed(seed) 523 | torch.manual_seed(seed) 524 | # torch.cuda.manual_seed_all(seed) 525 | 526 | 527 | def _print(*args, verbose=True, **kwargs): 528 | """ 529 | Print with condition 530 | 531 | :param verbose: whether to verbose or not 532 | :type verbose: bool 533 | """ 534 | if verbose: 535 | print(*args, **kwargs) 536 | 537 | 538 | def check_path(path): 539 | """ 540 | Check existence of directory path. If not, then create it. 541 | 542 | :param path: path to directory 543 | :type path: str 544 | """ 545 | if not os.path.exists(path): 546 | os.makedirs(path) 547 | 548 | 549 | def make_batch(tensor, batch_size): 550 | """ 551 | Generate fake batch 552 | 553 | :param tensor: input tensor 554 | :type tensor: torch.Tensor 555 | :param batch_size: batch size 556 | :type batch_size: int 557 | :return: fake batch 558 | :rtype: torch.Tensor 559 | """ 560 | assert len(tensor.shape) == 3, 'Assume 3D input tensor' 561 | return tensor.unsqueeze(0).repeat(batch_size, 1, 1, 1) 562 | 563 | 564 | def make_interpolation_vector(num_classes, step=0.25, 565 | minimum=-1., maximum=1.): 566 | """ 567 | Generate interpolation vector 568 | 569 | :param num_classes: number of classes 570 | :type num_classes: int 571 | :param step: increasing step 572 | :type step: float 573 | :param minimum: minimum value 574 | :type minimum: float 575 | :param maximum: maximum value 576 | :type maximum: float 577 | :return: interpolation vector 578 | :rtype: np.ndarray 579 | """ 580 | num_levels = int((maximum - minimum) / step) + 1 581 | levels = [-1. + step * i for i in range(num_levels)] 582 | 583 | interpolation_vector = np.zeros([num_classes, num_levels, num_classes]) 584 | for cls in range(num_classes): 585 | for lv in range(num_levels): 586 | interpolation_vector[cls, lv, cls] = levels[lv] 587 | 588 | return interpolation_vector 589 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import FlowStep, FlowModel, Glow 2 | from .module import (ActNorm, LinearZeros, Conv2d, Conv2dZeros, 3 | f, Invertible1x1Conv, Permutation2d, GaussianDiag, 4 | Split2d, Squeeze2d) 5 | 6 | from .builder import Builder 7 | from .trainer import Trainer 8 | from .inferer import Inferer 9 | 10 | __all__ = ( 11 | FlowStep, FlowModel, Glow, 12 | ActNorm, LinearZeros, Conv2d, Conv2dZeros, 13 | f, Invertible1x1Conv, Permutation2d, GaussianDiag, 14 | Split2d, Squeeze2d, 15 | Builder, Trainer, Inferer 16 | ) 17 | -------------------------------------------------------------------------------- /network/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from functools import partial 5 | from network import model 6 | from misc import util, lr_scheduler 7 | 8 | 9 | class Builder: 10 | optimizer_dict = { 11 | 'adam': lambda params, **kwargs: torch.optim.Adam(params, **kwargs), 12 | 'adamax': lambda params, **kwargs: torch.optim.Adamax(params, **kwargs) 13 | } 14 | lr_scheduler_dict = { 15 | 'constant': lambda **kwargs: lr_scheduler.constant(**kwargs), 16 | 'noam': lambda **kwargs: lr_scheduler.noam_decay(**kwargs), 17 | 'linear': lambda **kwargs: lr_scheduler.linear_anneal(**kwargs), 18 | 'step': lambda **kwargs: lr_scheduler.step_anneal(**kwargs), 19 | 'cyclic_cosine': lambda **kwargs: lr_scheduler.cyclic_cosine_anneal(**kwargs), 20 | } 21 | 22 | def __init__(self, hps): 23 | """ 24 | Network builder 25 | 26 | :param hps: hyper-parameters for this network 27 | :type hps: dict 28 | """ 29 | super().__init__() 30 | self.hps = hps 31 | 32 | def build(self, training=True): 33 | """ 34 | Build network 35 | 36 | :param training: 37 | :type training: 38 | :return: 39 | :rtype: 40 | """ 41 | # initialize all variables 42 | step = 0 43 | state = None 44 | result_subdir = None 45 | graph, optimizer, scheduler, criterion_dict = None, None, None, None 46 | devices = util.get_devices(self.hps.device.graph) 47 | data_device = util.get_devices(self.hps.device.data)[0] 48 | 49 | # build graph 50 | graph = model.Glow(self.hps) 51 | graph.to('cpu') 52 | 53 | # load model 54 | if graph is not None: 55 | # locate or create result subdir 56 | if self.hps.general.warm_start and self.hps.general.resume_run_id != "": 57 | result_subdir = util.locate_result_subdir(self.hps.general.result_dir, 58 | self.hps.general.resume_run_id) 59 | 60 | if training and result_subdir is None: 61 | result_subdir = util.create_result_subdir(self.hps.general.result_dir, 62 | desc=self.hps.profile, 63 | profile=self.hps) 64 | # load pre-trained model on first device 65 | if self.hps.general.warm_start: 66 | step_or_model_path = None 67 | if os.path.exists(self.hps.general.pre_trained): 68 | step_or_model_path = self.hps.general.pre_trained 69 | elif self.hps.general.resume_step not in ['', 'best', 'latest']: 70 | step_or_model_path = int(self.hps.general.resume_step) 71 | if step_or_model_path is not None: 72 | state = util.load_model(result_subdir, step_or_model_path, graph, 73 | device=devices[0]) 74 | if not training and state is None: 75 | raise RuntimeError('No pre-trained model for inference') 76 | # move graph to devices 77 | if 'cpu' in devices: 78 | graph = graph.cpu() 79 | data_device = 'cpu' 80 | else: 81 | graph = graph.to(devices[0]) 82 | print('[Builder] Use {} for model running and {} for data loading'.format(devices[0], data_device)) 83 | 84 | # setup optimizer and lr scheduler 85 | if training and graph is not None: 86 | # get optimizer 87 | optimizer_name = self.hps.optim.optimizer.lower() 88 | assert optimizer_name in self.optimizer_dict.keys(), \ 89 | "Unsupported optimizer: {}".format(optimizer_name) 90 | # If you need to move a model to GPU via .cuda(), please do so before constructing optimizers for it. 91 | optimizer = self.optimizer_dict[optimizer_name]( 92 | graph.parameters(), 93 | **self.hps.optim.optimizer_args) 94 | if state is not None: 95 | optimizer.load_state_dict(state['optimizer']) 96 | # get lr scheduler 97 | scheduler_name = self.hps.optim.lr_scheduler.lower() 98 | scheduler_args = self.hps.optim.lr_scheduler_args 99 | assert scheduler_name in self.lr_scheduler_dict.keys(), \ 100 | "Unsupported lr scheduler: {}".format(scheduler_name) 101 | if 'base_lr' not in scheduler_args: 102 | scheduler_args['base_lr'] = self.hps.optim.optimizer_args['lr'] 103 | scheduler = partial(self.lr_scheduler_dict[scheduler_name], **scheduler_args) 104 | 105 | return { 106 | 'step': step, 107 | 'graph': graph, 108 | 'optimizer': optimizer, 109 | 'scheduler': scheduler, 110 | 'devices': devices, 111 | 'data_device': data_device, 112 | 'result_subdir': result_subdir 113 | } 114 | -------------------------------------------------------------------------------- /network/inferer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from tqdm import tqdm 5 | from torchvision.utils import make_grid 6 | from torch.utils.data import DataLoader 7 | 8 | from misc import util 9 | 10 | 11 | class Inferer: 12 | 13 | def __init__(self, hps, graph, devices, data_device): 14 | """ 15 | Network inferer 16 | 17 | :param hps: hyper-parameters for this network 18 | :type hps: dict 19 | :param graph: model graph 20 | :type graph: torch.nn.Module 21 | :param devices: list of usable devices for model running 22 | :type devices: list 23 | :param data_device: list of usable devices for data loading 24 | :type data_device: str or int 25 | """ 26 | # general 27 | self.hps = hps 28 | # state 29 | self.graph = graph 30 | self.graph.eval() 31 | self.devices = devices 32 | self.use_cuda = 'cpu' not in self.devices 33 | # data 34 | self.data_device = data_device 35 | self.batch_size = self.graph.h_top.shape[0] 36 | self.num_classes = self.hps.dataset.num_classes 37 | # ablation 38 | self.y_condition = self.hps.ablation.y_condition 39 | 40 | def sample(self, z, y_onehot, eps_std=0.5): 41 | """ 42 | Sample image 43 | 44 | :param z: latent feature vector 45 | :type z: torch.Tensor or None 46 | :param y_onehot: one-hot vector of label 47 | :type y_onehot: torch.Tensor or None 48 | :param eps_std: standard deviation of eps 49 | :type eps_std: float 50 | :return: generated image 51 | :rtype: torch.Tensor 52 | """ 53 | with torch.no_grad(): 54 | # generate sample from model 55 | img = self.graph(z=z, y_onehot=y_onehot, eps_std=eps_std, reverse=True) 56 | 57 | # create image grid 58 | grid = make_grid(img) 59 | 60 | return grid 61 | 62 | def encode(self, img): 63 | """ 64 | Encode input image to latent features 65 | 66 | :param img: input image 67 | :type img: torch.Tensor or np.numpy or Image.Image 68 | :return: latent features 69 | :rtype: torch.Tensor 70 | """ 71 | with torch.no_grad(): 72 | if not torch.is_tensor(img): 73 | img = util.image_to_tensor( 74 | img, 75 | shape=self.hps.model.image_shape) 76 | img = util.make_batch(img, self.batch_size) 77 | elif len(img.shape) == 3: 78 | img = util.make_batch(img, self.batch_size) 79 | if self.use_cuda: 80 | img = img.cuda() 81 | z, _, _ = self.graph(img) 82 | return z[0, :, :, :] 83 | 84 | def decode(self, z): 85 | """ 86 | 87 | :param z: input latent feature vector 88 | :type z: torch.Tensor 89 | :return: decoded image 90 | :rtype: torch.Tensor 91 | """ 92 | with torch.no_grad(): 93 | if len(z.shape) == 3: 94 | z = util.make_batch(z, self.batch_size) 95 | if self.use_cuda: 96 | z = z.cuda() 97 | 98 | img = self.graph(z=z, y_onehot=None, reverse=True)[0, :, :, :] 99 | return img 100 | 101 | def compute_attribute_delta(self, dataset): 102 | """ 103 | Compute feature vector deltaz of different attributes 104 | 105 | :param dataset: dataset for training model 106 | :type dataset: torch.utils.data.Dataset 107 | :return: 108 | :rtype: 109 | """ 110 | with torch.no_grad(): 111 | # initialize variables 112 | attrs_z_pos = np.zeros([self.num_classes, *self.graph.flow.output_shapes[-1][1:]]) 113 | attrs_z_neg = np.zeros([self.num_classes, *self.graph.flow.output_shapes[-1][1:]]) 114 | num_z_pos = np.zeros(self.num_classes) 115 | num_z_neg = np.zeros(self.num_classes) 116 | deltaz = np.zeros([self.num_classes, *self.graph.flow.output_shapes[-1][1:]]) 117 | 118 | data_loader = DataLoader(dataset, batch_size=self.batch_size, 119 | num_workers=self.hps.dataset.num_workers, 120 | shuffle=True, 121 | drop_last=True) 122 | 123 | progress = tqdm(data_loader) 124 | for idx, batch in enumerate(progress): 125 | # extract batch data 126 | assert 'y_onehot' in batch.keys(), 'Compute attribute deltaz needs "y_onehot" in batch data' 127 | for i in batch: 128 | batch[i] = batch[i].to(self.data_device) 129 | x = batch['x'] 130 | y_onehot = batch['y_onehot'] 131 | 132 | # decode latent features 133 | z, _, _ = self.graph(x) 134 | 135 | # append to latent feature list by attributes 136 | for i in range(len(batch)): 137 | for cls in range(self.num_classes): 138 | if y_onehot[i, cls] > 0: 139 | attrs_z_pos[cls] += z[i] 140 | num_z_pos[cls] += 1 141 | else: 142 | attrs_z_neg[cls] += z[i] 143 | num_z_neg[cls] += 1 144 | 145 | # compute deltaz 146 | num_z_pos = [max(1., float(num)) for num in num_z_pos] 147 | num_z_neg = [max(1., float(num)) for num in num_z_neg] 148 | for cls in range(self.num_classes): 149 | mean_z_pos = attrs_z_pos[cls] / num_z_pos[cls] 150 | mean_z_neg = attrs_z_neg[cls] / num_z_neg[cls] 151 | deltaz[cls] = mean_z_pos - mean_z_neg 152 | 153 | return deltaz 154 | 155 | def apply_attribute_delta(self, img, deltaz, interpolation): 156 | """ 157 | Apply attribute delta to image by given interpolation vector 158 | 159 | :param img: given image 160 | :type img: torch.Tensor or np.numpy or Image.Image 161 | :param deltaz: delta vector of attributes in latent space 162 | :type deltaz: np.ndarray 163 | :param interpolation: interpolation vector 164 | :type interpolation: torch.Tensor or np.ndarray or list[float] 165 | :return: processed image 166 | :rtype: torch.Tensor 167 | """ 168 | if isinstance(deltaz, np.ndarray): 169 | deltaz = torch.Tensor(deltaz) 170 | assert len(interpolation) == self.num_classes 171 | assert deltaz.shape == torch.Size([self.num_classes, 172 | *self.graph.flow.output_shapes[-1][1:]]) 173 | 174 | # encode 175 | z = self.encode(img) 176 | 177 | # perform interpolation 178 | z_interpolated = z.clone() 179 | for i in range(len(interpolation)): 180 | z_delta = deltaz[i].mul(interpolation[i]) 181 | if self.use_cuda: 182 | z_delta = z_delta.cuda() 183 | z_interpolated += z_delta 184 | 185 | # decode 186 | img = self.decode(z_interpolated) 187 | 188 | return img 189 | -------------------------------------------------------------------------------- /network/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from network import module 7 | from misc import ops, util 8 | 9 | 10 | class FlowStep(nn.Module): 11 | flow_permutation_list = ['invconv', 'reverse', 'shuffle'] 12 | flow_coupling_list = ['additive', 'affine'] 13 | 14 | def __init__(self, 15 | in_channels, 16 | hidden_channels, 17 | permutation='invconv', 18 | coupling='additive', 19 | actnorm_scale=1., 20 | lu_decomposition=False): 21 | """ 22 | One step of flow described in paper 23 | 24 | ▲ 25 | │ 26 | ┌─────────────┼─────────────┐ 27 | │ ┌──────────┴──────────┐ │ 28 | │ │ flow coupling layer │ │ 29 | │ └──────────▲──────────┘ │ 30 | │ │ │ 31 | │ ┌──────────┴──────────┐ │ 32 | │ │ flow permutation │ │ 33 | │ │ layer │ │ 34 | │ └──────────▲──────────┘ │ 35 | │ │ │ 36 | │ ┌──────────┴──────────┐ │ 37 | │ │ activation │ │ 38 | │ │ normalization layer │ │ 39 | │ └──────────▲──────────┘ │ 40 | └─────────────┼─────────────┘ 41 | │ 42 | │ 43 | 44 | :param in_channels: number of input channels 45 | :type in_channels: int 46 | :param hidden_channels: number of hidden channels 47 | :type hidden_channels: int 48 | :param permutation: type of flow permutation 49 | :type permutation: str 50 | :param coupling: type of flow coupling 51 | :type coupling: str 52 | :param actnorm_scale: scale factor of actnorm layer 53 | :type actnorm_scale: float 54 | :param lu_decomposition: whether to use LU decomposition or not 55 | :type lu_decomposition: bool 56 | """ 57 | super().__init__() 58 | # permutation and coupling 59 | assert permutation in self.flow_permutation_list, 'Unsupported flow permutation: {}'.format(permutation) 60 | assert coupling in self.flow_coupling_list, 'Unsupported flow coupling: {}'.format(coupling) 61 | self.permutation = permutation 62 | self.coupling = coupling 63 | 64 | # activation normalization layer 65 | self.actnorm = module.ActNorm(num_channels=in_channels, scale=actnorm_scale) 66 | 67 | # flow permutation layer 68 | if permutation == 'invconv': 69 | self.invconv = module.Invertible1x1Conv(num_channels=in_channels, 70 | lu_decomposition=lu_decomposition) 71 | elif permutation == 'reverse': 72 | self.reverse = module.Permutation2d(num_channels=in_channels, shuffle=False) 73 | else: 74 | self.shuffle = module.Permutation2d(num_channels=in_channels, shuffle=True) 75 | 76 | # flow coupling layer 77 | if coupling == 'additive': 78 | self.f = module.f(in_channels // 2, hidden_channels, in_channels // 2) 79 | else: 80 | self.f = module.f(in_channels // 2, hidden_channels, in_channels) 81 | 82 | def normal_flow(self, x, logdet=None): 83 | """ 84 | Normal flow 85 | 86 | :param x: input tensor 87 | :type x: torch.Tensor 88 | :param logdet: log determinant 89 | :type logdet: torch.Tensor 90 | :return: output and logdet 91 | :rtype: tuple(torch.Tensor, torch.Tensor) 92 | """ 93 | # activation normalization layer 94 | z, logdet = self.actnorm(x, logdet=logdet, reverse=False) 95 | 96 | # flow permutation layer 97 | if self.permutation == 'invconv': 98 | z, logdet = self.invconv(z, logdet, reverse=False) 99 | elif self.permutation == 'reverse': 100 | z = self.reverse(z, reverse=False) 101 | else: 102 | z = self.shuffle(z, reverse=False) 103 | 104 | # flow coupling layer 105 | z1, z2 = ops.split_channel(z, 'simple') 106 | if self.coupling == 'additive': 107 | z2 += self.f(z1) 108 | else: 109 | h = self.f(z1) 110 | shift, scale = ops.split_channel(h, 'cross') 111 | scale = F.sigmoid(scale + 2.) 112 | z2 += shift 113 | z2 *= scale 114 | logdet = ops.reduce_sum(torch.log(scale), dim=[1, 2, 3]) + logdet 115 | z = ops.cat_channel(z1, z2) 116 | 117 | return z, logdet 118 | 119 | def reverse_flow(self, x, logdet=None): 120 | """ 121 | Reverse flow 122 | 123 | :param x: input tensor 124 | :type x: torch.Tensor 125 | :param logdet: log determinant 126 | :type logdet: torch.Tensor 127 | :return: output and logdet 128 | :rtype: tuple(torch.Tensor, torch.Tensor) 129 | """ 130 | # flow coupling layer 131 | z1, z2 = ops.split_channel(x, 'simple') 132 | if self.coupling == 'additive': 133 | z2 -= self.f(z1) 134 | else: 135 | h = self.f(z1) 136 | shift, scale = ops.split_channel(h, 'cross') 137 | scale = F.sigmoid(scale + 2.) 138 | z2 /= scale 139 | z2 -= shift 140 | logdet = -ops.reduce_sum(torch.log(scale), dim=[1, 2, 3]) + logdet 141 | z = ops.cat_channel(z1, z2) 142 | 143 | # flow permutation layer 144 | if self.permutation == 'invconv': 145 | z, logdet = self.invconv(z, logdet, reverse=True) 146 | elif self.permutation == 'reverse': 147 | z = self.reverse(z, reverse=True) 148 | else: 149 | z = self.shuffle(z, reverse=True) 150 | 151 | # activation normalization layer 152 | z, logdet = self.actnorm(z, logdet=logdet, reverse=True) 153 | 154 | return z, logdet 155 | 156 | def forward(self, x, logdet=None, reverse=False): 157 | """ 158 | Forward oen step of flow 159 | 160 | :param x: input tensor 161 | :type x: torch.Tensor 162 | :param logdet: log determinant 163 | :type logdet: torch.Tensor 164 | :param reverse: whether to reverse flow 165 | :type reverse: bool 166 | :return: output and logdet 167 | :rtype: tuple(torch.Tensor, torch.Tensor) 168 | """ 169 | assert x.shape[1] % 2 == 0 170 | if not reverse: 171 | return self.normal_flow(x, logdet) 172 | else: 173 | return self.reverse_flow(x, logdet) 174 | 175 | 176 | class FlowModel(nn.Module): 177 | def __init__(self, 178 | in_shape, 179 | hidden_channels, 180 | K, L, 181 | permutation='invconv', 182 | coupling='additive', 183 | actnorm_scale=1., 184 | lu_decomposition=False): 185 | """ 186 | Flow model with multi-scale architecture 187 | 188 | ┏━━━┓ 189 | ┃z_L┃ 190 | ┗━▲━┛ 191 | │ 192 | ┌──────────┴──────────┐ 193 | │ step of flow │* K 194 | └──────────▲──────────┘ 195 | ┌──────────┴──────────┐ 196 | │ squeeze │ 197 | └──────────▲──────────┘ 198 | ├──────────────┐ 199 | ┏━━━┓ ┌──────────┴──────────┐ │ 200 | ┃z_i┃◀──┤ split │ │ 201 | ┗━━━┛ └──────────▲──────────┘ │ 202 | ┌──────────┴──────────┐ │ 203 | │ step of flow │* K│ * (L-1) 204 | └──────────▲──────────┘ │ 205 | ┌──────────┴──────────┐ │ 206 | │ squeeze │ │ 207 | └──────────▲──────────┘ │ 208 | │◀─────────────┘ 209 | ┏━┻━┓ 210 | ┃ x ┃ 211 | ┗━━━┛ 212 | 213 | :param in_shape: shape of input image in (H, W, C) 214 | :type in_shape: torch.Size or tuple(int) or list(int) 215 | :param hidden_channels: number of hidden channels 216 | :type hidden_channels: int 217 | :param K: depth of flow 218 | :type K: int 219 | :param L: number of levels 220 | :type L: int 221 | :param permutation: type of flow permutation 222 | :type permutation: str 223 | :param coupling: type of flow coupling 224 | :type coupling: str 225 | :param actnorm_scale: scale factor of actnorm layer 226 | :type actnorm_scale: float 227 | :param lu_decomposition: whether to use LU decomposition or not 228 | :type lu_decomposition: bool 229 | """ 230 | super().__init__() 231 | self.K = K 232 | self.L = L 233 | 234 | # image shape 235 | assert len(in_shape) == 3 236 | assert in_shape[2] == 1 or in_shape[2] == 3 237 | nh, nw, nc = in_shape 238 | 239 | # initialize layers 240 | self.layers = nn.ModuleList() 241 | self.output_shapes = [] 242 | for i in range(L): 243 | # squeeze 244 | self.layers.append(module.Squeeze2d(factor=2)) 245 | nc, nh, nw = nc * 4, nh // 2, nw // 2 246 | self.output_shapes.append([-1, nc, nh, nw]) 247 | # flow step * K 248 | for _ in range(K): 249 | self.layers.append(FlowStep( 250 | in_channels=nc, 251 | hidden_channels=hidden_channels, 252 | permutation=permutation, 253 | coupling=coupling, 254 | actnorm_scale=actnorm_scale, 255 | lu_decomposition=lu_decomposition)) 256 | self.output_shapes.append([-1, nc, nh, nw]) 257 | # split 258 | if i < L - 1: 259 | self.layers.append(module.Split2d(num_channels=nc)) 260 | nc = nc // 2 261 | self.output_shapes.append([-1, nc, nh, nw]) 262 | 263 | def encode(self, z, logdet=0.): 264 | """ 265 | Encode input 266 | 267 | :param z: input tensor 268 | :type z: torch.Tensor 269 | :param logdet: log determinant 270 | :type logdet: torch.Tensor 271 | :return: encoded tensor 272 | :rtype: torch.Tensor 273 | """ 274 | for layer in self.layers: 275 | z, logdet = layer(z, logdet, reverse=False) 276 | return z, logdet 277 | 278 | def decode(self, z, eps_std=None): 279 | """ 280 | Decode input 281 | 282 | :param z: input tensor 283 | :type z: torch.Tensor 284 | :param eps_std: standard deviation of eps 285 | :type eps_std: float 286 | :return: decoded tensor 287 | :rtype: torch.Tensor 288 | """ 289 | for layer in reversed(self.layers): 290 | if isinstance(layer, module.Split2d): 291 | z, logdet = layer(z, logdet=0., reverse=True, eps_std=eps_std) 292 | else: 293 | z, logdet = layer(z, logdet=0., reverse=True) 294 | return z 295 | 296 | def forward(self, z, logdet=0., eps_std=None, reverse=False): 297 | """ 298 | Forward flow model 299 | 300 | :param z: input tensor 301 | :type z: torch.Tensor 302 | :param logdet: log determinant 303 | :type logdet: torch.Tensor 304 | :param eps_std: standard deviation of eps 305 | :type eps_std: float 306 | :param reverse: whether to reverse flow 307 | :type reverse: bool 308 | :return: output tensor 309 | :rtype: torch.Tensor 310 | """ 311 | if not reverse: 312 | return self.encode(z, logdet) 313 | else: 314 | return self.decode(z, eps_std) 315 | 316 | 317 | class Glow(nn.Module): 318 | bce_criterion = nn.BCEWithLogitsLoss() 319 | ce_criterion = nn.CrossEntropyLoss() 320 | 321 | def __init__(self, hps): 322 | """ 323 | Glow network 324 | 325 | :param hps: hyper-parameters for this network 326 | :type hps: dict 327 | """ 328 | super().__init__() 329 | 330 | self.hps = hps 331 | self.flow = FlowModel( 332 | in_shape=hps.model.image_shape, 333 | hidden_channels=hps.model.hidden_channels, 334 | K=hps.model.K, 335 | L=hps.model.L, 336 | permutation=hps.ablation.flow_permutation, 337 | coupling=hps.ablation.flow_coupling, 338 | actnorm_scale=hps.model.actnorm_scale, 339 | lu_decomposition=hps.ablation.lu_decomposition) 340 | 341 | if hps.ablation.learn_top: 342 | nc = self.flow.output_shapes[-1][1] 343 | self.learn_top = module.Conv2dZeros(in_channels=2 * nc, 344 | out_channels=2 * nc) 345 | if hps.ablation.y_condition: 346 | nc = self.flow.output_shapes[-1][1] 347 | self.y_emb = module.LinearZeros(hps.dataset.num_classes, nc * 2) 348 | self.classifier = module.LinearZeros(nc, hps.dataset.num_classes) 349 | 350 | num_device = len(util.get_devices(self.hps.device.graph, verbose=False)) 351 | assert hps.optim.num_batch_train % num_device == 0 352 | self.register_parameter('h_top', 353 | nn.Parameter(torch.zeros([hps.optim.num_batch_train // num_device, 354 | self.flow.output_shapes[-1][1] * 2, 355 | self.flow.output_shapes[-1][2], 356 | self.flow.output_shapes[-1][3]]))) 357 | 358 | @property 359 | def batch_h_top(self): 360 | return self.h_top.shape[0] 361 | 362 | def prior(self, y_onehot=None): 363 | """ 364 | Prior 365 | 366 | :param y_onehot: one-hot vector of label 367 | :type y_onehot: torch.Tensor 368 | :return: hidden output 369 | :rtype: torch.Tensor 370 | """ 371 | nc = self.h_top.shape[1] 372 | h = self.h_top.detach().clone() 373 | assert torch.sum(h) == 0. 374 | if self.hps.ablation.learn_top: 375 | h = self.learn_top(h) 376 | if self.hps.ablation.y_condition: 377 | assert y_onehot is not None 378 | h += self.y_emb(y_onehot).view(-1, nc, 1, 1) 379 | return ops.split_channel(h, 'simple') 380 | 381 | # def preprocess(self, x): 382 | # """ 383 | # Pre-process for input 384 | # 385 | # :param x: input 386 | # :type x: torch.Tensor 387 | # :return: precessed input 388 | # :rtype: torch.Tensor 389 | # """ 390 | # n_bins = 2 ** self.hps.model.n_bits_x 391 | # if self.hps.model.n_bits_x < 8: 392 | # x = torch.floor(x / 2 ** (8 - self.hps.model.n_bits_x)) 393 | # x = x / n_bins - .5 394 | # return x 395 | # 396 | # def postprocess(self, x): 397 | # """ 398 | # Pre-process for input 399 | # 400 | # :param x: input 401 | # :type x: torch.Tensor 402 | # :return: precessed input 403 | # :rtype: torch.Tensor 404 | # """ 405 | # n_bins = 2 ** self.hps.model.n_bits_x 406 | # x = torch.clamp(torch.floor((x + .5) * n_bins) * (256. / n_bins), min=0, max=255) 407 | # return x 408 | 409 | def normal_flow(self, x, y_onehot): 410 | """ 411 | Normal flow 412 | 413 | :param x: input tensor 414 | :type x: torch.Tensor 415 | :param y_onehot: one-hot vector of label 416 | :type y_onehot: torch.Tensor 417 | """ 418 | # Pre-process for z 419 | n_bins = 2 ** self.hps.model.n_bits_x 420 | # z = self.preprocess(x) 421 | z = x + torch.nn.init.uniform_(torch.empty(*x.shape, device=x.device), 0, 1. / n_bins) 422 | # z = x + module.GaussianDiag.eps(x, eps_std=1. / n_bins) 423 | 424 | # Initialize logdet 425 | logdet_factor = x.shape[1] * ops.count_pixels(x) # N = C * H * W 426 | objective = torch.zeros_like(x[:, 0, 0, 0]) 427 | # c = M * log(a), where a is determined by the discretization level 428 | # of the data and M is the dimensionality of x 429 | objective += float(-np.log(n_bins)) * logdet_factor 430 | 431 | # Encode 432 | z, objective = self.flow(z, logdet=objective, reverse=False) 433 | 434 | # Prior 435 | mean, logs = self.prior(y_onehot) 436 | # x_tilde(i) = x(i) + u 437 | # u ~ U(0,a), where a is determined by the discretization level of the data 438 | objective += module.GaussianDiag.logp(mean, logs, z) 439 | 440 | # Prediction loss 441 | if self.hps.ablation.y_condition and self.hps.model.weight_y > 0: 442 | h_y = ops.reduce_mean(z, dim=[2, 3]) 443 | y_logits = self.classifier(h_y) 444 | else: 445 | y_logits = None 446 | 447 | # Generative loss 448 | nobj = -objective 449 | # negative log-likelihood 450 | nll = nobj / float(np.log(2.) * logdet_factor) 451 | 452 | return z, nll, y_logits 453 | 454 | def reverse_flow(self, z, y_onehot, eps_std=None): 455 | """ 456 | Reverse flow 457 | 458 | :param z: latent vector 459 | :type z: torch.Tensor 460 | :param y_onehot: one-hot vector of label 461 | :type y_onehot: torch.Tensor 462 | :param eps_std: standard deviation of eps 463 | :type eps_std: float 464 | """ 465 | with torch.no_grad(): 466 | mean, logs = self.prior(y_onehot) 467 | if z is None: 468 | z = module.GaussianDiag.sample(mean, logs, eps_std) 469 | x = self.flow(z, eps_std=eps_std, reverse=True) 470 | # x = self.postprocess(x) 471 | return x 472 | 473 | def forward(self, 474 | x=None, y_onehot=None, 475 | z=None, eps_std=None, 476 | reverse=False): 477 | """ 478 | Forward Glow model 479 | 480 | :param x: input tensor 481 | :type x: torch.Tensor 482 | :param y_onehot: one-hot vector of label 483 | :type y_onehot: torch.Tensor 484 | :param z: latent vector 485 | :type z: torch.Tensor 486 | :param eps_std: standard deviation of eps 487 | :type eps_std: float 488 | :param reverse: whether to reverse flow 489 | :type reverse: bool 490 | """ 491 | if not reverse: 492 | return self.normal_flow(x, y_onehot) 493 | else: 494 | return self.reverse_flow(z, y_onehot, eps_std) 495 | 496 | @staticmethod 497 | def generative_loss(nll): 498 | """ 499 | Loss for generation 500 | 501 | :param nll: negative logistic likehood 502 | :type nll: torch.Tensor 503 | :return: generative loss 504 | :rtype: torch.Tensor 505 | """ 506 | return torch.mean(nll) 507 | 508 | @staticmethod 509 | def single_class_loss(y_logits, y): 510 | """ 511 | Classification loss for single target class problem 512 | 513 | :param y_logits: prediction in the shape of (N, Classes) 514 | :type y_logits: torch.Tensor 515 | :param y: target in the shape of (N) 516 | :type y: torch.Tensor 517 | :return: classification loss 518 | :rtype: torch.Tensor 519 | """ 520 | if y_logits is None: 521 | return 0 522 | return Glow.ce_criterion(y_logits, y.long()) 523 | 524 | @staticmethod 525 | def multi_class_loss(y_logits, y_onehot): 526 | """ 527 | Classification loss for multiple target class problem 528 | 529 | :param y_logits: prediction in the shape of (N, Classes) 530 | :type y_logits: torch.Tensor 531 | :param y_onehot: one-hot targte vector in the shape of (N, Classes) 532 | :type y_onehot: torch.Tensor 533 | :return: classification loss 534 | :rtype: torch.Tensor 535 | """ 536 | if y_logits is None: 537 | return 0 538 | return Glow.bce_criterion(y_logits, y_onehot.float()) 539 | 540 | def set_actnorm_inited(self, inited=True): 541 | """ 542 | Set bias and logs of ActNorm layer initialized 543 | 544 | :param inited: initialization state 545 | :type inited: bool 546 | """ 547 | for name, m in self.named_modules(): 548 | if m.__class__.__name__.find("ActNorm") >= 0: 549 | m.bias_inited = inited 550 | m.logs_inited = inited 551 | -------------------------------------------------------------------------------- /network/module.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from misc import ops 7 | 8 | 9 | class ActNorm(nn.Module): 10 | def __init__(self, num_channels, scale=1., logscale_factor=3., batch_variance=False): 11 | """ 12 | Activation normalization layer 13 | 14 | :param num_channels: number of channels 15 | :type num_channels: int 16 | :param scale: scale 17 | :type scale: float 18 | :param logscale_factor: factor for logscale 19 | :type logscale_factor: float 20 | :param batch_variance: use batch variance 21 | :type batch_variance: bool 22 | """ 23 | super().__init__() 24 | self.num_channels = num_channels 25 | self.scale = scale 26 | self.logscale_factor = logscale_factor 27 | self.batch_variance = batch_variance 28 | 29 | self.bias_inited = False 30 | self.logs_inited = False 31 | self.register_parameter('bias', nn.Parameter(torch.zeros(1, self.num_channels, 1, 1))) 32 | self.register_parameter('logs', nn.Parameter(torch.zeros(1, self.num_channels, 1, 1))) 33 | 34 | def actnorm_center(self, x, reverse=False): 35 | """ 36 | center operation of activation normalization 37 | 38 | :param x: input 39 | :type x: torch.Tensor 40 | :param reverse: whether to reverse bias 41 | :type reverse: bool 42 | :return: centered input 43 | :rtype: torch.Tensor 44 | """ 45 | if not self.bias_inited: 46 | self.initialize_bias(x) 47 | if not reverse: 48 | return x + self.bias 49 | else: 50 | return x - self.bias 51 | 52 | def actnorm_scale(self, x, logdet, reverse=False): 53 | """ 54 | scale operation of activation normalization 55 | 56 | :param x: input 57 | :type x: torch.Tensor 58 | :param logdet: log determinant 59 | :type logdet: 60 | :param reverse: whether to reverse bias 61 | :type reverse: bool 62 | :return: centered input and logdet 63 | :rtype: tuple(torch.Tensor, torch.Tensor) 64 | """ 65 | 66 | if not self.logs_inited: 67 | self.initialize_logs(x) 68 | 69 | # TODO condition for non 4-dims input 70 | logs = self.logs * self.logscale_factor 71 | 72 | if not reverse: 73 | x *= torch.exp(logs) 74 | else: 75 | x *= torch.exp(-logs) 76 | 77 | if logdet is not None: 78 | logdet_factor = ops.count_pixels(x) # H * W 79 | dlogdet = torch.sum(logs) * logdet_factor 80 | if reverse: 81 | dlogdet *= -1 82 | logdet += dlogdet 83 | 84 | return x, logdet 85 | 86 | def initialize_bias(self, x): 87 | """ 88 | Initialize bias 89 | 90 | :param x: input 91 | :type x: torch.Tensor 92 | """ 93 | if not self.training: 94 | return 95 | with torch.no_grad(): 96 | # Compute initial value 97 | x_mean = -1. * ops.reduce_mean(x, dim=[0, 2, 3], keepdim=True) 98 | # Copy to parameters 99 | self.bias.data.copy_(x_mean.data) 100 | self.bias_inited = True 101 | 102 | def initialize_logs(self, x): 103 | """ 104 | Initialize logs 105 | 106 | :param x: input 107 | :type x: torch.Tensor 108 | """ 109 | if not self.training: 110 | return 111 | with torch.no_grad(): 112 | if self.batch_variance: 113 | x_var = ops.reduce_mean(x ** 2, keepdim=True) 114 | else: 115 | x_var = ops.reduce_mean(x ** 2, dim=[0, 2, 3], keepdim=True) 116 | logs = torch.log(self.scale / (torch.sqrt(x_var) + 1e-6)) / self.logscale_factor 117 | 118 | # Copy to parameters 119 | self.logs.data.copy_(logs.data) 120 | self.logs_inited = True 121 | 122 | def forward(self, x, logdet=None, reverse=False): 123 | """ 124 | Forward activation normalization layer 125 | 126 | :param x: input 127 | :type x: torch.Tensor 128 | :param logdet: log determinant 129 | :type logdet: 130 | :param reverse: whether to reverse bias 131 | :type reverse: bool 132 | :return: normalized input and logdet 133 | :rtype: tuple(torch.Tensor, torch.Tensor) 134 | """ 135 | assert len(x.shape) == 4 136 | assert x.shape[1] == self.num_channels, \ 137 | 'Input shape should be NxCxHxW, however channels are {} instead of {}'.format(x.shape[1], self.num_channels) 138 | assert x.device == self.bias.device and x.device == self.logs.device, \ 139 | 'Expect input device {} instead of {}'.format(self.bias.device, x.device) 140 | 141 | if not reverse: 142 | # center and scale 143 | x = self.actnorm_center(x, reverse=False) 144 | x, logdet = self.actnorm_scale(x, logdet, reverse=False) 145 | else: 146 | # scale and center 147 | x, logdet = self.actnorm_scale(x, logdet, reverse=True) 148 | x = self.actnorm_center(x, reverse=True) 149 | return x, logdet 150 | 151 | 152 | class LinearZeros(nn.Linear): 153 | def __init__(self, in_features, out_features, bias=True, logscale_factor=3.): 154 | """ 155 | Linear layer with zero initialization 156 | 157 | :param in_features: size of each input sample 158 | :type in_features: int 159 | :param out_features: size of each output sample 160 | :type out_features: int 161 | :param bias: whether to learn an additive bias. 162 | :type bias: bool 163 | :param logscale_factor: factor of logscale 164 | :type logscale_factor: float 165 | """ 166 | super().__init__(in_features, out_features, bias) 167 | self.logscale_factor = logscale_factor 168 | # zero initialization 169 | self.weight.data.zero_() 170 | self.bias.data.zero_() 171 | # register parameter 172 | self.register_parameter('logs', nn.Parameter(torch.zeros(out_features))) 173 | 174 | def forward(self, x): 175 | """ 176 | Forward linear zero layer 177 | 178 | :param x: input 179 | :type x: torch.Tensor 180 | :return: output 181 | :rtype: torch.Tensor 182 | """ 183 | output = super().forward(x) 184 | output *= torch.exp(self.logs * self.logscale_factor) 185 | return output 186 | 187 | 188 | class Conv2d(nn.Conv2d): 189 | @staticmethod 190 | def get_padding(padding_type, kernel_size, stride): 191 | """ 192 | Get padding size. 193 | 194 | mentioned in https://github.com/pytorch/pytorch/issues/3867#issuecomment-361775080 195 | behaves as 'SAME' padding in TensorFlow 196 | independent on input size when stride is 1 197 | 198 | :param padding_type: type of padding in ['SAME', 'VALID'] 199 | :type padding_type: str 200 | :param kernel_size: kernel size 201 | :type kernel_size: tuple(int) or int 202 | :param stride: stride 203 | :type stride: int 204 | :return: padding size 205 | :rtype: tuple(int) 206 | """ 207 | assert padding_type in ['SAME', 'VALID'], "Unsupported padding type: {}".format(padding_type) 208 | if isinstance(kernel_size, int): 209 | kernel_size = [kernel_size, kernel_size] 210 | if padding_type == 'SAME': 211 | assert stride == 1, "'SAME' padding only supports stride=1" 212 | return tuple((k - 1) // 2 for k in kernel_size) 213 | return tuple(0 for _ in kernel_size) 214 | 215 | def __init__(self, in_channels, out_channels, 216 | kernel_size=(3, 3), stride=1, padding_type='SAME', 217 | do_weightnorm=False, do_actnorm=True, 218 | dilation=1, groups=1): 219 | """ 220 | Wrapper of nn.Conv2d with weight normalization and activation normalization 221 | 222 | :param padding_type: type of padding in ['SAME', 'VALID'] 223 | :type padding_type: str 224 | :param do_weightnorm: whether to do weight normalization after convolution 225 | :type do_weightnorm: bool 226 | :param do_actnorm: whether to do activation normalization after convolution 227 | :type do_actnorm: bool 228 | """ 229 | padding = self.get_padding(padding_type, kernel_size, stride) 230 | super().__init__(in_channels, out_channels, 231 | kernel_size, stride, padding, 232 | dilation, groups, 233 | bias=(not do_actnorm)) 234 | self.do_weight_norm = do_weightnorm 235 | self.do_actnorm = do_actnorm 236 | 237 | self.weight.data.normal_(mean=0.0, std=0.05) 238 | if self.do_actnorm: 239 | self.actnorm = ActNorm(out_channels) 240 | else: 241 | self.bias.data.zero_() 242 | 243 | def forward(self, x): 244 | """ 245 | Forward wrapped Conv2d layer 246 | 247 | :param x: input 248 | :type x: torch.Tensor 249 | :return: output 250 | :rtype: torch.Tensor 251 | """ 252 | x = super().forward(x) 253 | # if self.do_weight_norm: 254 | # # normalize N, H and W dims 255 | # F.normalize(x, p=2, dim=0) 256 | # F.normalize(x, p=2, dim=2) 257 | # F.normalize(x, p=2, dim=3) 258 | if self.do_actnorm: 259 | x, _ = self.actnorm(x) 260 | return x 261 | 262 | 263 | class Conv2dZeros(nn.Conv2d): 264 | 265 | def __init__(self, in_channels, out_channels, 266 | kernel_size=(3, 3), stride=1, padding_type='SAME', 267 | logscale_factor=3, 268 | dilation=1, groups=1, bias=True): 269 | """ 270 | Wrapper of nn.Conv2d with zero initialization and logs 271 | 272 | :param padding_type: type of padding in ['SAME', 'VALID'] 273 | :type padding_type: str 274 | :param logscale_factor: factor for logscale 275 | :type logscale_factor: float 276 | """ 277 | padding = Conv2d.get_padding(padding_type, kernel_size, stride) 278 | super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 279 | 280 | self.logscale_factor = logscale_factor 281 | # initialize variables with zero 282 | self.bias.data.zero_() 283 | self.weight.data.zero_() 284 | self.register_parameter("logs", nn.Parameter(torch.zeros(out_channels, 1, 1))) 285 | 286 | def forward(self, x): 287 | """ 288 | Forward wrapped Conv2d layer 289 | 290 | :param x: input 291 | :type x: torch.Tensor 292 | :return: output 293 | :rtype: torch.Tensor 294 | """ 295 | x = super().forward(x) 296 | x *= torch.exp(self.logs * self.logscale_factor) 297 | return x 298 | 299 | 300 | def f(in_channels, hidden_channels, out_channels): 301 | """ 302 | Convolution block 303 | 304 | :param in_channels: number of input channels 305 | :type in_channels: int 306 | :param hidden_channels: number of hidden channels 307 | :type hidden_channels: int 308 | :param out_channels: number of output channels 309 | :type out_channels: int 310 | :return: desired convolution block 311 | :rtype: nn.Module 312 | """ 313 | return nn.Sequential( 314 | Conv2d(in_channels, hidden_channels), 315 | nn.ReLU(inplace=True), 316 | Conv2d(hidden_channels, hidden_channels, kernel_size=1), 317 | nn.ReLU(inplace=True), 318 | Conv2dZeros(hidden_channels, out_channels) 319 | ) 320 | 321 | 322 | class Invertible1x1Conv(nn.Module): 323 | 324 | def __init__(self, num_channels, lu_decomposition=False): 325 | """ 326 | Invertible 1x1 convolution layer 327 | 328 | :param num_channels: number of channels 329 | :type num_channels: int 330 | :param lu_decomposition: whether to use LU decomposition 331 | :type lu_decomposition: bool 332 | """ 333 | super().__init__() 334 | self.num_channels = num_channels 335 | self.lu_decomposition = lu_decomposition 336 | if self.lu_decomposition: 337 | raise NotImplementedError() 338 | else: 339 | w_shape = [num_channels, num_channels] 340 | # Sample a random orthogonal matrix 341 | w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype('float32') 342 | self.register_parameter('weight', nn.Parameter(torch.Tensor(w_init))) 343 | 344 | def forward(self, x, logdet=None, reverse=False): 345 | """ 346 | 347 | :param x: input 348 | :type x: torch.Tensor 349 | :param logdet: log determinant 350 | :type logdet: 351 | :param reverse: whether to reverse bias 352 | :type reverse: bool 353 | :return: output and logdet 354 | :rtype: tuple(torch.Tensor, torch.Tensor) 355 | """ 356 | logdet_factor = ops.count_pixels(x) # H * W 357 | dlogdet = torch.log(torch.abs(torch.det(self.weight))) * logdet_factor 358 | if not reverse: 359 | weight = self.weight.view(*self.weight.shape, 1, 1) 360 | z = F.conv2d(x, weight) 361 | if logdet is not None: 362 | logdet = logdet + dlogdet 363 | return z, logdet 364 | else: 365 | weight = self.weight.inverse().view(*self.weight.shape, 1, 1) 366 | z = F.conv2d(x, weight) 367 | if logdet is not None: 368 | logdet = logdet - dlogdet 369 | return z, logdet 370 | 371 | 372 | class Permutation2d(nn.Module): 373 | 374 | def __init__(self, num_channels, shuffle=False): 375 | """ 376 | Perform permutation on channel dimension 377 | 378 | :param num_channels: 379 | :type num_channels: 380 | :param shuffle: 381 | :type shuffle: 382 | """ 383 | super().__init__() 384 | self.num_channels = num_channels 385 | self.indices = np.arange(self.num_channels - 1, -1, -1, dtype=np.long) 386 | if shuffle: 387 | np.random.shuffle(self.indices) 388 | self.indices_inverse = np.zeros(self.num_channels, dtype=np.long) 389 | for i in range(self.num_channels): 390 | self.indices_inverse[self.indices[i]] = i 391 | 392 | def forward(self, x, reverse=False): 393 | assert len(x.shape) == 4 394 | if not reverse: 395 | return x[:, self.indices, :, :] 396 | else: 397 | return x[:, self.indices_inverse, :, :] 398 | 399 | 400 | class GaussianDiag: 401 | """ 402 | Generator of gaussian diagonal matrix 403 | """ 404 | 405 | log_2pi = float(np.log(2 * np.pi)) 406 | 407 | @staticmethod 408 | def eps(shape_tensor, eps_std=None): 409 | """ 410 | Returns a tensor filled with random numbers from a standard normal distribution 411 | 412 | :param shape_tensor: input tensor 413 | :type shape_tensor: torch.Tensor 414 | :param eps_std: standard deviation of eps 415 | :type eps_std: float 416 | :return: a tensor filled with random numbers from a standard normal distribution 417 | :rtype: torch.Tensor 418 | """ 419 | eps_std = eps_std or 1. 420 | return torch.normal(mean=torch.zeros_like(shape_tensor), 421 | std=torch.ones_like(shape_tensor) * eps_std) 422 | 423 | @staticmethod 424 | def flatten_sum(tensor): 425 | """ 426 | Summarize tensor except first dimension 427 | 428 | :param tensor: input tensor 429 | :type tensor: torch.Tensor 430 | :return: summarized tensor 431 | :rtype: torch.Tensor 432 | """ 433 | assert len(tensor.shape) == 4 434 | return ops.reduce_sum(tensor, dim=[1, 2, 3]) 435 | 436 | @staticmethod 437 | def logps(mean, logs, x): 438 | """ 439 | Likehood 440 | 441 | :param mean: 442 | :type mean: torch.Tensor 443 | :param logs: 444 | :type logs: torch.Tensor 445 | :param x: input tensor 446 | :type x: torch.Tensor 447 | :return: likehood 448 | :rtype: torch.Tensor 449 | """ 450 | return -0.5 * (GaussianDiag.log_2pi + 2. * logs + ((x - mean) ** 2) / torch.exp(2. * logs)) 451 | 452 | @staticmethod 453 | def logp(mean, logs, x): 454 | """ 455 | Summarized likehood 456 | 457 | :param mean: 458 | :type mean: torch.Tensor 459 | :param logs: 460 | :type logs: torch.Tensor 461 | :param x: input tensor 462 | :type x: torch.Tensor 463 | :return: 464 | :rtype: torch.Tensor 465 | """ 466 | s = GaussianDiag.logps(mean, logs, x) 467 | return GaussianDiag.flatten_sum(s) 468 | 469 | @staticmethod 470 | def sample(mean, logs, eps_std=None): 471 | """ 472 | Generate smaple 473 | 474 | :type mean: torch.Tensor 475 | :param logs: 476 | :type logs: torch.Tensor 477 | :param eps_std: standard deviation of eps 478 | :type eps_std: float 479 | :return: sample 480 | :rtype: torch.Tensor 481 | """ 482 | eps = GaussianDiag.eps(mean, eps_std) 483 | return mean + torch.exp(logs) * eps 484 | 485 | 486 | class Split2d(nn.Module): 487 | def __init__(self, num_channels): 488 | """ 489 | Split2d layer 490 | 491 | :param num_channels: number of channels 492 | :type num_channels: int 493 | """ 494 | super().__init__() 495 | self.num_channels = num_channels 496 | self.conv2d_zeros = Conv2dZeros(num_channels // 2, num_channels) 497 | 498 | def prior(self, z): 499 | """ 500 | Pre-process 501 | 502 | :param z: input tensor 503 | :type z: torch.Tensor 504 | :return: output tensor 505 | :rtype: torch.Tensor 506 | """ 507 | h = self.conv2d_zeros(z) 508 | mean, logs = ops.split_channel(h, 'cross') 509 | return mean, logs 510 | 511 | def forward(self, x, logdet=0., reverse=False, eps_std=None): 512 | """ 513 | Forward Split2d layer 514 | 515 | :param x: input tensor 516 | :type x: torch.Tensor 517 | :param logdet: log determinant 518 | :type logdet: float 519 | :param reverse: whether to reverse flow 520 | :type reverse: bool 521 | :param eps_std: standard deviation of eps 522 | :type eps_std: float 523 | :return: output and logdet 524 | :rtype: tuple(torch.Tensor, torch.Tensor) 525 | """ 526 | if not reverse: 527 | z1, z2 = ops.split_channel(x, 'simple') 528 | mean, logs = self.prior(z1) 529 | logdet = GaussianDiag.logp(mean, logs, z2) + logdet 530 | return z1, logdet 531 | else: 532 | z1 = x 533 | mean, logs = self.prior(z1) 534 | z2 = GaussianDiag.sample(mean, logs, eps_std) 535 | z = ops.cat_channel(z1, z2) 536 | return z, logdet 537 | 538 | 539 | class Squeeze2d(nn.Module): 540 | def __init__(self, factor=2): 541 | """ 542 | Squeeze2d layer 543 | 544 | :param factor: squeeze factor 545 | :type factor: int 546 | """ 547 | super().__init__() 548 | self.factor = factor 549 | 550 | @staticmethod 551 | def unsqueeze(x, factor=2): 552 | """ 553 | Unsqueeze tensor 554 | 555 | :param x: input tensor 556 | :type x: torch.Tensor 557 | :param factor: unsqueeze factor 558 | :type factor: int 559 | :return: unsqueezed tensor 560 | :rtype: torch.Tensor 561 | """ 562 | assert factor >= 1 563 | if factor == 1: 564 | return x 565 | _, nc, nh, nw = x.shape 566 | assert nc >= factor ** 2 and nc % factor ** 2 == 0 567 | x = x.view(-1, nc // factor ** 2, factor, factor, nh, nw) 568 | x = x.permute(0, 1, 4, 2, 5, 3).contiguous() 569 | x = x.view(-1, nc // factor ** 2, nh * factor, nw * factor) 570 | return x 571 | 572 | @staticmethod 573 | def squeeze(x, factor=2): 574 | """ 575 | Squeeze tensor 576 | 577 | :param x: input tensor 578 | :type x: torch.Tensor 579 | :param factor: squeeze factor 580 | :type factor: int 581 | :return: squeezed tensor 582 | :rtype: torch.Tensor 583 | """ 584 | assert factor >= 1 585 | if factor == 1: 586 | return x 587 | _, nc, nh, nw = x.shape 588 | assert nh % factor == 0 and nw % factor == 0 589 | x = x.view(-1, nc, nh // factor, factor, nw // factor, factor) 590 | x = x.permute(0, 1, 3, 5, 2, 4).contiguous() 591 | x = x.view(-1, nc * factor * factor, nh // factor, nw // factor) 592 | return x 593 | 594 | def forward(self, x, logdet=None, reverse=False): 595 | """ 596 | Forward Squeeze2d layer 597 | 598 | :param x: input tensor 599 | :type x: torch.Tensor 600 | :param logdet: log determinant 601 | :type logdet: 602 | :param reverse: whether to reverse flow 603 | :type reverse: bool 604 | :return: output and logdet 605 | :rtype: tuple(torch.Tensor, torch.Tensor) 606 | """ 607 | if not reverse: 608 | output = self.squeeze(x, self.factor) 609 | else: 610 | output = self.unsqueeze(x, self.factor) 611 | 612 | return output, logdet 613 | -------------------------------------------------------------------------------- /network/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | 5 | from tqdm import tqdm 6 | from tensorboardX import SummaryWriter 7 | from torch.utils.data import DataLoader 8 | 9 | from misc import util, ops 10 | from network.model import Glow 11 | 12 | 13 | class Trainer: 14 | criterion_dict = { 15 | 'single_class': lambda y_logits, y: Glow.single_class_loss(y_logits, y), 16 | 'multi_class': lambda y_logits, y_onehot: Glow.single_class_loss(y_logits, y_onehot) 17 | } 18 | 19 | def __init__(self, hps, result_subdir, 20 | step, graph, optimizer, scheduler, devices, 21 | dataset, data_device): 22 | """ 23 | Network trainer 24 | 25 | :param hps: hyper-parameters for this network 26 | :type hps: dict 27 | :param result_subdir: path to result sub-directory 28 | :type result_subdir: str 29 | :param step: global step of model 30 | :type step: int 31 | :param graph: model graph 32 | :type graph: torch.nn.Module 33 | :param optimizer: optimizer 34 | :type optimizer: torch.optim.Optimizer 35 | :param scheduler: learning rate scheduler 36 | :type scheduler: function 37 | :param devices: list of usable devices for model running 38 | :type devices: list 39 | :param dataset: dataset for training model 40 | :type dataset: torch.utils.data.Dataset 41 | :param data_device: 42 | :type data_device: 43 | """ 44 | # general 45 | self.hps = hps 46 | self.result_subdir = result_subdir 47 | self.start_time = time.time() 48 | # state 49 | self.step = step 50 | self.graph = graph 51 | self.optimizer = optimizer 52 | self.scheduler = scheduler 53 | self.devices = devices 54 | # data 55 | self.data_device = data_device 56 | self.batch_size = self.hps.optim.num_batch_train 57 | self.num_classes = self.hps.dataset.num_classes 58 | self.data_loader = DataLoader(dataset, batch_size=self.batch_size, 59 | num_workers=self.hps.dataset.num_workers, 60 | shuffle=True, 61 | drop_last=True) 62 | self.num_epochs = (self.hps.optim.num_epochs + len(self.data_loader) - 1) // len(self.data_loader) 63 | # ablation 64 | self.y_condition = self.hps.ablation.y_condition 65 | if self.y_condition: 66 | self.y_criterion = self.hps.ablation.y_criterion 67 | assert self.y_criterion in self.criterion_dict.keys(), "Unsupported criterion: {}".format(self.y_criterion) 68 | self.max_grad_clip = self.hps.ablation.max_grad_clip 69 | self.max_grad_norm = self.hps.ablation.max_grad_norm 70 | # logging 71 | self.writer = SummaryWriter(log_dir=self.result_subdir) 72 | self.interval_scalar = self.hps.optim.interval_scalar 73 | self.interval_snapshot = self.hps.optim.interval_snapshot 74 | self.interval_valid = self.hps.optim.interval_valid 75 | self.interval_sample = self.hps.optim.interval_sample 76 | self.num_sample = self.hps.optim.num_sample 77 | 78 | def train(self): 79 | """ 80 | Train network 81 | """ 82 | self.graph.train() 83 | 84 | for epoch in range(self.num_epochs): 85 | print('[Trainer] Epoch ({}/{})'.format(epoch, self.num_epochs)) 86 | progress = tqdm(self.data_loader) 87 | for idx, batch in enumerate(progress): 88 | # update learning rate 89 | lr = self.scheduler(global_step=self.step) 90 | for param_group in self.optimizer.param_groups: 91 | param_group['lr'] = lr 92 | self.optimizer.zero_grad() 93 | if self.step % self.interval_scalar == 0 and self.step > 0: 94 | self.writer.add_scalar('lr/lr', lr, self.step) 95 | 96 | # extract batch data 97 | for i in batch: 98 | batch[i] = batch[i].to(self.data_device) 99 | x = batch['x'] 100 | y = None 101 | y_onehot = None 102 | if self.y_condition: 103 | if self.y_criterion == 'single_class': 104 | assert 'y' in batch.keys(), 'Single-class criterion needs "y" in batch data' 105 | y = batch['y'] 106 | y_onehot = ops.onehot(y, self.num_classes) 107 | else: 108 | assert 'y_onehot' in batch.keys(), 'Multi-class criterion needs "y_onehot" in batch data' 109 | y_onehot = batch['y_onehot'] 110 | 111 | # initialize actnorm layer at first 112 | if self.step == 0: 113 | self.graph(x=x[:self.batch_size // len(self.devices), ...], 114 | y_onehot=y_onehot[:self.batch_size // len(self.devices), ...] 115 | if y_onehot is not None else None) 116 | # data parallel 117 | if len(self.devices) > 1 and not hasattr(self.graph, 'module'): 118 | self.graph = torch.nn.parallel.DataParallel(module=self.graph, 119 | device_ids=self.devices, 120 | output_device=self.devices[0]) 121 | 122 | # forward model 123 | z, nll, y_logits = self.graph(x=x, y_onehot=y_onehot) 124 | 125 | # compute loss 126 | generative_loss = Glow.generative_loss(nll) 127 | classification_loss = 0 128 | if self.y_condition: 129 | classification_loss = self.criterion_dict[self.y_criterion](y_logits, 130 | y if self.y_criterion == 'single_class' else y_onehot) 131 | loss = generative_loss + classification_loss * self.hps.model.weight_y 132 | if self.step % self.interval_scalar == 0 and self.step > 0: 133 | self.writer.add_scalar('loss/generative_loss', generative_loss, self.step) 134 | if self.y_condition: 135 | self.writer.add_scalar('loss/classification_loss', classification_loss, self.step) 136 | 137 | # backward model 138 | self.graph.zero_grad() 139 | self.optimizer.zero_grad() 140 | loss.backward() 141 | # gradient operation 142 | if self.max_grad_clip is not None and self.max_grad_clip > 0: 143 | torch.nn.utils.clip_grad_value_(self.graph.parameters(), self.max_grad_clip) 144 | if self.max_grad_norm is not None and self.max_grad_norm > 0: 145 | grad_norm = torch.nn.utils.clip_grad_norm_(self.graph.parameters(), self.max_grad_norm) 146 | if self.step % self.interval_scalar == 0 and self.step > 0: 147 | self.writer.add_scalar("grad_norm/grad_norm", grad_norm, self.step) 148 | 149 | # optimize 150 | self.optimizer.step() 151 | 152 | # snapshot 153 | if self.step % self.interval_snapshot == 0 and self.step > 0: 154 | util.save_model(result_subdir=self.result_subdir, 155 | step=self.step, 156 | graph=self.graph, 157 | optimizer=self.optimizer, 158 | seconds=time.time() - self.start_time, 159 | is_best=True) 160 | 161 | # valid 162 | if self.step % self.interval_valid == 0 and self.step > 0: 163 | img = self.graph(z=z, y_onehot=y_onehot, reverse=True) 164 | for i in range(min(self.num_sample, img.shape[0])): 165 | self.writer.add_image("reconstructed/{}".format(i), 166 | ops.cat_channel(img[i], batch["x"][i]), 167 | self.step) 168 | 169 | # sample 170 | if self.step % self.interval_sample == 0 and self.step > 0: 171 | img = self.graph(z=None, y_onehot=y_onehot, eps_std=0.5, reverse=True) 172 | for i in range(min(self.num_sample, img.shape[0])): 173 | self.writer.add_image("sample/{}".format(i), 174 | img[i], self.step) 175 | 176 | self.step += 1 177 | 178 | self.writer.export_scalars_to_json(os.path.join(self.result_subdir, "all_scalars.json")) 179 | self.writer.close() 180 | -------------------------------------------------------------------------------- /profile/celeba.json: -------------------------------------------------------------------------------- 1 | { 2 | "profile": "celeba_64x64_8bit", 3 | "general": { 4 | "verbose": false, 5 | "result_dir": "/Data/glow", 6 | "warm_start": false, 7 | "pre_trained": "", 8 | "resume_run_id": 1, 9 | "resume_step": "latest" 10 | }, 11 | "dataset": { 12 | "problem": "celeba", 13 | "root": "/Data/CelebA", 14 | "num_classes": 40, 15 | "num_workers": 8, 16 | "argument": "standard" 17 | }, 18 | "optim": { 19 | "num_epochs": 1000000, 20 | "num_train": 50000, 21 | "num_test": -1, 22 | "num_sample": 4, 23 | "interval_scalar": 10, 24 | "interval_snapshot": 5000, 25 | "interval_valid": 10, 26 | "interval_sample": 10, 27 | "num_batch_train": 50, 28 | "num_batch_test": 50, 29 | "num_batch_init": 256, 30 | "optimizer": "adam", 31 | "optimizer_args": { 32 | "lr": 1e-3, 33 | "betas": [ 34 | 0.9, 35 | 0.9999 36 | ], 37 | "eps": 1e-8, 38 | "weight_decay": 0 39 | }, 40 | "lr_scheduler": "noam", 41 | "lr_scheduler_args": { 42 | "warmup_steps": 4000, 43 | "min_lr": 1e-4 44 | }, 45 | "gradient_checkpointing": true 46 | }, 47 | "model": { 48 | "image_shape": [ 49 | 64, 50 | 64, 51 | 3 52 | ], 53 | "anchor_size": 32, 54 | "hidden_channels": 512, 55 | "actnorm_scale": 1.0, 56 | "K": 32, 57 | "L": 3, 58 | "weight_y": 0.0, 59 | "n_bits_x": 8 60 | }, 61 | "ablation": { 62 | "learn_top": false, 63 | "y_condition": false, 64 | "y_criterion": "multi_classes", 65 | "lu_decomposition": false, 66 | "seed": 2384, 67 | "flow_permutation": "invconv", 68 | "flow_coupling": "affine", 69 | "max_grad_clip": 5, 70 | "max_grad_norm": 100 71 | }, 72 | "device": { 73 | "graph": [ 74 | "cuda:0", 75 | "cuda:1" 76 | ], 77 | "data": [ 78 | "cuda:0" 79 | ] 80 | } 81 | } 82 | 83 | -------------------------------------------------------------------------------- /profile/test.json: -------------------------------------------------------------------------------- 1 | { 2 | "profile": "celebahq_256x256_5bit", 3 | "general": { 4 | "verbose": false, 5 | "result_dir": ".", 6 | "warm_start": true, 7 | "pre_trained": "", 8 | "resume_run_id": 0, 9 | "resume_step": "latest" 10 | }, 11 | "dataset": { 12 | "problem": "celeba", 13 | "root": "/Data/CelebA", 14 | "num_classes": 1, 15 | "num_workers": 8, 16 | "argument": "standard" 17 | }, 18 | "optim": { 19 | "num_epochs": 1000000, 20 | "num_train": 50000, 21 | "num_test": -1, 22 | "num_sample": 1, 23 | "interval_scalar": 50, 24 | "interval_snapshot": 50, 25 | "interval_valid": 50, 26 | "interval_sample": 50, 27 | "num_batch_train": 16, 28 | "num_batch_test": 50, 29 | "num_batch_init": 256, 30 | "optimizer": "adamax", 31 | "optimizer_args": { 32 | "lr": 1e-3, 33 | "betas": [ 34 | 0.9, 35 | 0.99 36 | ], 37 | "eps": 1e-8, 38 | "weight_decay": 0 39 | }, 40 | "lr_scheduler": "noam", 41 | "lr_scheduler_args": { 42 | "warmup_steps": 4000 43 | }, 44 | "gradient_checkpointing": true 45 | }, 46 | "model": { 47 | "image_shape": [ 48 | 64, 49 | 64, 50 | 3 51 | ], 52 | "anchor_size": 32, 53 | "hidden_channels": 512, 54 | "actnorm_scale": 1.0, 55 | "K": 32, 56 | "L": 3, 57 | "weight_y": 0.0, 58 | "n_bits_x": 8 59 | }, 60 | "ablation": { 61 | "learn_top": false, 62 | "y_condition": false, 63 | "y_criterion": "", 64 | "lu_decomposition": false, 65 | "seed": 0, 66 | "flow_permutation": "invconv", 67 | "flow_coupling": "additive", 68 | "max_grad_clip": 5, 69 | "max_grad_norm": 100 70 | }, 71 | "device": { 72 | "graph": [ 73 | "cuda:0", 74 | "cuda:1" 75 | ], 76 | "data": [ 77 | "cpu" 78 | ] 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | click 3 | numpy 4 | torch 5 | easydict 6 | torchvision 7 | tensorboardX -------------------------------------------------------------------------------- /result/interpolated_Attractive.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/corenel/pytorch-glow/dfd6db37da093d7741585ce1d046b2d3c3ac99bd/result/interpolated_Attractive.png -------------------------------------------------------------------------------- /result/interpolated_Black_Hair.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/corenel/pytorch-glow/dfd6db37da093d7741585ce1d046b2d3c3ac99bd/result/interpolated_Black_Hair.png -------------------------------------------------------------------------------- /result/interpolated_Blurry.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/corenel/pytorch-glow/dfd6db37da093d7741585ce1d046b2d3c3ac99bd/result/interpolated_Blurry.png -------------------------------------------------------------------------------- /result/interpolated_Mouth_Slightly_Open.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/corenel/pytorch-glow/dfd6db37da093d7741585ce1d046b2d3c3ac99bd/result/interpolated_Mouth_Slightly_Open.png -------------------------------------------------------------------------------- /result/reconstructed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/corenel/pytorch-glow/dfd6db37da093d7741585ce1d046b2d3c3ac99bd/result/reconstructed.png -------------------------------------------------------------------------------- /setting.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/corenel/pytorch-glow/dfd6db37da093d7741585ce1d046b2d3c3ac99bd/setting.py -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/corenel/pytorch-glow/dfd6db37da093d7741585ce1d046b2d3c3ac99bd/test/__init__.py -------------------------------------------------------------------------------- /test/test_dataset.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from dataset import CelebA 4 | 5 | 6 | class TestDataset(unittest.TestCase): 7 | def test_celeba(self): 8 | dataset = CelebA(root='/Data/CelebA') 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /test/test_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import unittest 4 | 5 | from PIL import Image 6 | 7 | from network.model import FlowStep, FlowModel, Glow 8 | from misc import ops, util 9 | 10 | 11 | class TestModel(unittest.TestCase): 12 | def test_flow_step(self): 13 | flow_permutation = ['invconv', 'reverse', 'shuffle'] 14 | flow_coupling = ['additive', 'affine'] 15 | 16 | for permutation in flow_permutation: 17 | for coupling in flow_coupling: 18 | # initial variables 19 | x = torch.Tensor(np.random.rand(2, 16, 4, 4)) 20 | flow_step = FlowStep( 21 | in_channels=16, 22 | hidden_channels=256, 23 | permutation=permutation, 24 | coupling=coupling, 25 | actnorm_scale=1., 26 | lu_decomposition=False 27 | ) 28 | # forward and reverse flow 29 | y, det = flow_step(x, 0, reverse=False) 30 | x_, det_ = flow_step(y, det, reverse=True) 31 | # assertion 32 | self.assertTrue(ops.tensor_equal(x, x_)) 33 | 34 | def test_flow_model(self): 35 | flow_permutation = ['invconv', 'reverse', 'shuffle'] 36 | flow_coupling = ['additive', 'affine'] 37 | 38 | for permutation in flow_permutation: 39 | for coupling in flow_coupling: 40 | # initial variables 41 | x = torch.Tensor(np.random.rand(2, 3, 16, 16)) 42 | flow_model = FlowModel( 43 | in_shape=(16, 16, 3), 44 | hidden_channels=256, 45 | K=16, L=3, 46 | permutation=permutation, 47 | coupling=coupling, 48 | actnorm_scale=1., 49 | lu_decomposition=False 50 | ) 51 | # forward and reverse flow 52 | y, det = flow_model(x, 0, reverse=False) 53 | x_ = flow_model(y, det, reverse=True) 54 | # assertion 55 | self.assertEqual(x.shape, x_.shape) 56 | self.assertTupleEqual((2, 48, 2, 2), tuple(y.shape)) 57 | 58 | def test_glow_model(self): 59 | # build model 60 | hps = util.load_profile('profile/test.json') 61 | glow_model = Glow(hps).cuda() 62 | image_shape = hps.model.image_shape 63 | batch_size = glow_model.h_top.shape[0] 64 | # read image 65 | img = Image.open('misc/test.png').convert('RGB') 66 | x = util.pil_to_tensor(img, shape=image_shape) 67 | x = util.make_batch(x, batch_size).cuda() 68 | y_onehot = torch.zeros((batch_size, hps.dataset.num_classes)).cuda() 69 | # forward and reverse flow 70 | z, logdet, y_logits = glow_model(x=x, y_onehot=y_onehot, reverse=False) 71 | x_ = glow_model(z=z, y_onehot=y_onehot, reverse=True) 72 | 73 | 74 | if __name__ == '__main__': 75 | unittest.main() 76 | -------------------------------------------------------------------------------- /test/test_module.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import unittest 4 | 5 | from network.module import (ActNorm, LinearZeros, Conv2d, Conv2dZeros, 6 | Invertible1x1Conv, Permutation2d, Split2d, Squeeze2d) 7 | from misc import ops 8 | 9 | 10 | class TestModule(unittest.TestCase): 11 | 12 | def test_actnorm(self): 13 | # initial variables 14 | x = torch.Tensor(np.random.rand(2, 16, 4, 4)) 15 | actnorm = ActNorm(num_channels=16) 16 | # forward and reverse flow 17 | y, _ = actnorm(x) 18 | x_, _ = actnorm(y, reverse=True) 19 | # assertion 20 | self.assertTrue(ops.tensor_equal(x, x_)) 21 | 22 | def test_linear_zeros(self): 23 | # initial variables 24 | x = torch.Tensor(np.random.rand(16)) 25 | linear_zeros = LinearZeros(16, 16) 26 | # forward 27 | y = linear_zeros(x) 28 | # assertion 29 | self.assertTrue(torch.equal(y, torch.zeros(16))) 30 | 31 | def test_conv2d(self): 32 | # initial variables 33 | x = torch.Tensor(np.random.rand(2, 16, 4, 4)) 34 | conv2d = Conv2d(in_channels=16, out_channels=5) 35 | # forward and reverse flow 36 | y = conv2d(x) 37 | # assertion 38 | self.assertTupleEqual((2, 5, 4, 4), tuple(y.shape)) 39 | 40 | def test_conv2d_zeros(self): 41 | # initial variables 42 | x = torch.Tensor(np.random.rand(2, 16, 4, 4)) 43 | conv2d_zeros = Conv2dZeros(in_channels=16, out_channels=5) 44 | # forward and reverse flow 45 | y = conv2d_zeros(x) 46 | # assertion 47 | self.assertTupleEqual((5, 16), tuple(conv2d_zeros.weight.shape[:2])) 48 | self.assertTupleEqual((2, 5, 4, 4), tuple(y.shape)) 49 | 50 | def test_invertible_1x1_conv(self): 51 | # initial variables 52 | x = torch.Tensor(np.random.rand(2, 16, 4, 4)) 53 | invertible_1x1_conv = Invertible1x1Conv(num_channels=16) 54 | # forward and reverse flow 55 | y, _ = invertible_1x1_conv(x) 56 | x_, _ = invertible_1x1_conv(y, reverse=True) 57 | # assertion 58 | self.assertEqual(x.shape, y.shape) 59 | self.assertTrue(ops.tensor_equal(x, x_)) 60 | 61 | def test_permutation2d(self): 62 | # initial variables 63 | x = torch.Tensor(np.random.rand(2, 16, 4, 4)) 64 | reverse = Permutation2d(num_channels=16) 65 | shuffle = Permutation2d(num_channels=16, shuffle=True) 66 | # forward and reverse flow 67 | y_reverse = reverse(x) 68 | x_reverse = reverse(y_reverse, reverse=True) 69 | y_shuffle = shuffle(x) 70 | x_shuffle = shuffle(y_shuffle, reverse=True) 71 | # assertion 72 | self.assertTrue(ops.tensor_equal(x, x_reverse)) 73 | self.assertTrue(ops.tensor_equal(x, x_shuffle)) 74 | 75 | def test_squeeze2d(self): 76 | # initial variables 77 | x = torch.Tensor(np.random.rand(2, 16, 4, 4)) 78 | squeeze = Squeeze2d(factor=2) 79 | # forward and reverse flow 80 | y, _ = squeeze(x) 81 | x_, _ = squeeze(y, reverse=True) 82 | # assertion 83 | self.assertTrue(ops.tensor_equal(x, x_)) 84 | 85 | def test_split2d(self): 86 | # initial variables 87 | x = torch.Tensor(np.random.rand(2, 16, 4, 4)) 88 | split2d = Split2d(num_channels=16) 89 | # forward and reverse flow 90 | y, _ = split2d(x, 0, reverse=False) 91 | x_, _ = split2d(y, 0, reverse=True) 92 | # assertion 93 | self.assertTrue(ops.tensor_equal(x[:, :x.shape[1] // 2, :, :], 94 | x_[:, :x_.shape[1] // 2, :, :])) 95 | 96 | 97 | if __name__ == '__main__': 98 | unittest.main() 99 | -------------------------------------------------------------------------------- /test/test_ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import unittest 4 | 5 | from misc import ops 6 | 7 | 8 | class TestOps(unittest.TestCase): 9 | def test_tensor_equal(self): 10 | x = torch.Tensor(np.random.rand(2, 16, 4, 4)) 11 | x_ = x + 1e-4 12 | self.assertTrue(ops.tensor_equal(x, x)) 13 | self.assertFalse(ops.tensor_equal(x, x_)) 14 | 15 | def test_reduce_mean(self): 16 | x = torch.ones(2, 3, 16, 16) 17 | mean = ops.reduce_mean(x, dim=[1, 2, 3]) 18 | self.assertTrue(ops.tensor_equal(torch.ones(2), mean)) 19 | 20 | def test_reduce_sum(self): 21 | x = torch.ones(2, 3, 16, 16) 22 | sum = ops.reduce_sum(x, dim=[1, 2, 3]) 23 | sum_shape = float(x.shape[1] * x.shape[2] * x.shape[3]) 24 | self.assertTrue(ops.tensor_equal(torch.Tensor([sum_shape, sum_shape]), sum)) 25 | 26 | def test_split_channel(self): 27 | x = torch.ones(2, 4, 16, 16) 28 | nc = x.shape[1] 29 | # simple splitting 30 | x1, x2 = ops.split_channel(x, 'simple') 31 | for c in range(nc // 2): 32 | self.assertTrue(ops.tensor_equal(x1[:, c, :, :], x[:, c, :, :])) 33 | self.assertTrue(ops.tensor_equal(x2[:, c, :, :], x[:, nc // 2 + c, :, :])) 34 | # cross splitting 35 | x1, x2 = ops.split_channel(x, 'cross') 36 | for c in range(nc // 2): 37 | self.assertTrue(ops.tensor_equal(x1[:, c, :, :], x[:, 2 * c, :, :])) 38 | self.assertTrue(ops.tensor_equal(x2[:, c, :, :], x[:, 2 * c + 1, :, :])) 39 | 40 | def test_cat_channel(self): 41 | x = torch.ones(2, 4, 16, 16) 42 | x1, x2 = ops.split_channel(x, 'simple') 43 | self.assertTrue(ops.tensor_equal(ops.cat_channel(x1, x2), x)) 44 | 45 | def test_count_pixels(self): 46 | x = torch.Tensor(np.random.rand(2, 16, 4, 4)) 47 | nh, nw = x.shape[2], x.shape[3] 48 | self.assertEqual(nh * nw, ops.count_pixels(x)) 49 | 50 | 51 | if __name__ == '__main__': 52 | unittest.main() 53 | -------------------------------------------------------------------------------- /test/test_util.py: -------------------------------------------------------------------------------- 1 | from misc import util 2 | import torch 3 | import unittest 4 | 5 | 6 | class TestUtil(unittest.TestCase): 7 | def test_load_profile(self): 8 | hps = util.load_profile('profile/celebahq_256x256_5bit.json') 9 | self.assertIsInstance(hps, dict) 10 | 11 | def test_get_devices(self): 12 | # use cpu only 13 | devices = ['cpu'] 14 | self.assertListEqual(devices, util.get_devices(devices)) 15 | # use gpu 16 | cuda_available = torch.cuda.is_available() 17 | gpu_count = torch.cuda.device_count() 18 | if cuda_available and gpu_count >= 1: 19 | self.assertListEqual([0], util.get_devices(['cuda:0'])) 20 | self.assertListEqual([0], util.get_devices(['cuda:0', 'cuda:1'])) 21 | self.assertListEqual(['cpu'], util.get_devices(['cuda:1'])) 22 | with self.assertRaises(AssertionError): 23 | util.get_devices(['cuda:1', 'cpu']) 24 | if cuda_available and gpu_count >= 2: 25 | devices = ['cuda:0', 'cuda:1'] 26 | self.assertListEqual([0, 1], util.get_devices(devices)) 27 | 28 | def test_result_subdir(self): 29 | result_dir = '/tmp' 30 | result_subdir = util.create_result_subdir(result_dir, 31 | desc='test', 32 | profile={}) 33 | util.locate_result_subdir(result_dir, result_subdir) 34 | util.locate_result_subdir(result_dir, 0) 35 | util.locate_result_subdir(result_dir, '000') 36 | 37 | 38 | if __name__ == '__main__': 39 | unittest.main() 40 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import signal 3 | import argparse 4 | 5 | from torchvision import transforms 6 | 7 | from misc import util 8 | from network import Builder, Trainer 9 | from dataset import CelebA 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser( 14 | description='PyTorch implementation of "Glow: Generative Flow with Invertible 1x1 Convolutions"') 15 | parser.add_argument('profile', type=str, 16 | default='profile/celeba.json', 17 | help='path to profile file') 18 | return parser.parse_args() 19 | 20 | 21 | if __name__ == '__main__': 22 | # this enables a Ctrl-C without triggering errors 23 | signal.signal(signal.SIGINT, lambda x, y: sys.exit(0)) 24 | 25 | # parse arguments 26 | args = parse_args() 27 | 28 | # initialize logging 29 | util.init_output_logging() 30 | 31 | # load hyper-parameters 32 | hps = util.load_profile(args.profile) 33 | util.manual_seed(hps.ablation.seed) 34 | 35 | # build graph 36 | builder = Builder(hps) 37 | state = builder.build() 38 | 39 | # load dataset 40 | dataset = CelebA(root=hps.dataset.root, 41 | transform=transforms.Compose(( 42 | transforms.CenterCrop(160), 43 | transforms.Resize(64), 44 | transforms.ToTensor() 45 | ))) 46 | 47 | # start training 48 | trainer = Trainer(hps=hps, dataset=dataset, **state) 49 | trainer.train() 50 | --------------------------------------------------------------------------------