├── discriminator.py ├── README.md ├── losses.py ├── generator.py ├── batch_data.py └── main.py /discriminator.py: -------------------------------------------------------------------------------- 1 | ''' 2 | the author is leilei 3 | you have so many choices: deeplab_v3 、or based vgg16 -> u-net 4 | ''' 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ssgan 2 | Semi Supervised Semantic Segmentation Using Generative Adversarial Network ; Pytorch 3 | 4 | ### Environment 5 | ``` 6 | python:3.5 7 | Pytorch:0.40 8 | ``` 9 | 10 | ### Note 11 | ``` 12 | 由于论文未给出代码,并且此论文为“分割”SEMI-GAN,与分类有相似之处,但仍有巨大区别, 13 | 在参考一些分类SEMI-GAN后,复现此半监督分割GAN论文代码。 14 | 若有问题,请及时指出,谢谢。 15 | 16 | 注意:测试时请添加model.eval() and with torch.no_grad(): 17 | ``` 18 | 19 | ### Refer 20 | + [Semi Supervised Semantic Segmentation Using Generative Adversarial Network](https://arxiv.org/abs/1703.09695) 21 | 22 | ### Other 23 | + [segmentation_pytorch](https://github.com/gengyanlei/segmentation_pytorch) 24 | + [building-segmentation-dataset](https://github.com/gengyanlei/build_segmentation_dataset) 25 | + [fire-smoke-detect-dataset](https://github.com/gengyanlei/fire-detect-yolov4) 26 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | ''' 2 | the author is leilei; 3 | Loss functions are in here. 4 | 分别计算 有标签真实数据损失函数、生成数据损失函数、无标签真实数据损失函数。 5 | ''' 6 | 7 | def log_sum_exp(x,axis=1): 8 | ''' 9 | Args: 10 | x : [n*h*w,c],semantic segmentation‘s output’s shape is [n,c,h,w],before input need to reshape [n*h*w,c] 11 | ''' 12 | m = torch.max(x,dim=axis)[0] 13 | return m+torch.log(torch.sum(torch.exp(x-torch.unsqueeze(m,dim=axis)),dim=axis)) 14 | 15 | def Loss_label(pred,label): 16 | ''' 17 | pred: [n,c,h,w],need to transpose [n,h,w,c],then reshape [n*h*w,c] 18 | label: [n,h,w] ,tensor need to numpy ,then need to reshape [n*h*w,1] 19 | ''' 20 | shape = pred.shape# n c h w 21 | # predict before softmax 22 | output_before_softmax_lab = pred.transpose(1,2).transpose(2,3).reshape([-1,shape[1]])# [n*h*w, c] 23 | 24 | label_ = label.data.cpu().numpy().reshape([-1,]) 25 | # l_lab before softmax 26 | l_lab = output_before_softmax_lab[np.arange(label_.shape[0]),label_] 27 | # compute two value 28 | loss_lab = -torch.mean(l_lab) + torch.mean(log_sum_exp(output_before_softmax_lab)) 29 | 30 | return loss_lab 31 | 32 | def Loss_fake(pred): 33 | ''' 34 | pred: [n,c,h,w],need to transpose [n,h,w,c],then reshape [n*h*w,c] 35 | ''' 36 | shape = pred.shape# n c h w 37 | # predict before softmax 38 | output_before_softmax_gen = pred.transpose(1,2).transpose(2,3).reshape([-1,shape[1]])# [n*h*w, c] 39 | l_gen = log_sum_exp(output_before_softmax_gen) 40 | loss_gen = torch.mean(F.softplus(l_gen)) 41 | 42 | return loss_gen 43 | 44 | def Loss_unlabel(pred): 45 | ''' 46 | pred: [n,c,h,w],need to transpose [n,h,w,c],then reshape [n*h*w,c] 47 | ''' 48 | shape = pred.shape# n c h w 49 | # predict before softmax 50 | output_before_softmax_unl = pred.transpose(1,2).transpose(2,3).reshape([-1,shape[1]])# [n*h*w, c] 51 | 52 | l_unl = log_sum_exp(output_before_softmax_unl) 53 | loss_unl = -torch.mean(l_unl) + torch.mean(F.softplus(l_unl)) 54 | 55 | return loss_unl 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /generator.py: -------------------------------------------------------------------------------- 1 | ''' 2 | the author is leilei 3 | ''' 4 | 5 | from torch import nn 6 | 7 | class Generator(nn.Module): 8 | def __init__(self,class_number): 9 | super().__init__() 10 | self.linear = nn.Sequential(nn.Linear(10*10,768*16*16),nn.ReLU(inplace=True)) 11 | # reshape 12 | self.deconv1 = nn.Sequential(nn.ConvTranspose2d(768,384,3,2,1,1), 13 | nn.BatchNorm2d(384),nn.ReLU(inplace=True))#32*32 14 | self.deconv2 = nn.Sequential(nn.ConvTranspose2d(384,256,3,2,1,1), 15 | nn.BatchNorm2d(256),nn.ReLU(inplace=True))#64*64 16 | self.deconv3 = nn.Sequential(nn.ConvTranspose2d(256,192,3,2,1,1), 17 | nn.BatchNorm2d(192),nn.ReLU(inplace=True))#128*128 18 | # last layer no relu 19 | self.deconv4 = nn.Sequential(nn.ConvTranspose2d(192,class_number,3,2,1,1),nn.Tanh())#256*256 20 | 21 | def forward(self,x): 22 | x = self.linear(x) 23 | x = x.reshape([-1,768,16,16]) 24 | x = self.deconv1(x) 25 | x = self.deconv2(x) 26 | x = self.deconv3(x) 27 | x = self.deconv4(x) 28 | 29 | return x 30 | 31 | 32 | def generator1(class_number): 33 | model = Generator(class_number) 34 | return model 35 | 36 | ###################################################################### 37 | class Generator(nn.Module): 38 | def __init__(self,class_number): 39 | super().__init__() 40 | # input [N,50*50] 由于全连接层 4096*4096 就很大了,因此这里不能设置那么大 41 | self.linear = nn.Sequential(nn.Linear(50*50,64*16*16),nn.ReLU(inplace=True)) 42 | # reshape 43 | self.deconv1 = nn.Sequential(nn.ConvTranspose2d(64,128,3,2,1,1), 44 | nn.BatchNorm2d(128),nn.ReLU(inplace=True))#32*32 45 | self.deconv2 = nn.Sequential(nn.ConvTranspose2d(128,256,3,2,1,1), 46 | nn.BatchNorm2d(256),nn.ReLU(inplace=True))#64*64 47 | self.deconv3 = nn.Sequential(nn.ConvTranspose2d(256,128,3,2,1,1), 48 | nn.BatchNorm2d(128),nn.ReLU(inplace=True))#128*128 49 | # last layer no relu 50 | self.deconv4 = nn.Sequential(nn.ConvTranspose2d(128,class_number,3,2,1,1),nn.Tanh())#256*256 51 | 52 | def forward(self,x): 53 | x = self.linear(x) 54 | x = x.reshape([-1,64,16,16]) 55 | x = self.deconv1(x) 56 | x = self.deconv2(x) 57 | x = self.deconv3(x) 58 | x = self.deconv4(x) 59 | 60 | return x 61 | 62 | 63 | def generator(class_number): 64 | model = Generator(class_number) 65 | return model 66 | 67 | 68 | -------------------------------------------------------------------------------- /batch_data.py: -------------------------------------------------------------------------------- 1 | ''' 2 | the author is leilei 3 | ''' 4 | 5 | import cv2 6 | import h5py 7 | import torch 8 | import random 9 | import torchvision 10 | from torchvision import transforms 11 | import numpy as np 12 | from torch import nn 13 | from torch.autograd import Variable 14 | from torch.utils import data 15 | 16 | class Data(data.Dataset): 17 | def __init__(self,dataset_path,transform=None,augmentation=True): 18 | self.hdf5=h5py.File(dataset_path,mode='r') 19 | self.image=self.hdf5['image'] 20 | self.label=self.hdf5['label']# not one-hot 21 | self.transform=transform# h*w*c => c*h*w and normlize[0 1] 22 | self.augmentation=augmentation 23 | def __len__(self): 24 | return self.image.shape[0] 25 | def data_augmentation(self,image,label): 26 | randint=random.randint(1,8) 27 | if randint==1:# left-right flip 28 | image=cv2.flip(image,1) 29 | label=cv2.flip(label,1) 30 | elif randint==2:# up-down-flip 31 | image=cv2.flip(image,0) 32 | label=cv2.flip(label,0) 33 | elif randint==3:# rotation 90 first width and then hight 34 | M=cv2.getRotationMatrix2D((image.shape[1]//2,image.shape[0]//2),90,1.0) 35 | image=cv2.warpAffine(image,M,(image.shape[1],image.shape[0]),flags=cv2.INTER_NEAREST) 36 | label=cv2.warpAffine(label,M,(image.shape[1],image.shape[0]),flags=cv2.INTER_NEAREST) 37 | elif randint==4:# rotation 270 38 | M=cv2.getRotationMatrix2D((image.shape[1]//2,image.shape[0]//2),270,1.0) 39 | image=cv2.warpAffine(image,M,(image.shape[1],image.shape[0]),flags=cv2.INTER_NEAREST) 40 | label=cv2.warpAffine(label,M,(image.shape[1],image.shape[0]),flags=cv2.INTER_NEAREST) 41 | return image,label 42 | def __getitem__(self,index): 43 | # here index is no N=1 direct h*w*c or h*w,then transform to c*h*w;dataloader auto add N=1 and batch_size 44 | img=self.image[index]# cv2 bgr remenber all example are the same as 'index' h*w*c 45 | lab=self.label[index].argmax(axis=-1) # no one-hot 46 | if self.augmentation: 47 | img,lab=self.data_augmentation(img,lab) 48 | if self.transform is not None: 49 | img=self.transform(img)# only totensor and normlize h*w*c=>c*h*w 50 | 51 | return img,np.int64(lab) # lab need int64=long 52 | 53 | 54 | class data(data.Dataset): # for unlabeled 55 | def __init__(self,dataset_path,transform=None,augmentation=False): 56 | self.hdf5=h5py.File(dataset_path,mode='r') 57 | self.image=self.hdf5['image'] 58 | self.transform=transform# h*w*c => c*h*w and normlize[0 1] 59 | self.augmentation=augmentation 60 | def __len__(self): 61 | return self.image.shape[0] 62 | def data_augmentation(self,image): 63 | randint=random.randint(1,8) 64 | if randint==1:# left-right flip 65 | image=cv2.flip(image,1) 66 | # label=cv2.flip(label,1) 67 | elif randint==2:# up-down-flip 68 | image=cv2.flip(image,0) 69 | # label=cv2.flip(label,0) 70 | elif randint==3:# rotation 90 first width and then hight 71 | M=cv2.getRotationMatrix2D((image.shape[1]//2,image.shape[0]//2),90,1.0) 72 | image=cv2.warpAffine(image,M,(image.shape[1],image.shape[0]),flags=cv2.INTER_NEAREST) 73 | # label=cv2.warpAffine(label,M,(image.shape[1],image.shape[0]),flags=cv2.INTER_NEAREST) 74 | elif randint==4:# rotation 270 75 | M=cv2.getRotationMatrix2D((image.shape[1]//2,image.shape[0]//2),270,1.0) 76 | image=cv2.warpAffine(image,M,(image.shape[1],image.shape[0]),flags=cv2.INTER_NEAREST) 77 | # label=cv2.warpAffine(label,M,(image.shape[1],image.shape[0]),flags=cv2.INTER_NEAREST) 78 | return image 79 | def __getitem__(self,index): 80 | img=self.image[index]# cv2 bgr 81 | if self.augmentation: 82 | img=self.data_augmentation(img) 83 | if self.transform is not None: 84 | img=self.transform(img)# only totensor and normlize 85 | return img 86 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | ''' 2 | the author is leilei 3 | ''' 4 | # 开始写主函数 5 | 6 | import os 7 | import cv2 8 | import torch 9 | from torchvision import transforms 10 | import numpy as np 11 | from torch import nn 12 | from torch.autograd import Variable # 可以不用 13 | from torch.utils import data 14 | # imread data 15 | from data_imread import batch_data 16 | # models 17 | from generator import generator 18 | from discriminator import discriminator 19 | # losses 20 | from losses import Loss_label, Loss_fake, Loss_unlabel 21 | 22 | ################### Hyper parameter ################### 23 | batch_size=16 24 | class_number=5 25 | lr_g=2e-4 26 | lr_d=1e-4 27 | power=0.9 28 | weight_decay=5e-4 29 | max_iter=20000 30 | dataset_path=r'**/Dataset/hdf5/f5.hdf5' 31 | dataset_nl_path=r'**/Dataset/hdf5/f2.hdf5' 32 | save_path=r'**/Pytorch_Code/ALL/ssgan/' 33 | 34 | #loss_s_path=os.path.join(save_path,'loss.npy') 35 | model_s_path=os.path.join(save_path,'model.pth') 36 | #loss_s_figure=os.path.join(save_path,'loss.tif') 37 | model_g_spath=os.path.join(save_path,'g/model_g.pth') 38 | 39 | ################### update lr ################### 40 | def lr_poly(base_lr,iters,max_iter,power): 41 | return base_lr*((1-float(iters)/max_iter)**power) 42 | def adjust_lr(optimizer,base_lr,iters,max_iter,power): 43 | lr=lr_poly(base_lr,iters,max_iter,power) 44 | optimizer.param_groups[0]['lr']=lr 45 | if len(optimizer.param_groups)>1: 46 | optimizer.param_groups[1]['lr']=lr*10 47 | 48 | ################### dataset loader ################### 49 | img_transform=transforms.ToTensor()# hwc=>chw and 0-255=>0-1 50 | dataset=batch_data.Data(dataset_path,transform=img_transform,augmentation=False) 51 | trainloader=data.DataLoader(dataset,batch_size=batch_size,shuffle=True,num_workers=3) 52 | 53 | dataset_nl=batch_data.data(dataset_nl_path,transform=img_transform) 54 | trainloader_nl=data.DataLoader(dataset_nl,batch_size=batch_size,shuffle=True,num_workers=3) 55 | 56 | trainloader_iter=enumerate(trainloader) 57 | trainloader_nl_iter=enumerate(trainloader_nl) 58 | 59 | ################### build model ################### 60 | model_g = generator(3) 61 | model_d = discriminator(class_number+1) 62 | 63 | #### fine-tune #### 64 | #new_params=model.state_dict() 65 | #pretrain_dict=torch.load(r'**/model.pth') 66 | #pretrain_dict={k:v for k,v in pretrain_dict.items() if k in new_params and v.size()==new_params[k].size()}# default k in m m.keys 67 | #new_params.update(pretrain_dict) 68 | #model.load_state_dict(new_params) 69 | 70 | model_g.train() 71 | model_g.cuda() 72 | 73 | model_d.train() 74 | model_d.cuda() 75 | 76 | ################### optimizer ################### 77 | optimizer_g=torch.optim.Adam(model_g.parameters(),lr=lr_g,betas=(0.9,0.99),weight_decay=weight_decay) 78 | #optimizer_g.zero_grad() 79 | 80 | optimizer_d=torch.optim.Adam(model_d.parameters(),lr=lr_d,betas=(0.9,0.99),weight_decay=weight_decay) 81 | #optimizer_d.zero_grad() 82 | 83 | ################### iter train ################### 84 | for iters in range(max_iter): 85 | loss_g_v=0 86 | loss_d_v=0 87 | 88 | ####### train D ################## 89 | optimizer_d.zero_grad() 90 | adjust_lr(optimizer_d,lr_d,iters,max_iter,power) 91 | 92 | # labeled data 93 | try: 94 | _,batch=next(trainloader_iter) 95 | except: 96 | trainloader_iter=enumerate(trainloader) 97 | _,batch=next(trainloader_iter) 98 | 99 | images,labels=batch 100 | images=Variable(images).cuda() 101 | labels=Variable(labels).cuda() 102 | 103 | # unlabeled data 104 | try: 105 | _,batch_nl=next(trainloader_nl_iter) 106 | except: 107 | trainloader_nl_iter=enumerate(trainloader_nl) 108 | _,batch_nl=next(trainloader_nl_iter) 109 | 110 | images_nl=batch_nl 111 | images_nl=Variable(images_nl).cuda() 112 | if images.shape[0] != images_nl.shape[0]: 113 | continue 114 | # noise data 115 | noise = torch.rand([images.shape[0],50*50]).uniform_().cuda() 116 | # predict 117 | pred_labeled = model_d(images) 118 | pred_unlabel = model_d(images_nl) 119 | pred_fake = model_d( model_g(noise) ) 120 | # compute loss 121 | loss_labeled = Loss_label(pred_labeled,labels) 122 | loss_unlabel = Loss_unlabel(pred_unlabel) 123 | loss_fake = Loss_fake(pred_fake) 124 | 125 | loss_d = loss_labeled + 0.5*loss_fake + 0.5*loss_unlabel 126 | loss_d_v += loss_d.data.cpu().numpy().item() 127 | loss_d.backward() 128 | optimizer_d.step() 129 | 130 | ####### train G ################## 131 | optimizer_g.zero_grad() 132 | adjust_lr(optimizer_g,lr_g,iters,max_iter,power) 133 | # predict 134 | pred_fake = model_d( model_g(noise) ) 135 | loss_g = -Loss_fake(pred_fake) 136 | loss_g_v += loss_g.data.cpu().numpy().item() 137 | loss_g.backward() 138 | optimizer_g.step() 139 | 140 | # output loss value 141 | print('iter=%d , loss_g=%.2f , loss_d=%.2f'%(iters,loss_g_v,loss_d_v)) 142 | # save model 143 | if iters%1000==0 and iters!=0: 144 | # test image 145 | # img=Image.open(os.path.join(test_path,names[i])) 146 | # r,g,b=img.split() 147 | # img=Image.merge('RGB',(b,g,r)) 148 | # img_=img_transform(img) 149 | # img_=torch.unsqueeze(img_,dim=0) 150 | # image=Variable(img_).cuda() 151 | # predict=model(image) 152 | # P=torch.max(predict,1)[1].cuda().data.cpu().numpy()[0] 153 | # P=np.uint8(P) 154 | # cv2.imwrite(os.path.join(pre_path,names[i]),P) 155 | 156 | torch.save(model_d.state_dict(),model_s_path) 157 | torch.save(model_g.state_dict(),model_g_spath) 158 | torch.save(model_d.state_dict(),model_s_path) 159 | torch.save(model_g.state_dict(),model_g_spath) 160 | 161 | 162 | 163 | 164 | 165 | --------------------------------------------------------------------------------