├── .gitignore ├── README.md ├── assets ├── board-2017-04-04.png └── model.png ├── config.py ├── data_loader.py ├── download.py ├── folder.py ├── main.py ├── models.py ├── test.png ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Data 2 | data/hand 3 | data/gaze 4 | data/* 5 | samples 6 | outputs 7 | 8 | # ipython checkpoints 9 | .ipynb_checkpoints 10 | 11 | # Log 12 | logs 13 | 14 | # ETC 15 | paper.pdf 16 | .DS_Store 17 | 18 | # Created by https://www.gitignore.io/api/python,vim 19 | 20 | ### Python ### 21 | # Byte-compiled / optimized / DLL files 22 | __pycache__/ 23 | *.py[cod] 24 | *$py.class 25 | 26 | # C extensions 27 | *.so 28 | 29 | # Distribution / packaging 30 | .Python 31 | env/ 32 | build/ 33 | develop-eggs/ 34 | dist/ 35 | downloads/ 36 | eggs/ 37 | .eggs/ 38 | lib/ 39 | lib64/ 40 | parts/ 41 | sdist/ 42 | var/ 43 | wheels/ 44 | *.egg-info/ 45 | .installed.cfg 46 | *.egg 47 | 48 | # PyInstaller 49 | # Usually these files are written by a python script from a template 50 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 51 | *.manifest 52 | *.spec 53 | 54 | # Installer logs 55 | pip-log.txt 56 | pip-delete-this-directory.txt 57 | 58 | # Unit test / coverage reports 59 | htmlcov/ 60 | .tox/ 61 | .coverage 62 | .coverage.* 63 | .cache 64 | nosetests.xml 65 | coverage.xml 66 | *,cover 67 | .hypothesis/ 68 | 69 | # Translations 70 | *.mo 71 | *.pot 72 | 73 | # Django stuff: 74 | *.log 75 | local_settings.py 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # celery beat schedule file 97 | celerybeat-schedule 98 | 99 | # dotenv 100 | .env 101 | 102 | # virtualenv 103 | .venv/ 104 | venv/ 105 | ENV/ 106 | 107 | # Spyder project settings 108 | .spyderproject 109 | 110 | # Rope project settings 111 | .ropeproject 112 | 113 | 114 | ### Vim ### 115 | # swap 116 | [._]*.s[a-v][a-z] 117 | [._]*.sw[a-p] 118 | [._]s[a-v][a-z] 119 | [._]sw[a-p] 120 | # session 121 | Session.vim 122 | # temporary 123 | .netrwhist 124 | *~ 125 | # auto-generated tag files 126 | tags 127 | 128 | # End of https://www.gitignore.io/api/python,vim 129 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BEGAN in PyTorch 2 | 3 | **This project is still in progress. If you are looking for the working code, use** [BEGAN-tensorflow](https://github.com/carpedm20/BEGAN-tensorflow). 4 | 5 | 6 | ## Requirements 7 | 8 | - Python 2.7 9 | - [Pillow](https://pillow.readthedocs.io/en/4.0.x/) 10 | - [tqdm](https://github.com/tqdm/tqdm) 11 | - [PyTorch](https://github.com/pytorch/pytorch) 12 | - [torch-vision](https://github.com/pytorch/vision) 13 | - [requests](https://github.com/kennethreitz/requests) (Only used for downloading CelebA dataset) 14 | - [TensorFlow](https://github.com/tensorflow/tensorflow) (Only used TensorBoard for logging) 15 | 16 | 17 | ## Usage 18 | 19 | First download [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) datasets with: 20 | 21 | $ apt-get install p7zip-full # ubuntu 22 | $ brew install p7zip # Mac 23 | $ python download.py 24 | 25 | or you can use your own dataset by placing images like: 26 | 27 | data 28 | └── YOUR_DATASET_NAME 29 | ├── xxx.jpg (name doesn't matter) 30 | ├── yyy.jpg 31 | └── ... 32 | 33 | To train a model: 34 | 35 | $ python main.py --dataset=CelebA --num_gpu=1 36 | $ python main.py --dataset=YOUR_DATASET_NAME --num_gpu=4 --use_tensorboard=True 37 | 38 | To test a model (use your `load_path`): 39 | 40 | $ python main.py --dataset=CelebA --load_path=./logs/CelebA_0405_124806 --num_gpu=0 --is_train=False --split valid 41 | 42 | 43 | ## Results 44 | 45 | - [BEGAN-tensorflow](https://github.com/carpedm20/began-tensorflow) at least can generate human faces but [BEGAN-pytorch](https://github.com/carpedm20/BEGAN-pytorch) can't. 46 | - Both [BEGAN-tensorflow](https://github.com/carpedm20/began-tensorflow) and [BEGAN-pytorch](https://github.com/carpedm20/BEGAN-pytorch) shows **modal collapses** and I guess this is due to a wrong scheuduling of lr (Paper mentioned that *simply reducing the lr was sufficient to avoid them*). 47 | 48 | ![alt tag](./assets/board-2017-04-04.png) 49 | 50 | (in progress) 51 | 52 | 53 | ## Author 54 | 55 | Taehoon Kim / [@carpedm20](http://carpedm20.github.io) 56 | -------------------------------------------------------------------------------- /assets/board-2017-04-04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-pytorch/487d7fc1403772c307f2bb16f9fb637a038175c1/assets/board-2017-04-04.png -------------------------------------------------------------------------------- /assets/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-pytorch/487d7fc1403772c307f2bb16f9fb637a038175c1/assets/model.png -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | import argparse 3 | 4 | def str2bool(v): 5 | return v.lower() in ('true', '1') 6 | 7 | arg_lists = [] 8 | parser = argparse.ArgumentParser() 9 | 10 | def add_argument_group(name): 11 | arg = parser.add_argument_group(name) 12 | arg_lists.append(arg) 13 | return arg 14 | 15 | # Network 16 | net_arg = add_argument_group('Network') 17 | net_arg.add_argument('--input_scale_size', type=int, default=64, 18 | help='input image will be resized with the given value as width and height') 19 | net_arg.add_argument('--conv_hidden_num', type=int, default=128, help='n in the paper') 20 | net_arg.add_argument('--z_num', type=int, default=128) 21 | 22 | # Data 23 | data_arg = add_argument_group('Data') 24 | data_arg.add_argument('--dataset', type=str, default='CelebA') 25 | data_arg.add_argument('--split', type=str, default='train') 26 | data_arg.add_argument('--batch_size', type=int, default=16) 27 | data_arg.add_argument('--grayscale', type=str2bool, default=False) 28 | data_arg.add_argument('--num_worker', type=int, default=12) 29 | 30 | # Training / test parameters 31 | train_arg = add_argument_group('Training') 32 | train_arg.add_argument('--is_train', type=str2bool, default=True) 33 | train_arg.add_argument('--optimizer', type=str, default='adam') 34 | train_arg.add_argument('--max_step', type=int, default=500000) 35 | train_arg.add_argument('--lr_update_step', type=int, default=10000) 36 | train_arg.add_argument('--lr', type=float, default=0.0001) 37 | train_arg.add_argument('--beta1', type=float, default=0.5) 38 | train_arg.add_argument('--beta2', type=float, default=0.999) 39 | train_arg.add_argument('--gamma', type=float, default=0.5) 40 | train_arg.add_argument('--lambda_k', type=float, default=0.001) 41 | 42 | # Misc 43 | misc_arg = add_argument_group('Misc') 44 | misc_arg.add_argument('--load_path', type=str, default='') 45 | misc_arg.add_argument('--log_step', type=int, default=50) 46 | misc_arg.add_argument('--save_step', type=int, default=5000) 47 | misc_arg.add_argument('--num_log_samples', type=int, default=3) 48 | misc_arg.add_argument('--log_level', type=str, default='INFO', choices=['INFO', 'DEBUG', 'WARN']) 49 | misc_arg.add_argument('--log_dir', type=str, default='logs') 50 | misc_arg.add_argument('--data_dir', type=str, default='data') 51 | misc_arg.add_argument('--num_gpu', type=int, default=1) 52 | misc_arg.add_argument('--test_data_path', type=str, default=None, 53 | help='directory with images which will be used in test sample generation') 54 | misc_arg.add_argument('--sample_per_image', type=int, default=64, 55 | help='# of sample per image during test sample generation') 56 | misc_arg.add_argument('--random_seed', type=int, default=123) 57 | misc_arg.add_argument('--use_tensorboard', type=str2bool, default=False) 58 | 59 | def get_config(): 60 | config, unparsed = parser.parse_known_args() 61 | return config, unparsed 62 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from glob import glob 4 | from PIL import Image 5 | from tqdm import tqdm 6 | 7 | import torch 8 | from torchvision import transforms 9 | import torchvision.datasets as dset 10 | 11 | from folder import ImageFolder 12 | 13 | def get_loader(root, split, batch_size, scale_size, num_workers=2, shuffle=True): 14 | dataset_name = os.path.basename(root) 15 | image_root = os.path.join(root, 'splits', split) 16 | 17 | if dataset_name in ['CelebA']: 18 | dataset = ImageFolder(root=image_root, transform=transforms.Compose([ 19 | transforms.CenterCrop(160), 20 | transforms.Scale(scale_size), 21 | transforms.ToTensor(), 22 | #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 23 | ])) 24 | else: 25 | dataset = ImageFolder(root=image_root, transform=transforms.Compose([ 26 | transforms.Scale(scale_size), 27 | transforms.ToTensor(), 28 | #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 29 | ])) 30 | 31 | data_loader = torch.utils.data.DataLoader( 32 | dataset, batch_size=batch_size, shuffle=shuffle, num_workers=int(num_workers)) 33 | data_loader.shape = [int(num) for num in dataset[0][0].size()] 34 | 35 | return data_loader 36 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modification of 3 | - https://github.com/carpedm20/DCGAN-tensorflow/blob/master/download.py 4 | - http://stackoverflow.com/a/39225039 5 | """ 6 | from __future__ import print_function 7 | import os 8 | import zipfile 9 | import requests 10 | import subprocess 11 | from tqdm import tqdm 12 | from collections import OrderedDict 13 | 14 | def download_file_from_google_drive(id, destination): 15 | URL = "https://docs.google.com/uc?export=download" 16 | session = requests.Session() 17 | 18 | response = session.get(URL, params={ 'id': id }, stream=True) 19 | token = get_confirm_token(response) 20 | 21 | if token: 22 | params = { 'id' : id, 'confirm' : token } 23 | response = session.get(URL, params=params, stream=True) 24 | 25 | save_response_content(response, destination) 26 | 27 | def get_confirm_token(response): 28 | for key, value in response.cookies.items(): 29 | if key.startswith('download_warning'): 30 | return value 31 | return None 32 | 33 | def save_response_content(response, destination, chunk_size=32*1024): 34 | total_size = int(response.headers.get('content-length', 0)) 35 | with open(destination, "wb") as f: 36 | for chunk in tqdm(response.iter_content(chunk_size), total=total_size, 37 | unit='B', unit_scale=True, desc=destination): 38 | if chunk: # filter out keep-alive new chunks 39 | f.write(chunk) 40 | 41 | def unzip(filepath): 42 | print("Extracting: " + filepath) 43 | base_path = os.path.dirname(filepath) 44 | with zipfile.ZipFile(filepath) as zf: 45 | zf.extractall(base_path) 46 | os.remove(filepath) 47 | 48 | def download_celeb_a(base_path): 49 | data_path = os.path.join(base_path, 'CelebA') 50 | images_path = os.path.join(data_path, 'images') 51 | if os.path.exists(data_path): 52 | print('[!] Found Celeb-A - skip') 53 | return 54 | 55 | filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM" 56 | save_path = os.path.join(base_path, filename) 57 | 58 | if os.path.exists(save_path): 59 | print('[*] {} already exists'.format(save_path)) 60 | else: 61 | download_file_from_google_drive(drive_id, save_path) 62 | 63 | zip_dir = '' 64 | with zipfile.ZipFile(save_path) as zf: 65 | zip_dir = zf.namelist()[0] 66 | zf.extractall(base_path) 67 | if not os.path.exists(data_path): 68 | os.mkdir(data_path) 69 | os.rename(os.path.join(base_path, "img_align_celeba"), images_path) 70 | os.remove(save_path) 71 | 72 | def prepare_data_dir(path = './data'): 73 | if not os.path.exists(path): 74 | os.mkdir(path) 75 | 76 | # check, if file exists, make link 77 | def check_link(in_dir, basename, out_dir): 78 | in_file = os.path.join(in_dir, basename) 79 | if os.path.exists(in_file): 80 | link_file = os.path.join(out_dir, basename) 81 | rel_link = os.path.relpath(in_file, out_dir) 82 | os.symlink(rel_link, link_file) 83 | 84 | def add_splits(base_path): 85 | data_path = os.path.join(base_path, 'CelebA') 86 | images_path = os.path.join(data_path, 'images') 87 | train_dir = os.path.join(data_path, 'splits', 'train') 88 | valid_dir = os.path.join(data_path, 'splits', 'valid') 89 | test_dir = os.path.join(data_path, 'splits', 'test') 90 | if not os.path.exists(train_dir): 91 | os.makedirs(train_dir) 92 | if not os.path.exists(valid_dir): 93 | os.makedirs(valid_dir) 94 | if not os.path.exists(test_dir): 95 | os.makedirs(test_dir) 96 | 97 | # these constants based on the standard CelebA splits 98 | NUM_EXAMPLES = 202599 99 | TRAIN_STOP = 162770 100 | VALID_STOP = 182637 101 | 102 | for i in range(0, TRAIN_STOP): 103 | basename = "{:06d}.jpg".format(i+1) 104 | check_link(images_path, basename, train_dir) 105 | for i in range(TRAIN_STOP, VALID_STOP): 106 | basename = "{:06d}.jpg".format(i+1) 107 | check_link(images_path, basename, valid_dir) 108 | for i in range(VALID_STOP, NUM_EXAMPLES): 109 | basename = "{:06d}.jpg".format(i+1) 110 | check_link(images_path, basename, test_dir) 111 | 112 | if __name__ == '__main__': 113 | base_path = './data' 114 | prepare_data_dir() 115 | download_celeb_a(base_path) 116 | add_splits(base_path) 117 | -------------------------------------------------------------------------------- /folder.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | 7 | IMG_EXTENSIONS = [ 8 | '.jpg', '.JPG', '.jpeg', '.JPEG', 9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 10 | ] 11 | 12 | def is_image_file(filename): 13 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 14 | 15 | def make_dataset(dir): 16 | images = [] 17 | for root, _, fnames in sorted(os.walk(dir)): 18 | for fname in sorted(fnames): 19 | if is_image_file(fname): 20 | path = os.path.join(root, fname) 21 | item = (path, 0) 22 | images.append(item) 23 | 24 | return images 25 | 26 | def default_loader(path): 27 | return Image.open(path).convert('RGB') 28 | 29 | class ImageFolder(data.Dataset): 30 | 31 | def __init__(self, root, transform=None, target_transform=None, 32 | loader=default_loader): 33 | imgs = make_dataset(root) 34 | if len(imgs) == 0: 35 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 36 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 37 | 38 | print("Found {} images in subfolders of: {}".format(len(imgs), root)) 39 | 40 | self.root = root 41 | self.imgs = imgs 42 | self.transform = transform 43 | self.target_transform = target_transform 44 | self.loader = loader 45 | 46 | def __getitem__(self, index): 47 | path, target = self.imgs[index] 48 | img = self.loader(path) 49 | if self.transform is not None: 50 | img = self.transform(img) 51 | if self.target_transform is not None: 52 | target = self.target_transform(target) 53 | 54 | return img, target 55 | 56 | def __len__(self): 57 | return len(self.imgs) 58 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from trainer import Trainer 4 | from config import get_config 5 | from data_loader import get_loader 6 | from utils import prepare_dirs_and_logger, save_config 7 | 8 | def main(config): 9 | prepare_dirs_and_logger(config) 10 | 11 | torch.manual_seed(config.random_seed) 12 | if config.num_gpu > 0: 13 | torch.cuda.manual_seed(config.random_seed) 14 | 15 | if config.is_train: 16 | data_path = config.data_path 17 | batch_size = config.batch_size 18 | do_shuffle = True 19 | else: 20 | if config.test_data_path is None: 21 | data_path = config.data_path 22 | else: 23 | data_path = config.test_data_path 24 | batch_size = config.sample_per_image 25 | do_shuffle = False 26 | 27 | data_loader = get_loader( 28 | data_path, config.split, batch_size, config.input_scale_size, config.num_worker, do_shuffle) 29 | 30 | trainer = Trainer(config, data_loader) 31 | 32 | if config.is_train: 33 | save_config(config) 34 | trainer.train() 35 | else: 36 | if not config.load_path: 37 | raise Exception("[!] You should specify `load_path` to load a pretrained model") 38 | trainer.test() 39 | 40 | if __name__ == "__main__": 41 | config, unparsed = get_config() 42 | main(config) 43 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from torch.utils.data import TensorDataset, DataLoader 7 | 8 | class BaseModel(nn.Module): 9 | def forward(self, x): 10 | gpu_ids = None 11 | if isinstance(x.data, torch.cuda.FloatTensor) and self.num_gpu > 1: 12 | gpu_ids = range(self.num_gpu) 13 | if gpu_ids: 14 | return nn.parallel.data_parallel(self.main, x, gpu_ids) 15 | else: 16 | return self.main(x) 17 | 18 | class GeneratorCNN(BaseModel): 19 | def __init__(self, input_num, initial_conv_dim, output_num, repeat_num, hidden_num, num_gpu): 20 | super(GeneratorCNN, self).__init__() 21 | self.num_gpu = num_gpu 22 | layers = [] 23 | 24 | self.initial_conv_dim = initial_conv_dim 25 | self.fc = nn.Linear(input_num, np.prod(self.initial_conv_dim)) 26 | 27 | layers = [] 28 | for idx in range(repeat_num): 29 | layers.append(nn.Conv2d(hidden_num, hidden_num, 3, 1, 1)) 30 | layers.append(nn.ELU(True)) 31 | layers.append(nn.Conv2d(hidden_num, hidden_num, 3, 1, 1)) 32 | layers.append(nn.ELU(True)) 33 | 34 | if idx < repeat_num - 1: 35 | layers.append(nn.UpsamplingNearest2d(scale_factor=2)) 36 | 37 | layers.append(nn.Conv2d(hidden_num, output_num, 3, 1, 1)) 38 | #layers.append(nn.Tanh()) 39 | layers.append(nn.ELU(True)) 40 | 41 | self.conv = torch.nn.Sequential(*layers) 42 | 43 | def main(self, x): 44 | fc_out = self.fc(x).view([-1] + self.initial_conv_dim) 45 | return self.conv(fc_out) 46 | 47 | class DiscriminatorCNN(BaseModel): 48 | def __init__(self, input_channel, z_num, repeat_num, hidden_num, num_gpu): 49 | super(DiscriminatorCNN, self).__init__() 50 | self.num_gpu = num_gpu 51 | 52 | # Encoder 53 | layers = [] 54 | layers.append(nn.Conv2d(input_channel, hidden_num, 3, 1, 1)) 55 | layers.append(nn.ELU(True)) 56 | 57 | prev_channel_num = hidden_num 58 | for idx in range(repeat_num): 59 | channel_num = hidden_num * (idx + 1) 60 | layers.append(nn.Conv2d(prev_channel_num, channel_num, 3, 1, 1)) 61 | layers.append(nn.ELU(True)) 62 | 63 | if idx < repeat_num - 1: 64 | layers.append(nn.Conv2d(channel_num, channel_num, 3, 2, 1)) 65 | #layers.append(nn.MaxPool2d(2)) 66 | #layers.append(nn.MaxPool2d(1, 2)) 67 | else: 68 | layers.append(nn.Conv2d(channel_num, channel_num, 3, 1, 1)) 69 | 70 | layers.append(nn.ELU(True)) 71 | prev_channel_num = channel_num 72 | 73 | self.conv1_output_dim = [channel_num, 8, 8] 74 | 75 | self.conv1 = torch.nn.Sequential(*layers) 76 | self.fc1 = nn.Linear(8*8*channel_num, z_num) 77 | 78 | # Decoder 79 | self.conv2_input_dim = [hidden_num, 8, 8] 80 | self.fc2 = nn.Linear(z_num, np.prod(self.conv2_input_dim)) 81 | 82 | layers = [] 83 | for idx in range(repeat_num): 84 | layers.append(nn.Conv2d(hidden_num, hidden_num, 3, 1, 1)) 85 | layers.append(nn.ELU(True)) 86 | layers.append(nn.Conv2d(hidden_num, hidden_num, 3, 1, 1)) 87 | layers.append(nn.ELU(True)) 88 | 89 | if idx < repeat_num - 1: 90 | layers.append(nn.UpsamplingNearest2d(scale_factor=2)) 91 | 92 | layers.append(nn.Conv2d(hidden_num, input_channel, 3, 1, 1)) 93 | #layers.append(nn.Tanh()) 94 | layers.append(nn.ELU(True)) 95 | 96 | self.conv2 = torch.nn.Sequential(*layers) 97 | 98 | def main(self, x): 99 | conv1_out = self.conv1(x).view(-1, np.prod(self.conv1_output_dim)) 100 | fc1_out = self.fc1(conv1_out) 101 | 102 | fc2_out = self.fc2(fc1_out).view([-1] + self.conv2_input_dim) 103 | conv2_out = self.conv2(fc2_out) 104 | return conv2_out 105 | 106 | class _Loss(nn.Module): 107 | 108 | def __init__(self, size_average=True): 109 | super(_Loss, self).__init__() 110 | self.size_average = size_average 111 | 112 | def forward(self, input, target): 113 | # this won't still solve the problem 114 | # which means gradient will not flow through target 115 | # _assert_no_grad(target) 116 | backend_fn = getattr(self._backend, type(self).__name__) 117 | return backend_fn(self.size_average)(input, target) 118 | 119 | class L1Loss(_Loss): 120 | r"""Creates a criterion that measures the mean absolute value of the 121 | element-wise difference between input `x` and target `y`: 122 | 123 | :math:`{loss}(x, y) = 1/n \sum |x_i - y_i|` 124 | 125 | `x` and `y` arbitrary shapes with a total of `n` elements each. 126 | 127 | The sum operation still operates over all the elements, and divides by `n`. 128 | 129 | The division by `n` can be avoided if one sets the constructor argument `sizeAverage=False` 130 | """ 131 | pass 132 | -------------------------------------------------------------------------------- /test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-pytorch/487d7fc1403772c307f2bb16f9fb637a038175c1/test.png -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import StringIO 5 | import scipy.misc 6 | import numpy as np 7 | from glob import glob 8 | from tqdm import trange 9 | from itertools import chain 10 | from collections import deque 11 | 12 | import torch 13 | from torch import nn 14 | import torch.nn.parallel 15 | import torchvision.utils as vutils 16 | from torch.autograd import Variable 17 | 18 | from models import * 19 | from data_loader import get_loader 20 | 21 | def weights_init(m): 22 | classname = m.__class__.__name__ 23 | if classname.find('Conv') != -1: 24 | m.weight.data.normal_(0.0, 0.02) 25 | elif classname.find('BatchNorm') != -1: 26 | m.weight.data.normal_(1.0, 0.02) 27 | m.bias.data.fill_(0) 28 | 29 | def next(loader): 30 | return loader.next()[0] 31 | 32 | class Trainer(object): 33 | def __init__(self, config, data_loader): 34 | self.config = config 35 | self.data_loader = data_loader 36 | 37 | self.num_gpu = config.num_gpu 38 | self.dataset = config.dataset 39 | 40 | self.lr = config.lr 41 | self.beta1 = config.beta1 42 | self.beta2 = config.beta2 43 | self.optimizer = config.optimizer 44 | self.batch_size = config.batch_size 45 | 46 | self.gamma = config.gamma 47 | self.lambda_k = config.lambda_k 48 | 49 | self.z_num = config.z_num 50 | self.conv_hidden_num = config.conv_hidden_num 51 | self.input_scale_size = config.input_scale_size 52 | 53 | self.model_dir = config.model_dir 54 | self.load_path = config.load_path 55 | 56 | self.start_step = 0 57 | self.log_step = config.log_step 58 | self.max_step = config.max_step 59 | self.save_step = config.save_step 60 | self.lr_update_step = config.lr_update_step 61 | 62 | self.build_model() 63 | 64 | if self.num_gpu > 0: 65 | self.G.cuda() 66 | self.D.cuda() 67 | 68 | if self.load_path: 69 | self.load_model() 70 | 71 | self.use_tensorboard = config.use_tensorboard 72 | if self.use_tensorboard: 73 | import tensorflow as tf 74 | self.summary_writer = tf.summary.FileWriter(self.model_dir) 75 | 76 | def inject_summary(summary_writer, tag, value, step): 77 | if hasattr(value, '__len__'): 78 | for idx, img in enumerate(value): 79 | summary = tf.Summary() 80 | sio = StringIO.StringIO() 81 | scipy.misc.toimage(img).save(sio, format="png") 82 | image_summary = tf.Summary.Image(encoded_image_string=sio.getvalue()) 83 | summary.value.add(tag="{}/{}".format(tag, idx), image=image_summary) 84 | summary_writer.add_summary(summary, global_step=step) 85 | else: 86 | summary= tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 87 | summary_writer.add_summary(summary, global_step=step) 88 | 89 | self.inject_summary = inject_summary 90 | 91 | def build_model(self): 92 | channel, height, width = self.data_loader.shape 93 | assert height == width, "height and width should be same" 94 | 95 | repeat_num = int(np.log2(height)) - 2 96 | self.D = DiscriminatorCNN( 97 | channel, self.z_num, repeat_num, self.conv_hidden_num, self.num_gpu) 98 | self.G = GeneratorCNN( 99 | self.z_num, self.D.conv2_input_dim, channel, repeat_num, self.conv_hidden_num, self.num_gpu) 100 | 101 | self.G.apply(weights_init) 102 | self.D.apply(weights_init) 103 | 104 | def train(self): 105 | l1 = L1Loss() 106 | 107 | z_D = Variable(torch.FloatTensor(self.batch_size, self.z_num)) 108 | z_G = Variable(torch.FloatTensor(self.batch_size, self.z_num)) 109 | z_fixed = Variable(torch.FloatTensor(self.batch_size, self.z_num).normal_(0, 1), volatile=True) 110 | 111 | if self.num_gpu > 0: 112 | l1.cuda() 113 | 114 | z_D = z_D.cuda() 115 | z_G = z_G.cuda() 116 | z_fixed = z_fixed.cuda() 117 | 118 | if self.optimizer == 'adam': 119 | optimizer = torch.optim.Adam 120 | else: 121 | raise Exception("[!] Caution! Paper didn't use {} opimizer other than Adam".format(config.optimizer)) 122 | 123 | def get_optimizer(lr): 124 | return optimizer(self.G.parameters(), lr=lr, betas=(self.beta1, self.beta2)), \ 125 | optimizer(self.D.parameters(), lr=lr, betas=(self.beta1, self.beta2)) 126 | 127 | g_optim, d_optim = get_optimizer(self.lr) 128 | 129 | data_loader = iter(self.data_loader) 130 | x_fixed = self._get_variable(next(data_loader)) 131 | vutils.save_image(x_fixed.data, '{}/x_fixed.png'.format(self.model_dir)) 132 | 133 | k_t = 0 134 | prev_measure = 1 135 | measure_history = deque([0]*self.lr_update_step, self.lr_update_step) 136 | 137 | for step in trange(self.start_step, self.max_step): 138 | try: 139 | x = next(data_loader) 140 | except StopIteration: 141 | data_loader = iter(self.data_loader) 142 | x = next(data_loader) 143 | 144 | x = self._get_variable(x) 145 | batch_size = x.size(0) 146 | 147 | self.D.zero_grad() 148 | self.G.zero_grad() 149 | 150 | z_D.data.normal_(0, 1) 151 | z_G.data.normal_(0, 1) 152 | 153 | #sample_z_D = self.G(z_D) 154 | sample_z_G = self.G(z_G) 155 | 156 | AE_x = self.D(x) 157 | AE_G_d = self.D(sample_z_G.detach()) 158 | AE_G_g = self.D(sample_z_G) 159 | 160 | d_loss_real = l1(AE_x, x) 161 | d_loss_fake = l1(AE_G_d, sample_z_G.detach()) 162 | 163 | d_loss = d_loss_real - k_t * d_loss_fake 164 | g_loss = l1(sample_z_G, AE_G_g) # this won't still solve the problem 165 | 166 | loss = d_loss + g_loss 167 | loss.backward() 168 | 169 | g_optim.step() 170 | d_optim.step() 171 | 172 | g_d_balance = (self.gamma * d_loss_real - d_loss_fake).data[0] 173 | k_t += self.lambda_k * g_d_balance 174 | k_t = max(min(1, k_t), 0) 175 | 176 | measure = d_loss_real.data[0] + abs(g_d_balance) 177 | measure_history.append(measure) 178 | 179 | if step % self.log_step == 0: 180 | print("[{}/{}] Loss_D: {:.4f} L_x: {:.4f} Loss_G: {:.4f} " 181 | "measure: {:.4f}, k_t: {:.4f}, lr: {:.7f}". \ 182 | format(step, self.max_step, d_loss.data[0], d_loss_real.data[0], 183 | g_loss.data[0], measure, k_t, self.lr)) 184 | x_fake = self.generate(z_fixed, self.model_dir, idx=step) 185 | self.autoencode(x_fixed, self.model_dir, idx=step, x_fake=x_fake) 186 | 187 | if self.use_tensorboard: 188 | info = { 189 | 'loss/loss_D': d_loss.data[0], 190 | 'loss/L_x': d_loss_real.data[0], 191 | 'loss/Loss_G': g_loss.data[0], 192 | 'misc/measure': measure, 193 | 'misc/k_t': k_t, 194 | 'misc/lr': self.lr, 195 | 'misc/balance': g_d_balance, 196 | } 197 | for tag, value in info.items(): 198 | self.inject_summary(self.summary_writer, tag, value, step) 199 | 200 | self.inject_summary( 201 | self.summary_writer, "AE_G", AE_G_g.data.cpu().numpy(), step) 202 | self.inject_summary( 203 | self.summary_writer, "AE_x", AE_x.data.cpu().numpy(), step) 204 | self.inject_summary( 205 | self.summary_writer, "z_G", sample_z_G.data.cpu().numpy(), step) 206 | 207 | self.summary_writer.flush() 208 | 209 | if step % self.save_step == self.save_step - 1: 210 | self.save_model(step) 211 | 212 | if step % self.lr_update_step == self.lr_update_step - 1: 213 | cur_measure = np.mean(measure_history) 214 | if cur_measure > prev_measure * 0.9999: 215 | self.lr *= 0.5 216 | g_optim, d_optim = get_optimizer(self.lr) 217 | prev_measure = cur_measure 218 | 219 | def generate(self, inputs, path, idx=None): 220 | path = '{}/{}_G.png'.format(path, idx) 221 | x = self.G(inputs) 222 | vutils.save_image(x.data, path) 223 | print("[*] Samples saved: {}".format(path)) 224 | return x 225 | 226 | def autoencode(self, inputs, path, idx=None, x_fake=None): 227 | x_path = '{}/{}_D.png'.format(path, idx) 228 | x = self.D(inputs) 229 | vutils.save_image(x.data, x_path) 230 | print("[*] Samples saved: {}".format(x_path)) 231 | 232 | if x_fake is not None: 233 | x_fake_path = '{}/{}_D_fake.png'.format(path, idx) 234 | x = self.D(x_fake) 235 | vutils.save_image(x.data, x_fake_path) 236 | print("[*] Samples saved: {}".format(x_fake_path)) 237 | 238 | def test(self): 239 | data_loader = iter(self.data_loader) 240 | x_fixed = self._get_variable(next(data_loader)) 241 | vutils.save_image(x_fixed.data, '{}/x_fixed_test.png'.format(self.model_dir)) 242 | self.autoencode(x_fixed, self.model_dir, idx="test", x_fake=None) 243 | 244 | def save_model(self, step): 245 | print("[*] Save models to {}...".format(self.model_dir)) 246 | 247 | torch.save(self.G.state_dict(), '{}/G_{}.pth'.format(self.model_dir, step)) 248 | torch.save(self.D.state_dict(), '{}/D_{}.pth'.format(self.model_dir, step)) 249 | 250 | def load_model(self): 251 | print("[*] Load models from {}...".format(self.load_path)) 252 | 253 | paths = glob(os.path.join(self.load_path, 'G_*.pth')) 254 | paths.sort() 255 | 256 | if len(paths) == 0: 257 | print("[!] No checkpoint found in {}...".format(self.load_path)) 258 | return 259 | 260 | idxes = [int(os.path.basename(path.split('.')[1].split('_')[-1])) for path in paths] 261 | self.start_step = max(idxes) 262 | 263 | if self.num_gpu == 0: 264 | map_location = lambda storage, loc: storage 265 | else: 266 | map_location = None 267 | 268 | G_filename = '{}/G_{}.pth'.format(self.load_path, self.start_step) 269 | self.G.load_state_dict( 270 | torch.load(G_filename, map_location=map_location)) 271 | print("[*] G network loaded: {}".format(G_filename)) 272 | 273 | D_filename = '{}/D_{}.pth'.format(self.load_path, self.start_step) 274 | self.D.load_state_dict( 275 | torch.load(D_filename, map_location=map_location)) 276 | print("[*] D network loaded: {}".format(D_filename)) 277 | 278 | def _get_variable(self, inputs): 279 | if self.num_gpu > 0: 280 | out = Variable(inputs.cuda()) 281 | else: 282 | out = Variable(inputs) 283 | return out 284 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import json 5 | import logging 6 | import numpy as np 7 | from datetime import datetime 8 | 9 | import torchvision.utils as vutils 10 | 11 | def prepare_dirs_and_logger(config): 12 | formatter = logging.Formatter("%(asctime)s:%(levelname)s::%(message)s") 13 | logger = logging.getLogger() 14 | 15 | for hdlr in logger.handlers: 16 | logger.removeHandler(hdlr) 17 | 18 | handler = logging.StreamHandler() 19 | handler.setFormatter(formatter) 20 | 21 | logger.addHandler(handler) 22 | 23 | if config.load_path: 24 | if config.load_path.startswith(config.log_dir): 25 | config.model_dir = config.load_path 26 | else: 27 | if config.load_path.startswith(config.dataset): 28 | config.model_name = config.load_path 29 | else: 30 | config.model_name = "{}_{}".format(config.dataset, config.load_path) 31 | else: 32 | config.model_name = "{}_{}".format(config.dataset, get_time()) 33 | 34 | if not hasattr(config, 'model_dir'): 35 | config.model_dir = os.path.join(config.log_dir, config.model_name) 36 | config.data_path = os.path.join(config.data_dir, config.dataset) 37 | 38 | for path in [config.log_dir, config.data_dir, config.model_dir]: 39 | if not os.path.exists(path): 40 | os.makedirs(path) 41 | 42 | def get_time(): 43 | return datetime.now().strftime("%m%d_%H%M%S") 44 | 45 | def save_config(config): 46 | param_path = os.path.join(config.model_dir, "params.json") 47 | 48 | print("[*] MODEL dir: %s" % config.model_dir) 49 | print("[*] PARAM path: %s" % param_path) 50 | 51 | with open(param_path, 'w') as fp: 52 | json.dump(config.__dict__, fp, indent=4, sort_keys=True) 53 | 54 | def save_image(tensor, filename, nrow=8, padding=2, 55 | normalize=False, range=None, scale_each=False): 56 | from PIL import Image 57 | tensor = tensor.cpu() 58 | grid = vutils.make_grid(tensor, nrow=nrow, padding=padding, 59 | normalize=normalize, range=range, scale_each=scale_each) 60 | #ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() 61 | ndarr = grid.byte().permute(1, 2, 0).numpy() 62 | im = Image.fromarray(ndarr) 63 | im.save(filename) 64 | --------------------------------------------------------------------------------