├── .gitignore ├── README.md ├── config.py ├── data └── prepro │ └── label.csv ├── dataloader.py ├── layers.py ├── main.py ├── pixel_dt_gan.py ├── prepro.py ├── requirements.txt ├── setup.sh ├── solver.py ├── tf.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | 3 | *.tar 4 | *.tar.gz 5 | *.zip 6 | *.csv 7 | 8 | *.png 9 | *.jpg 10 | *.png 11 | *.jpeg 12 | *.bmp 13 | 14 | 15 | ./data 16 | ./data/raw 17 | ./data/prepro 18 | ./lookbook 19 | ./venv 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Pytorch Implementation of Piexl-leve Domain Transfer 3 | 4 | reference: [Pixel-Level Domain Transfer](https://arxiv.org/pdf/1603.07442.pdf) 5 | 6 | ## Note (IMPORTANT): 7 | + This code is not complete, the training fails yet. 8 | + I figure the bug is very trivial, and will be fixed as soon as I got time. 9 | + Welcome pull requests!! 10 | 11 | ## Prerequisites 12 | + virtualenv 13 | + PyTorch (tested at 0.3.0, but 0.4.0 would be ok) 14 | 15 | ## Prepare Dataset 16 | + Download LOOKBOOK dataset: [Link](https://drive.google.com/file/d/0By_p0y157GxQU1dCRUU4SFNqaTQ/view?usp=sharing) 17 | + run `sh setup.sh` 18 | 19 | 20 | ## Training 21 | ~~~ 22 | python main.py \ 23 | --gpu_id 0 \ 24 | --root_dir 'data/prepro' \ 25 | --csv_file 'data/prepro/label.csv' \ 26 | --expr 'experiment1' \ 27 | --batch_size 24 \ 28 | --load_size 64 \ 29 | --lr 0.0002 30 | ~~~ 31 | 32 | 33 | ## Visualization 34 | ~~~ 35 | tensorboard --logdir repo//tb --port 8000 36 | ~~~ 37 | 38 | ## Bug report 39 | I found the loss dies too quickly. Need to figure out the reason. 40 | Any pull requests, or bug report is welcome. :-) 41 | 42 | ![image1](https://user-images.githubusercontent.com/17468992/41200409-018cfd98-6cdf-11e8-93e7-fc85646c7c89.png) 43 | ![image2](https://user-images.githubusercontent.com/17468992/41200410-025498f8-6cdf-11e8-8997-355143a074c4.png) 44 | ![image3](https://user-images.githubusercontent.com/17468992/41200411-03075cae-6cdf-11e8-8a10-e14fdda91cc3.png) 45 | 46 | 47 | ## Author 48 | MinchulShin / [@nashory](https://github.com/nashory) 49 | __Any bug reports or questions are welcome. (min.stellastra[at]gmail.com) :-)__ 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | '''config.py 2 | ''' 3 | 4 | import time 5 | import argparse 6 | 7 | # helper func. 8 | def str2bool(v): 9 | return v.lower() in ('true', '1') 10 | 11 | # Parser 12 | parser = argparse.ArgumentParser('pixel-dt-gan') 13 | 14 | # Common options. 15 | parser.add_argument('--gpu_id', 16 | default='4', 17 | type=str, 18 | help='id(s) for CUDA_VISIBLE_DEVICES') 19 | parser.add_argument('--root_dir', 20 | default='../data/prepro', 21 | type=str) 22 | parser.add_argument('--csv_file', 23 | default='label.csv', 24 | type=str) 25 | parser.add_argument('--manualSeed', 26 | type=int, 27 | default=int(time.time()), 28 | help='manual seed') 29 | parser.add_argument('--expr', 30 | default='devel', 31 | type=str, 32 | help='experiment name') 33 | parser.add_argument('--workers', 34 | type=int, 35 | default=8) 36 | # hyperparameters 37 | parser.add_argument('--batch_size', 38 | type=int, 39 | default=24) 40 | parser.add_argument('--load_size', 41 | type=int, 42 | default=64) 43 | parser.add_argument('--epoch', 44 | type=int, 45 | default=40) 46 | parser.add_argument('--lr', 47 | type=float, 48 | default=0.0002) 49 | parser.add_argument('--optimizer', 50 | default='adam', 51 | type=str) 52 | 53 | # visualization 54 | parser.add_argument('--use_tensorboard', 55 | default=True, 56 | type=bool) 57 | parser.add_argument('--save_image_every', 58 | type=int, 59 | default=50) 60 | 61 | 62 | ## parse and save config. 63 | config, _ = parser.parse_known_args() 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | '''dataloader.py 2 | ''' 3 | 4 | import os 5 | import random 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from PIL import Image, ImageOps 10 | import torch 11 | import torchvision.transforms as transforms 12 | import torchvision.datasets as datasets 13 | 14 | from utils import adjust_pixel_range 15 | 16 | # Ignore warnings 17 | import warnings 18 | warnings.filterwarnings("ignore") 19 | 20 | 21 | class FashionDomainDataset(torch.utils.data.Dataset): 22 | def __init__(self, root_dir, csv_file, transform=None): 23 | """ 24 | Args: 25 | csv_file (string): ., . 26 | root_dir (string): Directory including all raw/clean images. 27 | transform (callable, optional): Optional transform to be applied on a sample. 28 | """ 29 | self.label_csv = pd.read_csv(os.path.join(root_dir, csv_file)) 30 | self.root_dir = root_dir 31 | self.transform = transform 32 | self.length = len(self.label_csv) 33 | 34 | def __len__(self): 35 | return len(self.label_csv) 36 | 37 | def __getitem__(self, idx): 38 | raw_id = os.path.join(self.root_dir, 'raw', self.label_csv.ix[idx, 0]) 39 | raw_im = Image.open(raw_id) 40 | clean_id = os.path.join(self.root_dir, 'clean', self.label_csv.ix[idx, 1]) 41 | clean_im = Image.open(clean_id) 42 | irre_idx = random.randrange(0, self.length) 43 | while(self.label_csv.ix[idx,0] == self.label_csv.ix[irre_idx,0]): 44 | irre_idx = random.randrange(0, self.length) 45 | irre_id = os.path.join(self.root_dir, 'clean', self.label_csv.ix[irre_idx, 1]) 46 | irre_im = Image.open(irre_id) 47 | 48 | if self.transform: 49 | raw_im = self.transform(raw_im) 50 | clean_im = self.transform(clean_im) 51 | irre_im = self.transform(irre_im) 52 | 53 | # adjust pixel range [0,255] --> [-1, 1] 54 | raw_im = adjust_pixel_range(raw_im, [0,1], [-1,1]) 55 | clean_im = adjust_pixel_range(clean_im, [0,1], [-1,1]) 56 | irre_im = adjust_pixel_range(irre_im, [0,1], [-1,1]) 57 | return {'raw':raw_im, 'clean':clean_im, 'irre':irre_im} 58 | 59 | class ResizeWithPadding(object): 60 | def __init__(self, imsize, fill=255): 61 | self.fill = fill 62 | self.imsize = imsize 63 | 64 | def __call__(self, x): 65 | return self.__add_padding__(x, self.imsize) 66 | 67 | def __add_padding__(self, x, imsize): 68 | w, h = x.size 69 | new_w = int(w / max(w,h) * imsize) 70 | new_h = int(h / max(w,h) * imsize) 71 | 72 | x = x.resize((new_w, new_h), resample=Image.BILINEAR) 73 | 74 | delta_w = imsize - new_w 75 | delta_h = imsize - new_h 76 | padding = (delta_w//2, delta_h//2, delta_w-(delta_w//2), delta_h-(delta_h//2)) 77 | x = ImageOps.expand(x, padding, fill=(self.fill, self.fill, self.fill)) 78 | return x 79 | 80 | def get_loader(config): 81 | prepro = [] 82 | #prepro.append(transforms.Resize(config.load_size)) 83 | #prepro.append(transforms.CenterCrop(config.load_size)) 84 | prepro.append(ResizeWithPadding(config.load_size)) 85 | prepro.append(transforms.ToTensor()) 86 | 87 | transform = transforms.Compose(prepro) 88 | 89 | # dataset. 90 | dataset = FashionDomainDataset( 91 | csv_file = config.csv_file, 92 | root_dir = config.root_dir, 93 | transform = transform) 94 | 95 | # dataloader. 96 | dataloader = torch.utils.data.DataLoader( 97 | dataset = dataset, 98 | batch_size = config.batch_size, 99 | shuffle = True, 100 | num_workers = config.workers) 101 | 102 | return dataloader 103 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | '''layers.py 2 | ''' 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | class Flatten(nn.Module): 8 | '''nn.Flatten layer in Torch7. 9 | ''' 10 | def __init__(self): 11 | super(Flatten, self).__init__() 12 | 13 | def forward(self, x): 14 | return x.view(x.size(0), -1) 15 | 16 | def __repr__(self): 17 | return self.__class__.__name__ 18 | 19 | class Reshape(nn.Module): 20 | '''nn.Reshape in Torch7. 21 | ''' 22 | def __init__(self, shape): 23 | super(Reshape, self).__init__() 24 | self.shape = shape 25 | def forward(self, x): 26 | return x.view(self.shape) 27 | def __repr__(self): 28 | return self.__class__.__name__ + ' (reshape to size: {})'.format(" ".join(str(x) for x in self.shape)) 29 | 30 | class Identity(nn.Module): 31 | '''nn.Identity in Torch7. 32 | ''' 33 | def __init__(self): 34 | super(Identity, self).__init__() 35 | def forward(self, x): 36 | return x 37 | def __repr__(self): 38 | return self.__class__.__name__ + ' (skip connection)' 39 | 40 | 41 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | '''main.py 2 | ''' 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import os 9 | import sys 10 | import time 11 | import random 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.backends.cudnn as cudnn 16 | 17 | from config import config 18 | 19 | 20 | def main(): 21 | # print config. 22 | state = {k: v for k, v in config._get_kwargs()} 23 | print(state) 24 | 25 | # if use cuda. 26 | os.environ['CUDA_VISIBLE_DEVICES'] = config.gpu_id 27 | use_cuda = torch.cuda.is_available() 28 | 29 | if config.manualSeed is None: 30 | config.manualSeed = random.randint(1, 10000) 31 | random.seed(config.manualSeed) 32 | torch.manual_seed(config.manualSeed) 33 | if use_cuda: 34 | torch.cuda.manual_seed_all(config.manualSeed) 35 | torch.backends.cudnn.benchmark = True # speed up training. 36 | 37 | # data loader 38 | from dataloader import get_loader 39 | dataloader = get_loader(config) 40 | 41 | # load model 42 | from pixel_dt_gan import PixelDtGan 43 | gen = PixelDtGan(mode='gen') # gen (encoder / decoder) 44 | dis = PixelDtGan(mode='dis') # real / fake discriminator 45 | dom = PixelDtGan(mode='dom') # domain discriminator 46 | 47 | print('generator:') 48 | print(gen) 49 | print('real/fake discriminator:') 50 | print(dis) 51 | print('domain discriminator:') 52 | print(dom) 53 | 54 | # solver 55 | from solver import Solver 56 | solver = Solver( 57 | config = config, 58 | dataloader = dataloader, 59 | gen = gen, 60 | dis = dis, 61 | dom = dom) 62 | 63 | # train for N-epochs 64 | for epoch in range(config.epoch): 65 | solver.solve(epoch) 66 | 67 | print('Congrats! You just finished training Pixel-Dt-Gan.') 68 | 69 | 70 | 71 | 72 | 73 | 74 | if __name__=='__main__': 75 | main() 76 | 77 | -------------------------------------------------------------------------------- /pixel_dt_gan.py: -------------------------------------------------------------------------------- 1 | 2 | # -*- coding: utf-8 -*- 3 | 4 | '''pixel-dt-gan.py 5 | ''' 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | from layers import ( 15 | Identity, 16 | Reshape, 17 | Flatten,) 18 | 19 | '''helper func. 20 | ''' 21 | def __cuda__(x): 22 | return x.cuda() if torch.cuda.is_available() else x 23 | 24 | def __load_weights_from__(module_dict, load_dict, modulenames): 25 | for modulename in modulenames: 26 | module = module_dict[modulename] 27 | print('loaded weights from module "{}" ...'.format(modulename)) 28 | module.load_state_dict(load_dict[modulename]) 29 | 30 | 31 | '''PixelDtGan 32 | ''' 33 | class PixelDtGan(nn.Module): 34 | def __init__( 35 | self, 36 | mode): 37 | 38 | super(PixelDtGan, self).__init__() 39 | 40 | self.mode = mode 41 | self.module_list = nn.ModuleList() 42 | self.module_dict = {} 43 | self.end_points = {} 44 | 45 | if mode.lower() == 'gen': 46 | # endpoint: encoder 47 | layers = [] 48 | layers = self.__add_conv_layer__(layers, 3, 128, k_size=4, stride=2, pad=1, act='leakyrelu', bn=False) 49 | layers = self.__add_conv_layer__(layers, 128, 256, k_size=4, stride=2, pad=1, act='leakyrelu', bn=True) 50 | layers = self.__add_conv_layer__(layers, 256, 512, k_size=4, stride=2, pad=1, act='leakyrelu', bn=True) 51 | layers = self.__add_conv_layer__(layers, 512, 1024, k_size=4, stride=2, pad=1, act='leakyrelu', bn=True) 52 | layers = self.__add_conv_layer__(layers, 1024, 64, k_size=4, stride=2, pad=1, act='leakyrelu', bn=True) 53 | self.__register_module__('enc', layers) 54 | 55 | # endpoint: decoder 56 | layers = [] 57 | layers = self.__add_deconv_layer__(layers, 64, 1024, k_size=4, stride=2, pad=1, act='relu', bn=True) 58 | layers = self.__add_deconv_layer__(layers, 1024, 512, k_size=4, stride=2, pad=1, act='relu', bn=True) 59 | layers = self.__add_deconv_layer__(layers, 512, 256, k_size=4, stride=2, pad=1, act='relu', bn=True) 60 | layers = self.__add_deconv_layer__(layers, 256, 128, k_size=4, stride=2, pad=1, act='relu', bn=True) 61 | layers = self.__add_deconv_layer__(layers, 128, 3, k_size=4, stride=2, pad=1, act='tanh', bn=False) 62 | self.__register_module__('dec', layers) 63 | 64 | elif mode.lower() == 'dis': 65 | # endpoint: real/fake discriminator 66 | layers = [] 67 | layers = self.__add_conv_layer__(layers, 3, 128, k_size=4, stride=2, pad=1, act='leakyrelu', bn=False) 68 | layers = self.__add_conv_layer__(layers, 128, 256, k_size=4, stride=2, pad=1, act='leakyrelu', bn=True) 69 | layers = self.__add_conv_layer__(layers, 256, 512, k_size=4, stride=2, pad=1, act='leakyrelu', bn=True) 70 | layers = self.__add_conv_layer__(layers, 512, 1024, k_size=4, stride=2, pad=1, act='leakyrelu', bn=True) 71 | layers = self.__add_conv_layer__(layers, 1024, 1, k_size=4, stride=1, pad=0, act='sigmoid', bn=False) 72 | self.__register_module__('dis', layers) 73 | 74 | elif mode.lower() == 'dom': 75 | # endpoint: domain discriminator 76 | layers = [] 77 | layers = self.__add_conv_layer__(layers, 6, 128, k_size=4, stride=2, pad=1, act='leakyrelu', bn=False) 78 | layers = self.__add_conv_layer__(layers, 128, 256, k_size=4, stride=2, pad=1, act='leakyrelu', bn=True) 79 | layers = self.__add_conv_layer__(layers, 256, 512, k_size=4, stride=2, pad=1, act='leakyrelu', bn=True) 80 | layers = self.__add_conv_layer__(layers, 512, 1024, k_size=4, stride=2, pad=1, act='leakyrelu', bn=True) 81 | layers = self.__add_conv_layer__(layers, 1024, 1, k_size=4, stride=1, pad=0, act='sigmoid', bn=False) 82 | self.__register_module__('dom', layers) 83 | 84 | def __add_deconv_layer__(self, layers, in_c, out_c, k_size, stride, pad, act, bn): 85 | layers.append(nn.ConvTranspose2d(in_c, out_c, k_size, stride, pad)) 86 | if bn: 87 | layers.append(nn.BatchNorm2d(out_c)) 88 | if act == 'leakyrelu': 89 | layers.append(nn.LeakyReLU(0.2)) 90 | elif act == 'relu': 91 | layers.append(nn.ReLU()) 92 | elif act == 'sigmoid': 93 | layers.append(nn.Sigmoid()) 94 | elif act == 'tanh': 95 | layers.append(nn.Tanh()) 96 | return layers 97 | 98 | def __add_conv_layer__(self, layers, in_c, out_c, k_size, stride, pad, act, bn): 99 | layers.append(nn.Conv2d(in_c, out_c, k_size, stride, pad)) 100 | if bn: 101 | layers.append(nn.BatchNorm2d(out_c)) 102 | if act == 'leakyrelu': 103 | layers.append(nn.LeakyReLU(0.2)) 104 | elif act == 'relu': 105 | layers.append(nn.ReLU()) 106 | elif act == 'sigmoid': 107 | layers.append(nn.Sigmoid()) 108 | elif act == 'tanh': 109 | layers.append(nn.Tanh()) 110 | return layers 111 | 112 | def __register_module__(self, modulename, module): 113 | if isinstance(module, list) or isinstance(module, tuple): 114 | module = nn.Sequential(*module) 115 | self.module_list.append(module) 116 | self.module_dict[modulename] = module 117 | 118 | def __forward_and_save__(self, x, modulename): 119 | module = self.module_dict[modulename] 120 | x = module(x) 121 | self.end_points[modulename] = x 122 | return x 123 | 124 | def forward(self, x): 125 | if self.mode.lower() == 'gen': 126 | x = self.__forward_and_save__(x, 'enc') 127 | x = self.__forward_and_save__(x, 'dec') 128 | elif self.mode.lower() == 'dis': 129 | x = self.__forward_and_save__(x, 'dis') 130 | else: 131 | assert self.mode.lower() == 'dom' 132 | x = self.__forward_and_save__(x, 'dom') 133 | return x 134 | 135 | 136 | 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /prepro.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from os.path import join 3 | import glob 4 | from PIL import Image 5 | from skimage import io 6 | import numpy as np 7 | import json 8 | import argparse 9 | import random 10 | import argparse 11 | import csv 12 | 13 | parser = argparse.ArgumentParser('prepare dataset') 14 | parser.add_argument('--data_root', type=str, default='data/raw') 15 | parser.add_argument('--out_dir', type=str, default='data/prepro') 16 | 17 | 18 | ## parse and save config. 19 | config, _ = parser.parse_known_args() 20 | 21 | 22 | cnt = 0 23 | valid_ext = ['.jpg', '.png'] 24 | os.system('mkdir -p {}/raw'.format(config.out_dir)) 25 | os.system('mkdir -p {}/clean'.format(config.out_dir)) 26 | 27 | csvfile = open('{}/label.csv'.format(config.out_dir), 'wb') 28 | writer = csv.writer(csvfile, delimiter=',') 29 | for filename in glob.glob(os.path.join(config.data_root, '*')): 30 | flist = os.path.splitext(filename) 31 | fname = os.path.basename(flist[0]) 32 | fext = flist[1] 33 | if fext.lower() not in valid_ext: 34 | continue 35 | 36 | image = Image.open(filename) 37 | fid = fname.split('_') 38 | if fid[1] == 'CLEAN0': 39 | image.save('{}/raw/{}{}'.format(config.out_dir, fid[2], fext)) 40 | writer.writerow([fid[2]+fext, fid[0]+fext]) 41 | 42 | elif fid[1] == 'CLEAN1': 43 | image.save('{}/clean/{}{}'.format(config.out_dir, fid[0], fext)) 44 | 45 | # logging. 46 | cnt = cnt +1 47 | print '[' + str(cnt) + '] ' + 'processed @ ' + os.path.join(config.out_dir, fname+'.jpg') 48 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.2.0 2 | astor==0.6.2 3 | backports.weakref==1.0rc1 4 | bleach==1.5.0 5 | gast==0.2.0 6 | grpcio==1.11.0 7 | html5lib==0.9999999 8 | Markdown==2.6.11 9 | numpy==1.14.3 10 | pandas==0.22.0 11 | Pillow==5.1.0 12 | progress==1.3 13 | protobuf==3.5.2.post1 14 | python-dateutil==2.7.2 15 | pytz==2018.4 16 | six==1.11.0 17 | tensorboard==1.8.0 18 | tensorboardX==1.2 19 | tensorflow==1.8.0 20 | tensorflow-tensorboard==1.5.1 21 | termcolor==1.1.0 22 | torch==0.4.0 23 | torchvision==0.2.1 24 | Werkzeug==0.14.1 25 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | 2 | # extract 3 | tar -xvf lookbook.tar 4 | mkdir -p data 5 | mv ./lookbook/data ./data/raw 6 | rm -rf lookbook 7 | 8 | mkdir -p data/prepro 9 | python prepro.py --out_dir 'data/prepro' --data_root 'data/raw' 10 | 11 | 12 | -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | '''solver.py 5 | ''' 6 | 7 | import os 8 | import time 9 | import shutil 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | from torch.autograd import Variable 15 | from torch.optim import Adam 16 | import torchvision.utils as vutils 17 | import torch.nn.functional as F 18 | from progress.bar import Bar as Bar 19 | 20 | from tf import TensorBoard 21 | from utils import ( 22 | AverageMeter, 23 | adjust_pixel_range, 24 | make_image_grid, 25 | mkdir_p) 26 | 27 | 28 | '''helper functions. 29 | ''' 30 | def __cuda__(x): 31 | return x.cuda() if torch.cuda.is_available() else x 32 | 33 | def __to_var__(x, volatile=False): 34 | return Variable(x, volatile=volatile) 35 | 36 | def __to_tensor__(x): 37 | return x.data 38 | 39 | class Solver(object): 40 | def __init__( 41 | self, 42 | config, 43 | dataloader, 44 | gen, 45 | dis, 46 | dom): 47 | 48 | self.config = config 49 | self.dataloader = dataloader 50 | self.epoch = 0 51 | self.globalIter = 0 52 | self.prefix = os.path.join('repo', config.expr) 53 | 54 | # model 55 | self.gen = __cuda__(gen) 56 | self.dis = __cuda__(dis) 57 | self.dom = __cuda__(dom) 58 | 59 | # criterion 60 | self.l1loss = torch.nn.L1Loss() 61 | self.mse = torch.nn.MSELoss() 62 | 63 | # optimizer (support adam optimizer only at the moment.) 64 | if config.optimizer.lower() in ['adam']: 65 | betas = (0.5, 0.97) # GAN is sensitive to the beta value. May be this could be the reason of the training failure. 66 | self.opt_gen = Adam( 67 | filter(lambda p: p.requires_grad, self.gen.parameters()), 68 | lr=config.lr, 69 | betas=betas, 70 | weight_decay=0.0) 71 | self.opt_dis = Adam( 72 | filter(lambda p: p.requires_grad, self.dis.parameters()), 73 | lr=config.lr, 74 | betas=betas, 75 | weight_decay=0.0) 76 | self.opt_dom = Adam( 77 | filter(lambda p: p.requires_grad, self.dom.parameters()), 78 | lr=config.lr, 79 | betas=betas, 80 | weight_decay=0.0) 81 | elif config.optimizer.lower() in ['sgd']: 82 | self.opt_gen = optim.RMSprop( 83 | filter(lambda p: p.requires_grad, self.gen.parameters()), 84 | lr=config.lr) 85 | self.opt_dis = optim.RMSprop( 86 | filter(lambda p: p.requires_grad, self.dis.parameters()), 87 | lr=config.lr) 88 | self.opt_dom = optim.RMSprop( 89 | filter(lambda p: p.requires_grad, self.dom.parameters()), 90 | lr=config.lr) 91 | 92 | # tensorboard for visualization 93 | if config.use_tensorboard: 94 | self.tb = TensorBoard(os.path.join(self.prefix, 'tb')) 95 | self.tb.initialize() 96 | 97 | def solve(self, epoch): 98 | ''' 99 | solve for 1 epoch. 100 | Args: 101 | xr: raw, target model image. 102 | xc: clean, relavant product image. 103 | xi: clean, irrelavant product image. 104 | ''' 105 | batch_timer = AverageMeter() 106 | data_timer = AverageMeter() 107 | since = time.time() 108 | bar = Bar('[PixelDtGan] Training ...', max=len(self.dataloader)) 109 | 110 | for batch_index, x in enumerate(self.dataloader): 111 | self.globalIter = self.globalIter + 1 112 | # measure data loading time 113 | data_timer.update(time.time() - since) 114 | 115 | # convert to cuda, variable 116 | xr = x['raw'] 117 | xc = x['clean'] 118 | xi = x['irre'] 119 | xr = __to_var__(__cuda__(xr)) 120 | xc = __to_var__(__cuda__(xc)) 121 | xi = __to_var__(__cuda__(xi)) 122 | 123 | # xr_test for test with fixed input. 124 | if self.globalIter == 1: 125 | xr_test = xr.clone() 126 | xc_test = xc.clone() 127 | 128 | # zero gradients. 129 | self.gen.zero_grad() 130 | self.dis.zero_grad() 131 | self.dom.zero_grad() 132 | 133 | '''update discriminator. (dis, dom) 134 | ''' 135 | since = time.time() 136 | # train dis (real/fake) 137 | dl_xc = self.dis(xc) # real, relavant 138 | dl_xi = self.dis(xi) # real, irrelavant 139 | xc_tilde = self.gen(xr) 140 | dl_xc_tilde = self.dis(xc_tilde.detach()) # fake (detach) 141 | real_label = dl_xc.clone().fill_(1).detach() 142 | fake_label = dl_xc.clone().fill_(0).detach() 143 | loss_dis = self.mse(dl_xc, real_label) + self.mse(dl_xi, real_label) + self.mse(dl_xc_tilde, fake_label) 144 | 145 | # train dom (associated-pair/non-associated-pair) 146 | xp_ass = torch.cat((xr, xc), dim=1) 147 | xp_noass = torch.cat((xr, xi), dim=1) 148 | xp_tilde = torch.cat((xr, xc_tilde.detach()), dim=1) 149 | dl_xp_ass = self.dom(xp_ass) 150 | dl_xp_noass = self.dom(xp_noass) 151 | dl_xp_tilde = self.dom(xp_tilde) 152 | loss_dom = self.mse(dl_xp_ass, real_label) + self.mse(dl_xp_noass, fake_label) + self.mse(dl_xp_tilde, fake_label) 153 | loss_D_total = 0.5 * (loss_dis + loss_dom) 154 | loss_D_total.backward() 155 | self.opt_dis.step() 156 | self.opt_dom.step() 157 | 158 | '''update generator. (gen) 159 | ''' 160 | # train gen (real/fake) 161 | gl_xc_tilde = self.dis(xc_tilde) 162 | gl_xp_tilde = self.dom(xp_tilde) 163 | loss_gen = self.mse(gl_xc_tilde, real_label) + self.mse(gl_xp_tilde, real_label) 164 | loss_gen.backward() 165 | self.opt_gen.step() 166 | 167 | # measure batch process time 168 | batch_timer.update(time.time() - since) 169 | 170 | # print log 171 | log_msg = '\n[Epoch:{EPOCH:}][Iter:{ITER:}][lr:{LR:}] Loss_dis:{LOSS_DIS:.3f} | Loss_dom:{LOSS_DOM:.3f} | Loss_gen:{LOSS_GEN:.3f} | eta:(data:{DATA_TIME:.3f}),(batch:{BATCH_TIME:.3f}),(total:{TOTAL_TIME:})' \ 172 | .format( 173 | EPOCH=epoch+1, 174 | ITER=batch_index+1, 175 | LR=self.config.lr, 176 | LOSS_DIS=loss_dis.data.sum(), 177 | LOSS_DOM=loss_dom.data.sum(), 178 | LOSS_GEN=loss_gen.data.sum(), 179 | DATA_TIME=data_timer.val, 180 | BATCH_TIME=batch_timer.val, 181 | TOTAL_TIME=bar.elapsed_td) 182 | print(log_msg) 183 | bar.next() 184 | 185 | # visualization 186 | if self.config.use_tensorboard: 187 | self.tb.add_scalar('data/loss_dis', float(loss_dis.data.cpu()), self.globalIter) 188 | self.tb.add_scalar('data/loss_dom', float(loss_dom.data.cpu()), self.globalIter) 189 | self.tb.add_scalar('data/loss_gen', float(loss_gen.data.cpu()), self.globalIter) 190 | 191 | if self.globalIter % self.config.save_image_every == 0: 192 | xall = torch.cat((xc_tilde, xc, xr), dim=0) 193 | xall = adjust_pixel_range(xall, range_from=[-1,1], range_to=[0,1]) 194 | self.tb.add_image_grid('grid/output', 8, xall.cpu().data, self.globalIter) 195 | 196 | xc_tilde_test = self.gen(xr_test) 197 | xall_test = torch.cat((xc_tilde_test, xc_test, xr_test), dim=0) 198 | xall_test = adjust_pixel_range(xall_test, range_from=[-1,1], range_to=[0,1]) 199 | self.tb.add_image_grid('grid/output_fixed', 8, xall_test.cpu().data, self.globalIter) 200 | 201 | # save image as png. 202 | mkdir_p(os.path.join(self.prefix, 'image')) 203 | image = make_image_grid(xc_tilde_test.cpu().data, 5) 204 | image = F.upsample(image.unsqueeze(0), size=(800, 800), mode='bilinear').squeeze() 205 | filename = 'Epoch_{}_Iter{}.png'.format(self.epoch, self.globalIter) 206 | vutils.save_image(image, os.path.join(self.prefix, 'image', filename), nrow=1) 207 | 208 | bar.finish() 209 | 210 | 211 | 212 | 213 | def save_checkpoint(self): 214 | print('save checkpoint') 215 | 216 | -------------------------------------------------------------------------------- /tf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.utils as vutils 3 | import numpy as np 4 | import torchvision.models as models 5 | from torchvision import datasets 6 | from tensorboardX import SummaryWriter 7 | import os, sys 8 | from utils import ( 9 | mkdir_p, 10 | make_image_grid) 11 | 12 | 13 | class TensorBoard: 14 | def __init__(self, path): 15 | self.path = path 16 | 17 | def initialize(self): 18 | mkdir_p(self.path) 19 | for i in range(1000): 20 | save_path = os.path.join(self.path, 'try{}'.format(i)) 21 | if not os.path.exists(save_path): 22 | self.writer = SummaryWriter(save_path) 23 | break 24 | 25 | def add_scalar(self, index, val, niter): 26 | self.writer.add_scalar(index, val, niter) 27 | 28 | def add_scalars(self, index, group_dict, niter): 29 | self.writer.add_scalar(index, group_dict, niter) 30 | 31 | def add_image_grid(self, index, nrow, x, niter): 32 | grid = make_image_grid(x, nrow) 33 | self.writer.add_image(index, grid, niter) 34 | 35 | def add_image_single(self, index, x, niter): 36 | self.writer.add_image(index, x, niter) 37 | 38 | def add_graph(self, index, x_input, model): 39 | torch.onnx.export(model, x_input, os.path.join(self.targ, "{}.proto".format(index)), verbose=True) 40 | self.writer.add_graph_onnx(os.path.join(self.targ, "{}.proto".format(index))) 41 | 42 | def export_json(self, out_file): 43 | self.writer.export_scalars_to_json(out_file) 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | ''' 53 | resnet18 = models.resnet18(False) 54 | writer = SummaryWriter() 55 | for n_iter in range(100): 56 | s1 = torch.rand(1) # value to keep 57 | s2 = torch.rand(1) 58 | writer.add_scalar('data/scalar1', s1[0], n_iter) #data grouping by `slash` 59 | writer.add_scalar('data/scalar2', s2[0], n_iter) 60 | writer.add_scalars('data/scalar_group', {"xsinx":n_iter*np.sin(n_iter), 61 | "xcosx":n_iter*np.cos(n_iter), 62 | "arctanx": np.arctan(n_iter)}, n_iter) 63 | dataset = datasets.MNIST('mnist', train=False, download=True) 64 | images = dataset.test_data[:100].float() 65 | label = dataset.test_labels[:100] 66 | features = images.view(100, 784) 67 | writer.add_embedding(features, metadata=label, label_img=images.unsqueeze(1)) 68 | 69 | # export scalar data to JSON for external processing 70 | writer.export_scalars_to_json("./all_scalars.json") 71 | writer.close() 72 | ''' 73 | 74 | 75 | 76 | ''' 77 | resnet18 = models.resnet18(False) 78 | writer = SummaryWriter() 79 | sample_rate = 44100 80 | freqs = [262, 294, 330, 349, 392, 440, 440, 440, 440, 440, 440] 81 | 82 | for n_iter in range(100): 83 | s1 = torch.rand(1) # value to keep 84 | s2 = torch.rand(1) 85 | writer.add_scalar('data/scalar1', s1[0], n_iter) #data grouping by `slash` 86 | writer.add_scalar('data/scalar2', s2[0], n_iter) 87 | writer.add_scalars('data/scalar_group', {"xsinx":n_iter*np.sin(n_iter), 88 | "xcosx":n_iter*np.cos(n_iter), 89 | "arctanx": np.arctan(n_iter)}, n_iter) 90 | x = torch.rand(32, 3, 64, 64) # output from network 91 | if n_iter%10==0: 92 | x = vutils.make_grid(x, normalize=True, scale_each=True) 93 | writer.add_image('Image', x, n_iter) 94 | x = torch.zeros(sample_rate*2) 95 | for i in range(x.size(0)): 96 | x[i] = np.cos(freqs[n_iter//10]*np.pi*float(i)/float(sample_rate)) # sound amplitude should in [-1, 1] 97 | writer.add_text('Text', 'text logged at step:'+str(n_iter), n_iter) 98 | for name, param in resnet18.named_parameters(): 99 | writer.add_histogram(name, param.clone().cpu().data.numpy(), n_iter) 100 | writer.add_pr_curve('xoxo', np.random.randint(2, size=100), np.random.rand(100), n_iter) #needs tensorboard 0.4RC or later 101 | dataset = datasets.MNIST('mnist', train=False, download=True) 102 | images = dataset.test_data[:100].float() 103 | label = dataset.test_labels[:100] 104 | features = images.view(100, 784) 105 | writer.add_embedding(features, metadata=label, label_img=images.unsqueeze(1)) 106 | 107 | # export scalar data to JSON for external processing 108 | writer.export_scalars_to_json("./all_scalars.json") 109 | 110 | writer.close() 111 | ''' 112 | 113 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | '''utils.py 3 | ''' 4 | 5 | from __future__ import print_function, absolute_import 6 | 7 | import errno 8 | import os 9 | import sys 10 | import time 11 | import math 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.init as init 16 | import torchvision.utils as vutils 17 | from torch.autograd import Variable 18 | 19 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter', 'compute_precision_top_k'] 20 | 21 | def compute_precision_top_k(output, target, top_k=(1,)): 22 | """Computes the precision@k for the specified values of k""" 23 | maxk = max(top_k) 24 | batch_size = target.size(0) 25 | 26 | _, pred = output.topk(maxk, 1, True, True) 27 | pred = pred.t() 28 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 29 | 30 | res = [] 31 | for k in top_k: 32 | correct_k = correct[:k].view(-1).float().sum(0) 33 | res.append(correct_k.mul_(100.0 / batch_size)) 34 | return res 35 | 36 | def get_mean_and_std(dataset): 37 | '''Compute the mean and std value of dataset.''' 38 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 39 | 40 | mean = torch.zeros(3) 41 | std = torch.zeros(3) 42 | print('==> Computing mean and std..') 43 | for inputs, targets in dataloader: 44 | for i in range(3): 45 | mean[i] += inputs[:,i,:,:].mean() 46 | std[i] += inputs[:,i,:,:].std() 47 | mean.div_(len(dataset)) 48 | std.div_(len(dataset)) 49 | return mean, std 50 | 51 | def init_params(net): 52 | '''Init layer parameters.''' 53 | for m in net.modules(): 54 | if isinstance(m, nn.Conv2d): 55 | init.kaiming_normal(m.weight, mode='fan_out') 56 | if m.bias: 57 | init.constant(m.bias, 0) 58 | elif isinstance(m, nn.BatchNorm2d): 59 | init.constant(m.weight, 1) 60 | init.constant(m.bias, 0) 61 | elif isinstance(m, nn.Linear): 62 | init.normal(m.weight, std=1e-3) 63 | if m.bias: 64 | init.constant(m.bias, 0) 65 | 66 | def mkdir_p(path): 67 | '''make dir if not exist''' 68 | try: 69 | os.makedirs(path) 70 | except OSError as exc: # Python >2.5 71 | if exc.errno == errno.EEXIST and os.path.isdir(path): 72 | pass 73 | else: 74 | raise 75 | 76 | def make_image_grid(x, nrow): 77 | if pow(nrow,2) < x.size(0): 78 | grid = vutils.make_grid( 79 | x[:nrow*nrow], 80 | nrow=nrow, 81 | padding=0, 82 | normalize=False, 83 | scale_each=True) 84 | else: 85 | grid = torch.FloatTensor(nrow*nrow, x.size(1), x.size(2), x.size(3)).uniform_() 86 | grid[:x.size(0)] = x 87 | grid = vutils.make_grid( 88 | grid, 89 | nrow=nrow, 90 | padding=0, 91 | normalize=False, 92 | scale_each=True) 93 | return grid 94 | 95 | 96 | def adjust_pixel_range( 97 | x, 98 | range_from=[0,1], 99 | range_to=[-1,1]): 100 | ''' 101 | adjust pixel range from to . 102 | ''' 103 | if (range_from[0] == range_to[0]) and (range_from[1] == range_to==[1]): 104 | return x 105 | else: 106 | scale = float(range_to[1]-range_to[0])/float(range_from[1]-range_from[0]) 107 | bias = range_to[0]-range_from[0]*scale 108 | x = x.mul(scale).add(bias) 109 | return x.clamp(range_to[0], range_to[1]) 110 | 111 | 112 | class AverageMeter(object): 113 | """Computes and stores the average and current value 114 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 115 | """ 116 | def __init__(self): 117 | self.reset() 118 | 119 | def reset(self): 120 | self.val = 0 121 | self.avg = 0 122 | self.sum = 0 123 | self.count = 0 124 | 125 | def update(self, val, n=1): 126 | self.val = val 127 | self.sum += val * n 128 | self.count += n 129 | self.avg = self.sum / self.count 130 | 131 | 132 | 133 | 134 | --------------------------------------------------------------------------------