├── example1.png ├── tool └── prepare_data.py ├── README.md └── src ├── data_set.py ├── model.py └── train.py /example1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuehy/pytorch-PixelDTGAN/HEAD/example1.png -------------------------------------------------------------------------------- /tool/prepare_data.py: -------------------------------------------------------------------------------- 1 | import six.moves.cPickle as Pickle 2 | import os 3 | 4 | 5 | dataset_dir = '../data/lookbook/data/' 6 | models = [] 7 | clothes = [] 8 | 9 | for filename in os.listdir(dataset_dir): 10 | if filename.endswith('.jpg'): 11 | if filename.split('_')[1].endswith('0'): 12 | models.append(filename) 13 | else: 14 | clothes.append(filename) 15 | 16 | print(len(models)) 17 | print(len(clothes)) 18 | 19 | i = 0 20 | match = [] 21 | while i < len(clothes): 22 | pid = clothes[i][3:9] 23 | match_i = [] 24 | j = 0 25 | while j < len(models): 26 | if models[j][3:9] == pid: 27 | match_i.append(models[j]) 28 | j += 1 29 | match.append(match_i) 30 | i += 1 31 | 32 | with open('cloth_table.pkl', 'wb') as cloth_table: 33 | Pickle.dump(clothes, cloth_table) 34 | with open('model_table.pkl', 'wb') as model_table: 35 | Pickle.dump(match, model_table) 36 | 37 | print('done') 38 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pixel Level Domain Transfer 2 | A pytorch remake of the implementation of "Pixel-Level Domain Transfer" [PixelDTGAN](https://github.com/fxia22/PixelDTGAN). For convenient use under pytorch. 3 | 4 | # Dependency 5 | - ```pytorch >= 0.4.0``` 6 | - [visdom](https://github.com/facebookresearch/visdom). 7 | - [opencv](https://github.com/opencv/opencv) 8 | 9 | # Training 10 | 11 | To train the model, put the LOOKBOOK dataset under repository, resize images to 64*64. Prepare the dataset using `tool/prepare_data.py`. 12 | Then goto src dir and run 13 | ``` 14 | python3 train.py 15 | ``` 16 | 17 | # Monitor the performance 18 | 19 | 20 | - Install [visdom](https://github.com/facebookresearch/visdom). 21 | - Start the visdom server with ```python3 -m visdom.server 5274``` 22 | - Open this URL in your browser: `http://localhost:5274` You will see the loss curve as well as the image examples. 23 | 24 | After 22k iterations, the results: 25 | 26 | ![22k](https://github.com/xuehy/pytorch-PixelDTGAN/blob/master/example1.png) 27 | 28 | 29 | ## Acknowledgement 30 | + [@fxia22's original repo](https://github.com/fxia22/PixelDTGAN) -------------------------------------------------------------------------------- /src/data_set.py: -------------------------------------------------------------------------------- 1 | import six.moves.cPickle as Pickle 2 | import torch as th 3 | import cv2 4 | import numpy as np 5 | 6 | 7 | def loadImage(path): 8 | inImage_ = cv2.imread(path) 9 | inImage = cv2.cvtColor(inImage_, cv2.COLOR_RGB2BGR) 10 | info = np.iinfo(inImage.dtype) 11 | inImage = inImage.astype(np.float) / info.max 12 | 13 | iw = inImage.shape[1] 14 | ih = inImage.shape[0] 15 | if iw < ih: 16 | inImage = cv2.resize(inImage, (64, int(64 * ih/iw))) 17 | else: 18 | inImage = cv2.resize(inImage, (int(64 * iw / ih), 64)) 19 | inImage = inImage[0:64, 0:64] 20 | return th.from_numpy(2 * inImage - 1).transpose(0, 2).transpose( 21 | 1, 2 22 | ) 23 | 24 | 25 | class LookbookDataset(): 26 | def __init__(self, data_dir, index_dir): 27 | self.data_dir = data_dir 28 | with open(index_dir+'cloth_table.pkl', 'rb') as cloth: 29 | self.cloth_table = Pickle.load(cloth) 30 | with open(index_dir+'model_table.pkl', 'rb') as model: 31 | self.model_table = Pickle.load(model) 32 | 33 | self.cn = len(self.cloth_table) 34 | self.path = data_dir 35 | 36 | def getbatch(self, batchsize): 37 | batch1 = [] 38 | batch2 = [] 39 | batch3 = [] 40 | for i in range(batchsize): 41 | seed = th.randint(1, 100000, (1,)).item() 42 | th.manual_seed((i+1)*seed) 43 | r1 = th.randint(0, self.cn, (1,)).item() 44 | r2 = th.randint(0, self.cn, (1,)).item() 45 | r1 = int(r1) 46 | r2 = int(r2) 47 | mn = len(self.model_table[r1]) 48 | r3 = th.randint(0, mn, (1,)).item() 49 | r3 = int(r3) 50 | 51 | path1 = self.cloth_table[r1] 52 | path2 = self.cloth_table[r2] 53 | path3 = self.model_table[r1][r3] 54 | img1 = loadImage(self.path + path1) 55 | img2 = loadImage(self.path + path2) 56 | img3 = loadImage(self.path + path3) 57 | batch1.append(img1) 58 | batch2.append(img2) 59 | batch3.append(img3) 60 | return th.stack(batch1), th.stack(batch2), th.stack(batch3) 61 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch as th 3 | import torch.nn as nn 4 | 5 | 6 | class NetG(nn.Module): 7 | def __init__(self, nc=3, ngf=96): 8 | super(NetG, self).__init__() 9 | self.converter = nn.Sequential( 10 | nn.Conv2d(nc, ngf, kernel_size=4, stride=2, padding=1, bias=False), 11 | nn.LeakyReLU(0.2, True), 12 | 13 | nn.Conv2d(ngf, ngf*2, kernel_size=4, stride=2, padding=1, 14 | bias=False), 15 | nn.BatchNorm2d(ngf*2), 16 | nn.LeakyReLU(0.2, True), 17 | 18 | nn.Conv2d(ngf*2, ngf*4, kernel_size=4, stride=2, padding=1, 19 | bias=False), 20 | nn.BatchNorm2d(ngf*4), 21 | nn.LeakyReLU(0.2, True), 22 | 23 | nn.Conv2d(ngf*4, ngf*8, kernel_size=4, stride=2, padding=1, 24 | bias=False), 25 | nn.BatchNorm2d(ngf*8), 26 | nn.LeakyReLU(0.2, True), 27 | 28 | nn.ConvTranspose2d(ngf*8, ngf*4, 29 | kernel_size=4, stride=2, padding=1, bias=False), 30 | nn.BatchNorm2d(ngf*4), 31 | nn.ReLU(True), 32 | 33 | nn.ConvTranspose2d(ngf*4, ngf*2, kernel_size=4, 34 | stride=2, padding=1, bias=False), 35 | nn.BatchNorm2d(ngf*2), 36 | nn.ReLU(True), 37 | 38 | nn.ConvTranspose2d(ngf*2, ngf, kernel_size=4, 39 | stride=2, padding=1, bias=False), 40 | nn.BatchNorm2d(ngf), 41 | nn.ReLU(True), 42 | 43 | nn.ConvTranspose2d(ngf, nc, kernel_size=4, stride=2, padding=1, 44 | bias=False), 45 | nn.Tanh() 46 | ) 47 | 48 | def forward(self, x): 49 | x = self.converter(x) 50 | return x 51 | 52 | 53 | class NetD(nn.Module): 54 | def __init__(self, nc=3, ndf=96): 55 | super(NetD, self).__init__() 56 | self.discriminator = nn.Sequential( 57 | nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1, bias=False), 58 | nn.LeakyReLU(0.2, True), 59 | 60 | nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1, 61 | bias=False), 62 | nn.BatchNorm2d(ndf*2), 63 | nn.LeakyReLU(0.2, True), 64 | 65 | nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1, 66 | bias=False), 67 | nn.BatchNorm2d(ndf*4), 68 | nn.LeakyReLU(0.2, True), 69 | 70 | nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1, 71 | bias=False), 72 | nn.BatchNorm2d(ndf*8), 73 | nn.LeakyReLU(0.2, True), 74 | 75 | nn.Conv2d(ndf*8, 1, kernel_size=4, stride=4, bias=False), 76 | nn.Sigmoid() 77 | ) 78 | 79 | def forward(self, x): 80 | x = self.discriminator(x) 81 | return x.view(-1, 1) 82 | 83 | 84 | ''' 85 | domain discriminator 86 | ''' 87 | 88 | 89 | class NetA(nn.Module): 90 | def __init__(self, nc=3, ndf=96): 91 | super(NetA, self).__init__() 92 | self.discriminator = nn.Sequential( 93 | nn.Conv2d(nc*2, ndf, kernel_size=4, stride=2, padding=1, 94 | bias=False), 95 | nn.LeakyReLU(0.2, True), 96 | 97 | nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1, 98 | bias=False), 99 | nn.BatchNorm2d(ndf*2), 100 | nn.LeakyReLU(0.2, True), 101 | 102 | nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1, 103 | bias=False), 104 | nn.BatchNorm2d(ndf*4), 105 | nn.LeakyReLU(0.2, True), 106 | 107 | nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1, 108 | bias=False), 109 | nn.BatchNorm2d(ndf*8), 110 | nn.LeakyReLU(0.2, True), 111 | 112 | nn.Conv2d(ndf*8, 1, kernel_size=4, stride=4, bias=False), 113 | nn.Sigmoid() 114 | ) 115 | 116 | def forward(self, x): 117 | x = self.discriminator(x) 118 | return x.view(-1, 1) 119 | 120 | 121 | if __name__ == '__main__': 122 | netd = NetD() 123 | a = th.zeros(128, 3, 64, 64) 124 | b = netd(a) 125 | print(b.shape) 126 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | # speed up the loading of the training data 2 | import cv2 3 | import numpy as np 4 | import torch as th 5 | import itertools 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.nn.init as init 9 | from model import NetG, NetD, NetA 10 | from data_set import LookbookDataset 11 | import torch.optim as optim 12 | import visdom 13 | from torchvision.utils import make_grid 14 | 15 | vis = visdom.Visdom(port=5274) 16 | win = None 17 | win1 = None 18 | netg = NetG() 19 | netd = NetD() 20 | neta = NetA() 21 | netg.train() 22 | netd.train() 23 | neta.train() 24 | device = th.device("cuda:2") 25 | 26 | # weights init 27 | all_mods = itertools.chain() 28 | all_mods = itertools.chain(all_mods, [ 29 | list(netg.children())[0].children(), 30 | list(netd.children())[0].children(), 31 | list(neta.children())[0].children() 32 | ]) 33 | for mod in all_mods: 34 | if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.ConvTranspose2d): 35 | init.normal_(mod.weight, 0.0, 0.02) 36 | elif isinstance(mod, nn.BatchNorm2d): 37 | init.normal_(mod.weight, 1.0, 0.02) 38 | init.constant_(mod.bias, 0.0) 39 | 40 | netg = netg.to(device) 41 | netd = netd.to(device) 42 | neta = neta.to(device) 43 | 44 | dataset = LookbookDataset(data_dir='../data/lookbook/data/', 45 | index_dir='../tool/') 46 | 47 | iteration = 0 48 | lr = 0.0002 49 | real_label = 1 50 | fake_label = 0 51 | fineSize = 64 52 | 53 | label = th.zeros((128, 1), requires_grad=False).to(device) 54 | optimG = optim.Adam(netg.parameters(), lr=lr/2) 55 | optimD = optim.Adam(netd.parameters(), lr=lr/3) 56 | optimA = optim.Adam(neta.parameters(), lr=lr/3) 57 | print('Training starts') 58 | while iteration < 1000000: 59 | ass_label, noass_label, img = dataset.getbatch(128) 60 | ass_label = ass_label.to(device).to(th.float32) 61 | noass_label = noass_label.to(device).to(th.float32) 62 | img = img.to(device).to(th.float32) 63 | # update D 64 | lossD = 0 65 | optimD.zero_grad() 66 | output = netd(ass_label) 67 | label.fill_(real_label) 68 | lossD_real1 = F.binary_cross_entropy(output, label) 69 | lossD += lossD_real1.item() 70 | lossD_real1.backward() 71 | 72 | label.fill_(real_label) 73 | output1 = netd(noass_label) 74 | lossD_real2 = F.binary_cross_entropy(output1, label) 75 | lossD == lossD_real2.item() 76 | lossD_real2.backward() 77 | 78 | fake = netg(img).detach() 79 | label.fill_(fake_label) 80 | output2 = netd(fake) 81 | 82 | lossD_fake = F.binary_cross_entropy(output2, label) 83 | lossD += lossD_fake.item() 84 | lossD_fake.backward() 85 | 86 | optimD.step() 87 | # update A 88 | lossA = 0 89 | optimA.zero_grad() 90 | assd = th.cat((img, ass_label), 1) 91 | noassd = th.cat((img, noass_label), 1) 92 | fake = netg(img).detach() 93 | faked = th.cat((img, fake), 1) 94 | 95 | label.fill_(real_label) 96 | output1 = neta(assd) 97 | lossA_real1 = F.binary_cross_entropy(output1, label) 98 | lossA += lossA_real1.item() 99 | lossA_real1.backward() 100 | 101 | label.fill_(fake_label) 102 | output = neta(noassd) 103 | lossA_real2 = F.binary_cross_entropy(output, label) 104 | lossA += lossA_real2.item() 105 | lossA_real2.backward() 106 | 107 | label.fill_(fake_label) 108 | output = neta(faked) 109 | lossA_fake = F.binary_cross_entropy(output, label) 110 | lossA += lossA_fake.item() 111 | lossA_fake.backward() 112 | optimA.step() 113 | # update G 114 | lossG = 0 115 | optimG.zero_grad() 116 | fake = netg(img) 117 | output = netd(fake) 118 | 119 | label.fill_(real_label) 120 | lossGD = F.binary_cross_entropy(output, label) 121 | lossG += lossGD.item() 122 | lossGD.backward(retain_graph=True) 123 | 124 | faked = th.cat((img, fake), 1) 125 | output = neta(faked) 126 | label.fill_(real_label) 127 | lossGA = F.binary_cross_entropy(output, label) 128 | lossG += lossGA.item() 129 | lossGA.backward() 130 | optimG.step() 131 | 132 | iteration += 1 133 | 134 | if iteration % 20 == 0: 135 | with th.no_grad(): 136 | netg.eval() 137 | fake = netg(img) 138 | netg.train() 139 | fake = (fake + 1) / 2 * 255 140 | real = (ass_label + 1) / 2 * 255 141 | ori = (img + 1) / 2 * 255 142 | al = th.cat((fake, real, ori), 2) 143 | display = make_grid(al, 10).cpu().numpy() 144 | if win1 is None: 145 | win1 = vis.image(display, opts=dict(title="fake", caption='fake')) 146 | else: 147 | vis.image(display, win=win1) 148 | if iteration % 20 == 0: 149 | print('iter = {}, ErrG = {}, ErrA = {}, ErrD = {}'.format( 150 | iteration, lossG/2, lossA/3, lossD/3 151 | )) 152 | if win is None: 153 | win = vis.line(X=np.array([[iteration, iteration, 154 | iteration]]), 155 | Y=np.array([[lossG/2, lossA/3, lossD/3]]), 156 | opts=dict( 157 | ylabel='loss', 158 | xlabel='iterations', 159 | legend=['lossG', 'lossA', 'lossD'] 160 | )) 161 | else: 162 | vis.line(X=np.array([[iteration, iteration, 163 | iteration]]), 164 | Y=np.array([[lossG/2, lossA/3, lossD/3]]), 165 | win=win, 166 | update='append') 167 | --------------------------------------------------------------------------------