├── __init__.py ├── README.md ├── dataset.py ├── model.py └── layoutgan.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LayoutGAN-Alpha 2 | Implementation of LayoutGAN https://arxiv.org/abs/1901.06767 3 | 4 | Using the SUNCG dataset for test. Codes for data processing are learned from deep-syth(https://github.com/brownvc/deep-synth) 5 | 6 | Codes here are just running version. Further, it needs to make some improvements and modifications. 7 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data 2 | # zc:把from ./data改成from data 3 | import sys,os 4 | p = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 5 | print(p) 6 | sys.path.insert(0,p) 7 | print(sys.path) 8 | from data import ObjectCategories, RenderedScene, RenderedComposite 9 | 10 | import random 11 | import math 12 | import torch 13 | import cv2 14 | import _pickle as pickle 15 | import color_index 16 | 17 | 18 | class Dataset(): 19 | """ 20 | Dataset for training/testing the layoutGAN network 21 | """ 22 | 23 | def __init__(self, data_root_dir, data_dir, scene_indices=(0, 6400), num_per_epoch=1, seed=None): 24 | """ 25 | Parameters 26 | ---------- 27 | data_root_dir (String): root dir where all data lives 28 | data_dir (String): directory where this dataset lives (relative to data_root_dir) 29 | scene_indices (tuple[int, int]): list of indices of scenes (in data_dir) that are considered part of this set 30 | num_per_epoch (int): number of random variants of each scene that will be used per training epoch 31 | """ 32 | self.data_root_dir = data_root_dir 33 | # self.data_dir = data_root_dir + '/' + data_dir 34 | self.data_dir = data_dir 35 | self.scene_indices = scene_indices 36 | self.num_per_epoch = num_per_epoch 37 | 38 | # Load up the map between SUNCG model IDs and category names 39 | # self.category_map = ObjectCategories(data_root_dir + '/suncg_data/ModelCategoryMapping.csv') 40 | # Also load up the list of coarse categories used in this particular dataset 41 | # self.categories = self.get_coarse_categories() 42 | # Build a reverse map from category to index 43 | # self.cat_to_index = {self.categories[i]:i for i in range(len(self.categories))} 44 | self.seed = seed 45 | 46 | def __len__(self): 47 | return (self.scene_indices[1] - self.scene_indices[0]) * self.num_per_epoch 48 | 49 | def __getitem__(self, index): 50 | if self.seed: 51 | random.seed(self.seed) 52 | 53 | i = int(index + self.scene_indices[0] / self.num_per_epoch) 54 | scene = RenderedScene(i, self.data_dir, self.data_root_dir) 55 | composite = scene.create_composite() 56 | 57 | num_categories = len(scene.categories) 58 | # Flip a coin for whether we're going remove objects or treat this as a complete scene 59 | 60 | num_objects = len(scene.object_nodes) 61 | object_nodes = scene.object_nodes 62 | 63 | # 理解:p_existing是输入的p 64 | # 一个场景的num_categories数量固定,每种标签至少有一个物体,但可能不止一个 65 | # 因此,列数是num_categories,而行数暂时先多填了5个 66 | # 疑问:1.每个场景的num_categories不一样,形成的one-hot vector长度不一致,是否 67 | # 需要改成固定长度的num_categories 68 | p_existing = torch.zeros(num_objects, num_categories) 69 | 70 | for i in range(num_objects): 71 | existing_categories = torch.zeros(num_categories) 72 | node = scene.object_nodes[i] 73 | composite.add_node(node) 74 | 75 | existing_categories[node["category"]] = 1 76 | p_existing[i] = existing_categories 77 | 78 | coordinates_existing = torch.zeros(num_objects, 4) 79 | 80 | wall = scene.wall 81 | wall_mask = wall.clone() 82 | index_nonzero = torch.nonzero(wall_mask) 83 | xmin_scene, ymin_scene = index_nonzero[0][0], index_nonzero[0][1] 84 | xmax_scene, ymax_scene = index_nonzero[index_nonzero.shape[0] - 1][0], \ 85 | index_nonzero[index_nonzero.shape[0] - 1][1] 86 | 87 | for i in range(num_objects): 88 | #existing_coordinates = torch.zeros(4) 89 | node = object_nodes[i] 90 | xmin, _, ymin, _ = node["bbox_min"] 91 | xmax, _, ymax, _ = node["bbox_max"] 92 | 93 | # TO DO 94 | # 1 scale coordinates(need to pre-define the height and width of map) 95 | # 获取房间俯视图的xmin_scene,xmax_scene,ymin_scene,ymax_scene 96 | # 将坐标归一化到0-1之间(房间的边缘是0和1) 97 | xmin = (xmin - xmin_scene) / (xmax_scene - xmin_scene).double() 98 | xmax = (xmax - xmin_scene) / (xmax_scene - xmin_scene).double() 99 | ymin = (ymin - ymin_scene) / (ymax_scene - ymin_scene).double() 100 | ymax = (ymax - ymin_scene) / (ymax_scene - ymin_scene).double() 101 | existing_coordinates = torch.Tensor((xmin, ymin, xmax, ymax)) 102 | 103 | coordinates_existing[i] = existing_coordinates 104 | existing_object = torch.cat((p_existing, coordinates_existing), 1) 105 | non_existing = torch.zeros(num_categories + 5 - num_objects, num_categories + 4) 106 | output = torch.cat((existing_object, non_existing), 0) 107 | 108 | #print("output shape=",output.shape) 109 | return output 110 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import sys,os 2 | p = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 3 | sys.path.insert(0,p) 4 | 5 | import argparse 6 | import random 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim as optim 12 | import torch.nn.functional as F 13 | import torch.utils.data 14 | import torchvision.datasets as dset 15 | import torchvision.transforms as transforms 16 | import torchvision.utils as vutils 17 | from tensorboardX import SummaryWriter 18 | from dataset import Dataset 19 | from layoutgan import Generator 20 | from layoutgan import Discriminator 21 | import numpy as np 22 | import utils 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--data-dir', type=str, default="bedroom", metavar='S') 26 | parser.add_argument('--save-resdir', type=str, default="train/bedroom", metavar='S') 27 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=1) 28 | parser.add_argument('--train-size', type=int, default=6400, metavar='N') 29 | parser.add_argument('--batchSize', type=int, default=64, help='input batch size') 30 | parser.add_argument('--ablation', type=str, default=None, metavar='S') 31 | parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network') 32 | parser.add_argument('--niter', type=int, default=60, help='number of epochs to train for') 33 | parser.add_argument('--lr', type=float, default=0.00002, help='learning rate, default=0.00002') 34 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') 35 | parser.add_argument('--cuda', action='store_true', default=True, help='enables cuda') 36 | parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use') 37 | parser.add_argument('--netG', default='', help="path to netG (to continue training)") 38 | parser.add_argument('--netD', default='', help="path to netD (to continue training)") 39 | parser.add_argument('--outf', type=str, default='LayoutGAN/result', 40 | help='folder to output images and model checkpoints') 41 | parser.add_argument('--manualSeed', type=int, help='manual seed') 42 | 43 | opt = parser.parse_args() 44 | print(opt) 45 | with open(f"{p}/data/{opt.data_dir}/final_categories_frequency","r") as f: 46 | lines = f.readlines() 47 | num_categories = len(lines) - 2 48 | 49 | # ---for wirte log----# 50 | # logfile=open('./{}/log.txt'.format(opt.outf),'w') 51 | logfile = open('./result/log.txt'.format(opt.outf), 'w') 52 | writer = SummaryWriter() 53 | 54 | 55 | def Log(msg): 56 | print(msg) 57 | logfile.write(msg + '\n') 58 | logfile.flush() 59 | 60 | 61 | try: 62 | os.makedirs(opt.outf) 63 | except OSError: 64 | pass 65 | 66 | if opt.manualSeed is None: 67 | opt.manualSeed = random.randint(1, 10000) 68 | print("Random Seed: ", opt.manualSeed) 69 | random.seed(opt.manualSeed) 70 | torch.manual_seed(opt.manualSeed) 71 | 72 | cudnn.benchmark = True 73 | 74 | if torch.cuda.is_available() and not opt.cuda: 75 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 76 | 77 | dataset = Dataset(data_root_dir=utils.get_data_root_dir(), 78 | data_dir=opt.data_dir, 79 | scene_indices=(0, opt.train_size), 80 | num_per_epoch=1) 81 | assert dataset 82 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, 83 | shuffle=True, num_workers=int(opt.workers)) 84 | 85 | 86 | def real_loss(D_out, smooth=False): 87 | labels = None 88 | batch_size = D_out.size(0) 89 | if smooth: 90 | labels = torch.ones(batch_size) * 0.9 91 | else: 92 | labels = torch.ones(batch_size) 93 | 94 | crit = nn.BCEWithLogitsLoss() 95 | loss = crit(D_out.squeeze(), labels) 96 | return loss 97 | 98 | 99 | def fake_loss(D_out): 100 | batch_size = D_out.size(0) 101 | labels = torch.zeros(batch_size) 102 | crit = nn.BCEWithLogitsLoss() 103 | loss = crit(D_out.squeeze(), labels) 104 | return loss 105 | 106 | 107 | element_num = num_categories + 5 108 | 109 | netG = Generator(num_categories + 4, element_num, num_categories) 110 | netD = Discriminator(num_categories, element_num, opt.imageSize, opt.imageSize, num_categories) 111 | print(netG) 112 | print(netD) 113 | 114 | #netG = torch.nn.DataParallel(netG).cuda() 115 | #netD = torch.nn.DataParallel(netD).cuda() 116 | 117 | 118 | d_optimizer = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 119 | g_optimizer = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 120 | 121 | cls_num = num_categories 122 | geo_num = 4 123 | print("cls_num=", cls_num, "\ngeo_num=", geo_num) 124 | 125 | for epoch in range(opt.niter): 126 | netD.train() 127 | 128 | print("\nepoch ", epoch, 129 | " netD train finish") 130 | 131 | netG.train() 132 | 133 | print("epoch ", epoch, 134 | " netG train finish") 135 | 136 | for batch_i, real_images in enumerate(dataloader): 137 | batch_size = real_images.size(0) 138 | 139 | # Train Discriminator. 140 | d_optimizer.zero_grad() 141 | 142 | D_real = netD(real_images) 143 | d_real_loss = real_loss(D_real) 144 | 145 | # !Random layout input generation have logic error, should be fixed. 146 | zlist = [] 147 | for i in range(batch_size): 148 | ##############zc修改 149 | # cls_z = np.ones((element_num, num_categories)) 150 | cls_z = np.random.uniform(size=(element_num, num_categories)) 151 | geo_z = np.random.normal(0, 1, size=(element_num, geo_num)) 152 | 153 | z = torch.FloatTensor(np.concatenate((cls_z, geo_z), axis=1)) 154 | zlist.append(z) 155 | 156 | fake_images = netG(torch.stack(zlist)) 157 | 158 | D_fake = netD(fake_images) 159 | d_fake_loss = fake_loss(D_fake) 160 | 161 | d_loss = d_real_loss + d_fake_loss 162 | d_loss.backward() 163 | d_optimizer.step() 164 | 165 | # Train Generator 166 | g_optimizer.zero_grad() 167 | 168 | # !Random layout input generation have logic error, should be fixed. 169 | zlist2 = [] 170 | for i in range(batch_size): 171 | ##############zc修改 172 | # cls_z = np.ones((element_num, cls_num)) 173 | cls_z = np.random.uniform(size=(element_num, cls_num)) 174 | geo_z = np.random.normal(0, 1, size=(element_num, geo_num)) 175 | 176 | z = torch.FloatTensor(np.concatenate((cls_z, geo_z), axis=1)) 177 | zlist2.append(z) 178 | 179 | fake_images2 = netG(torch.stack(zlist2)) 180 | D_fake = netD(fake_images2) 181 | g_loss = real_loss(D_fake) 182 | writer.add_scalar('data/D_LOSS', d_loss, epoch) 183 | writer.add_scalar('data/G_LOSS', g_loss, epoch) 184 | writer.add_scalars('data/D_G_LOSS', {'D_LOSS': d_loss, 185 | 'G_LOSS': g_loss}, epoch) 186 | print_msg = '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f' % \ 187 | (epoch + 1, opt.niter, batch_i, len(dataloader), d_loss.item(), g_loss.item()) 188 | Log(print_msg) 189 | 190 | if epoch % 5 == 0: 191 | torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch)) 192 | torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch)) 193 | -------------------------------------------------------------------------------- /layoutgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import random 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel 7 | import torch.backends.cudnn as cudnn 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | import torch.utils.data 11 | import torchvision.datasets as dset 12 | import torchvision.transforms as transforms 13 | import torchvision.utils as vutils 14 | from tensorboardX import SummaryWriter 15 | 16 | class Attention(nn.Module): 17 | def __init__(self,in_channels,dimension=2,sub_sample=False,bn=True,generate=True): 18 | super(Attention, self).__init__() 19 | self.inter_channels=in_channels//2 if in_channels>1 else 1 20 | self.generate=generate 21 | if dimension==2: 22 | conv_nd = nn.Conv2d 23 | max_pool_layer = nn.MaxPool2d(kernel_size=(2,2)) 24 | bn = nn.BatchNorm2d 25 | if dimension==1: 26 | conv_nd=nn.Conv1d 27 | max_pool_layer=nn.MaxPool1d(kernel_size=(2)) 28 | bn=nn.BatchNorm1d 29 | 30 | self.g=conv_nd(in_channels,self.inter_channels,kernel_size=1,stride=1,padding=0) 31 | if bn: 32 | self.W=nn.Sequential(conv_nd(self.inter_channels,in_channels,kernel_size=1,stride=1,padding=0), 33 | bn(in_channels)) 34 | nn.init.constant(self.W[1].weight,0) 35 | nn.init.constant(self.W[1].bias,0) 36 | else: 37 | self.W=conv_nd(self.inter_channels,in_channels,kernel_size=1,stride=1,padding=0) 38 | nn.init.constant(self.W.weight, 0) 39 | nn.init.constant(self.W.bias, 0) 40 | 41 | self.theta=conv_nd(in_channels,self.inter_channels,kernel_size=1,stride=1,padding=0) 42 | self.phi=conv_nd(in_channels,self.inter_channels,kernel_size=1,stride=1,padding=0) 43 | if sub_sample: 44 | self.g=nn.Sequential(self.g,max_pool_layer) 45 | self.phi=nn.Sequential(self.phi,max_pool_layer) 46 | 47 | def forward(self,x): 48 | batch_size=x.size(0) 49 | g_x=self.g(x).view(batch_size,self.inter_channels,-1) 50 | g_x=g_x.permute(0,2,1) 51 | theta_x=self.theta(x).view(batch_size,self.inter_channels,-1) 52 | theta_x=theta_x.permute(0,2,1) 53 | phi_x=self.phi(x).view(batch_size,self.inter_channels,-1) 54 | f=torch.matmul(theta_x,phi_x) 55 | N=f.size(-1) 56 | f_div_c=f/N; 57 | y=torch.matmul(f_div_c,g_x) 58 | y=y.permute(0,2,1).contiguous() 59 | y=y.view(batch_size,self.inter_channels,*x.size()[2:]) 60 | W_y=self.W(y) 61 | if self.generate: 62 | output=W_y+x 63 | else: 64 | output=W_y 65 | return output 66 | 67 | #input [batch,n,feature_dim] linear [N,*,H_in] 68 | # cls_num here is a questionable param. The value of it should be 1? 69 | class Generator(nn.Module): 70 | def __init__(self,in_channels,num_fea,cls_num=1): 71 | super(Generator, self).__init__() 72 | self.cls_num=cls_num 73 | 74 | self.fc1=nn.Linear(in_channels,in_channels*2) 75 | self.bn1=nn.BatchNorm1d(num_fea) # bn is needed? 76 | self.fc2=nn.Linear(in_channels*2,in_channels*2*2) 77 | self.bn2=nn.BatchNorm1d(num_fea) 78 | self.fc3=nn.Linear(in_channels*2*2,in_channels*2*2) 79 | 80 | # self.attention_1=Attention(1) 81 | # self.attention_2=Attention(1) 82 | # self.attention_3 = Attention(1) 83 | # self.attention_4 = Attention(1) 84 | 85 | self.attention_1 = Attention(in_channels*2*2,1) 86 | self.attention_2 = Attention(in_channels*2*2,1) 87 | self.attention_3 = Attention(in_channels*2*2,1) 88 | self.attention_4 = Attention(in_channels*2*2,1) 89 | 90 | self.fc4=nn.Linear(in_channels*2*2,in_channels*2) 91 | self.bn4 = nn.BatchNorm1d(num_fea) 92 | self.fc5 = nn.Linear(in_channels * 2, in_channels) 93 | 94 | self.fc6=nn.Linear(in_channels,cls_num) 95 | self.fc7=nn.Linear(in_channels,in_channels-cls_num) 96 | 97 | def forward(self,x): 98 | x=F.relu(self.bn1(self.fc1(x))) 99 | x = F.relu(self.bn2(self.fc2(x))) 100 | x=F.sigmoid(self.fc3(x)) 101 | #for 2d relation 102 | # x=x.unsqueeze(1) 103 | # x = self.attention_1(x) 104 | # x = self.attention_2(x) 105 | # x = self.attention_3(x) 106 | # x = self.attention_4(x) 107 | # x = x.squeeze(1) 108 | 109 | #for 1d relation 110 | x=x.permute(0,2,1).contiguous() 111 | x=self.attention_1(x) 112 | x=self.attention_2(x) 113 | x = self.attention_3(x) 114 | x = self.attention_4(x) 115 | x=x.permute(0,2,1).contiguous() 116 | 117 | out=F.relu(self.bn4(self.fc4(x))) 118 | out=F.relu(self.fc5(out)) 119 | 120 | cls=F.sigmoid(self.fc6(out)) 121 | geo=F.sigmoid(self.fc7(out)) 122 | output=torch.cat((cls,geo),2) 123 | return output 124 | 125 | class Discriminator(nn.Module): 126 | def __init__(self,output_channels,num_fea,height,width,cls_num=1): 127 | super(Discriminator, self).__init__() 128 | self.cls_num=cls_num 129 | self.num_elements=num_fea 130 | self.height=height 131 | self.width=width 132 | self.conv1=nn.Conv2d(cls_num,output_channels,kernel_size=3,stride=2) 133 | self.bn1=nn.BatchNorm2d(output_channels) 134 | self.conv2=nn.Conv2d(output_channels,output_channels*2,kernel_size=3,stride=2) 135 | self.bn2=nn.BatchNorm2d(output_channels*2) 136 | self.conv3=nn.Conv2d(output_channels*2,output_channels*2,kernel_size=3,stride=2) 137 | self.bn3=nn.BatchNorm2d(output_channels*2) 138 | self.fc1=nn.Linear(self.cls_num*2*49,output_channels) #2*29*7*7 139 | self.fc2=nn.Linear(output_channels,1) 140 | 141 | def rectangle_render(self,x): 142 | # I's size [b,c,h,w] 143 | # x's size [b,num_ele+5,cls_num+4] 144 | batch_size=x.size(0) 145 | # wrong 146 | # I=torch.zeros((batch_size,self.num_elements,self.height,self.width)) 147 | h_index=torch.arange(0,self.height) 148 | w_index=torch.arange(0,self.width) 149 | hh=h_index.repeat(len(w_index)) 150 | ww=w_index.view(-1,1).repeat(1,len(h_index)).view(-1) 151 | index=torch.stack([ww,hh],dim=-1) #[[0,0],[0,1]...[ww-1,hh-1]] 152 | index_=index.unsqueeze(0).repeat(batch_size,1,1) 153 | index_col=index_[:,:,0] 154 | index_row=index_[:,:,1] 155 | x_trans=x.permute(0,2,1) 156 | index_col=index_col.unsqueeze(2) 157 | index_row=index_row.unsqueeze(2) 158 | sub_xL=index_col-x_trans[:,self.cls_num,:].unsqueeze(1).long() 159 | sub_yT=index_row-x_trans[:,self.cls_num+1,:].unsqueeze(1).long() 160 | sub_xR=index_col-x_trans[:,self.cls_num+2,:].unsqueeze(1).long() 161 | sub_yB=index_row-x_trans[:,self.cls_num+3,:].unsqueeze(1).long() 162 | sub_y=x_trans[:,self.cls_num+3,:].unsqueeze(1).long()-index_row 163 | sub_x=x_trans[:,self.cls_num+2,:].unsqueeze(1).long()-index_col 164 | tmp1=F.relu(sub_yT) 165 | tmp1[tmp1>1]=1 166 | tmp2=F.relu(sub_y) 167 | tmp2[tmp2>1]=1 168 | F_0=F.relu(1-torch.abs(sub_xL))*tmp1*tmp2 169 | F_1=F.relu(1-torch.abs(sub_xR))*tmp1*tmp2 170 | tmp1 = F.relu(sub_xL) 171 | tmp1[tmp1 > 1] = 1 172 | tmp2 = F.relu(sub_x) 173 | tmp2[tmp2 > 1] = 1 174 | F_2=F.relu(1-torch.abs(sub_yT))*tmp1*tmp2 175 | F_3=F.relu(1-torch.abs(sub_yB))*tmp1*tmp2 176 | # val shape [batch_size,hei*wid,num_elem] 177 | val,index_ftheta=torch.max(torch.stack((F_0,F_1,F_2,F_3),dim=2),dim=2) 178 | 179 | x_prob=x[:,:,:self.cls_num] 180 | x_prob=x_prob.unsqueeze(1)#[batch_size,1,num_elem,cls_num] 181 | F_theta=val.unsqueeze(3).float() #[batch_szie,hei*wid,num_elem,1] 182 | prod=x_prob*F_theta #[batch_szie,hei*wid,num_elem,cls_num] 183 | res,index_res=torch.max(prod,2) 184 | I=res.contiguous().view(batch_size,self.height,self.width,-1).permute(0,1,3,2) 185 | I=I.permute(0,2,1,3) 186 | out=F.relu(self.bn1(self.conv1(I))) 187 | out=F.relu(self.bn2(self.conv2(out))) 188 | out=F.relu(self.bn3(self.conv3(out))) 189 | out=out.view(out.shape[0],-1) 190 | out=self.fc1(out) 191 | out=self.fc2(out) 192 | out=F.sigmoid(out) 193 | return out 194 | def forward(self,x): 195 | output=self.rectangle_render(x) 196 | return output 197 | 198 | if __name__=='__main__': 199 | x=torch.randn((3,4,5)) 200 | print(x.size()) 201 | model=Generator(5,4) 202 | output=model(x) 203 | print(output) 204 | print(output.size()) --------------------------------------------------------------------------------