├── Create_data ├── __pycache__ │ └── clear_and_create.cpython-35.pyc └── clear_and_create.py ├── Data.py ├── utlis.py ├── README.md ├── train.py ├── ops.py └── model.py /Create_data/__pycache__/clear_and_create.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rongpeng-Lin/A-DA-GAN-architecture/HEAD/Create_data/__pycache__/clear_and_create.cpython-35.pyc -------------------------------------------------------------------------------- /Data.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import cv2 as cv 4 | import math,os,h5py,argparse,sys 5 | from Create_data import clear_and_create 6 | 7 | def main(args): 8 | if args.op_type=='clear': 9 | clear_and_create.clear_data(args.im_dir) 10 | return True 11 | else: 12 | clear_and_create.create_data(args.raw_dir,args.if_clip) 13 | return True 14 | 15 | def parse_arguments(argv): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--op_type', type=str, help='Choose to clear data or create positive and negative samples,clear or create.', default="clear") 18 | parser.add_argument('--im_dir', type=str, help='Path to the image folder.', default="D:/SVHN_dataset/train/") 19 | parser.add_argument('--raw_dir', type=str, help='The path that has been cleared.', default="D:/SVHN_dataset/train/") 20 | parser.add_argument('--if_clip', type=bool,help='Whether to divide the picture into two parts (the training set will be reduced after segmentation).', default=False) 21 | return parser.parse_args(argv) 22 | 23 | if __name__ == '__main__': 24 | main(parse_arguments(sys.argv[1:])) 25 | -------------------------------------------------------------------------------- /utlis.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import cv2 as cv 4 | import math,os,h5py,argparse,sys 5 | 6 | def get_im_label(idx,batch,Dir,svhnMat): 7 | im_zeros = np.zeros([batch,64,64,3],np.float32) 8 | label_zeros = np.zeros([batch,10],np.float32) 9 | start = int(idx*batch) 10 | end = start+batch 11 | dirs = Dir[start:end] 12 | print('dirs: ',dirs) 13 | for i,Dir in enumerate(dirs): 14 | im_bgr = cv.resize(cv.imread(Dir),(64,64),interpolation=cv.INTER_CUBIC) 15 | im_rgb_unpro = im_bgr[:,:,::-1] 16 | im_rgb = ((im_rgb_unpro/255)-0.5)*2 17 | im_zeros[i,:,:,:] = im_rgb 18 | label_zeros[i,:] = im_dir2label(Dir,svhnMat) 19 | return im_zeros,label_zeros 20 | 21 | def im_dir2label(a_dir,svhnMat): 22 | label = np.zeros([10,],np.float32) 23 | im_name = a_dir.split('/')[-1] 24 | im_num = int(im_name.split('.')[0]) 25 | 26 | item = svhnMat['digitStruct']['bbox'][im_num-1].item() 27 | attr = svhnMat[item]['label'] 28 | values = [svhnMat[attr.value[i].item()].value[0][0] for i in range(len(attr))] if len(attr) > 1 else [attr.value[0][0]] 29 | for value in values: 30 | label[int(value)] = 1.0 31 | return label 32 | 33 | -------------------------------------------------------------------------------- /Create_data/clear_and_create.py: -------------------------------------------------------------------------------- 1 | import os,h5py 2 | import numpy as np 3 | import cv2 as cv 4 | 5 | def clear_data(im_dir): 6 | name = im_dir+'digitStruct.mat' 7 | svhnMat = h5py.File(name=name, mode='r') 8 | im_names = [Im_name for Im_name in os.listdir(im_dir) if Im_name.split('.')[-1]=='png'] 9 | for im_name in im_names: 10 | im_num = int(im_name.split('.')[0]) 11 | item = svhnMat['digitStruct']['bbox'][im_num-1].item() 12 | attr = svhnMat[item]['label'] 13 | values = [svhnMat[attr.value[i].item()].value[0][0] for i in range(len(attr))] if len(attr) > 1 else [attr.value[0][0]] 14 | for value in values: 15 | if value>=10.0: 16 | os.remove(im_dir+im_name) 17 | print('im is: ',im_name) 18 | print('value is: ',value) 19 | break 20 | return True 21 | 22 | def create_data(raw_dir,if_clip): 23 | positive = raw_dir+'positive' 24 | negtive = raw_dir+'negtive' 25 | for new_dir in [positive,negtive]: 26 | os.makedirs(new_dir) 27 | for im_name in os.listdir(raw_dir): 28 | if im_name.split('.')[-1]=='png': 29 | im = cv.imread(raw_dir+im_name) 30 | if if_clip: 31 | if_flip = np.random.uniform() 32 | if if_flip>0.5: 33 | im_flip = cv.flip(im,1,dst=None) 34 | cv.imwrite(negtive+'/'+im_name,im_flip) 35 | os.remove(raw_dir+im_name) 36 | else: 37 | cv.imwrite(positive+'/'+im_name,im) 38 | os.remove(raw_dir+im_name) 39 | else: 40 | cv.imwrite(positive+'/'+im_name,im) 41 | im_flip = cv.flip(im,1,dst=None) 42 | cv.imwrite(negtive+'/'+im_name,im_flip) 43 | os.remove(raw_dir+im_name) 44 | return True 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A-DA-GAN-architecture 2 | ## A basic architecture of "DA-GAN: Instance-level Image Translation by Deep Attention Generative Adversarial Networks"
3 |  This is a basic architecture implementation, and the structure of the article is outlined below:
  1. The image is encoded using an encoder (convolutional architecture).
4 |   2. Use another set of convolutions combined with a full join to generate an attention area (similar to the border of the target detection) and perform a masking operation on the original image.
5 |   3. The mask operation does not use the 01 mask, but instead uses sigmoid instead of direction propagation, making the change more 'soft'.
6 | ## how to use
7 |  1. There are some samples marked incorrectly in the svhn data set, first clean the sample:
8 |   python D:\SVHN_dataset\train\DAE_GAN\Data.py --op_type="clear" --im_dir="D:/SVHN_dataset/train/forcmd/" --raw_dir="D:/SVHN_dataset/train/forcmd/" --if_clip=False
9 |  2. Create positive and negative sample data:
10 |   python D:\SVHN_dataset\train\DAE_GAN\Data.py --op_type="create" --im_dir="D:/SVHN_dataset/train/forcmd/" --raw_dir="D:/SVHN_dataset/train/forcmd/" --if_clip=False
11 |  3. Perform training or loading:
12 |   Train:
13 |    python D:\SVHN_dataset\train\DAE_GAN\train.py --is_train="train" --im_size=64 --batch=2 --epoch=100 --hw_size=30 --k=2e5 --alpa=0.9 --beta=0.5 --im_dir="D:/SVHN_dataset/train/forcmd/" --save_dir="D:/SVHN_dataset/train/forcmd/ckpt/" --saveS_dir="D:/SVHN_dataset/train/forcmd/SampleS/" --saveT_dir="D:/SVHN_dataset/train/forcmd/SampleT/"
14 |   Test:
15 |    python D:\SVHN_dataset\train\DAE_GAN\train.py --is_train="test" --load_dir="D:/SVHN_dataset/train/ckpt/" --raw_im_dir="D:/SVHN_dataset/test" --save_im_dir="D:/SVHN_dataset/test_save/"
16 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import cv2 as cv 4 | import math,os,h5py,argparse,sys 5 | from model import * 6 | 7 | def main(args): 8 | dae_gan = DAE_GAN(args.batch, args.epoch, args.im_size, args.hw_size, args.k, args.alpa, args.beta, args.im_dir, args.save_dir, args.saveS_dir, args.saveT_dir) 9 | if args.is_train=='train': 10 | dae_gan.train() 11 | else: 12 | dae_gan.load(args.load_dir, args.raw_im_dir, args.save_im_dir) 13 | return True 14 | 15 | def parse_arguments(argv): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--is_train', type=str, help='Training or loading.', default="train") 18 | parser.add_argument('--load_dir', type=str, help='Load model checkpoint.', default="D:/SVHN_dataset/train/ckpt/") 19 | parser.add_argument('--raw_im_dir', type=str, help='Image to test.', default="D:/SVHN_dataset/test") 20 | parser.add_argument('--save_im_dir', type=str, help='Save sample images dir.', default="D:/SVHN_dataset/test_save/") 21 | 22 | parser.add_argument('--im_size', type=int, help='Image size (height, width) in pixels.', default=64) 23 | parser.add_argument('--batch', type=int, help='batch size.', default=64) 24 | parser.add_argument('--epoch', type=int, help='Number of training cyclese.', default=100) 25 | parser.add_argument('--hw_size', type=int, help='The size of the attention area removed.', default=30) 26 | parser.add_argument('--k', type=float, help='Gain coefficient of sigmoid when generating mask.', default=int(2e2)) 27 | parser.add_argument('--alpa', type=float, help='Error weight_1.', default=0.4) 28 | parser.add_argument('--beta', type=float, help='Error weight_2.', default=0.6) 29 | parser.add_argument('--im_dir', type=str, help='Path to the image folder.', default="D:/SVHN_dataset/train/") 30 | parser.add_argument('--save_dir', type=str, help='Model save path.', default="D:/SVHN_dataset/train/ckpt/") 31 | parser.add_argument('--saveS_dir', type=str, help='The path that has been cleared.', default="D:/SVHN_dataset/train/SampleS/") 32 | parser.add_argument('--saveT_dir', type=str, help='Source image save path.', default="D:/SVHN_dataset/train/SampleT/") 33 | return parser.parse_args(argv) 34 | 35 | if __name__ == '__main__': 36 | main(parse_arguments(sys.argv[1:])) 37 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import cv2 as cv 4 | import math,os,h5py,argparse,sys 5 | 6 | def conv(name,x,kers,s,outs,pad): 7 | with tf.variable_scope(name): 8 | ker = int(math.sqrt(kers)) 9 | shape = [i.value for i in x.get_shape()] 10 | w = tf.get_variable('w', 11 | [ker,ker,shape[-1],outs], 12 | tf.float32, 13 | tf.contrib.layers.xavier_initializer_conv2d()) 14 | b = tf.get_variable('b',[outs],tf.float32,tf.constant_initializer(0.)) 15 | padd = "SAME" if pad else "VALID" 16 | x_conv = tf.nn.conv2d(x,w,[1,s,s,1],padd) + b 17 | return x_conv 18 | 19 | def res_block(name,x): 20 | with tf.variable_scope(name): 21 | shape = [i.value for i in x.get_shape()] 22 | conv1 = conv(name+'_conv1',x,3*3,shape[-1],1,True) 23 | bn1 = tf.nn.relu(tf.contrib.layers.batch_norm(conv1,scale=True,updates_collections=None)) 24 | conv2 = conv(name+'_conv2',bn1,3*3,shape[-1],1,True) 25 | bn2 = tf.contrib.layers.batch_norm(conv2,scale=True,updates_collections=None) 26 | return tf.nn.relu(bn2+x) 27 | 28 | def conv_trans(name,x,sizes,s,outs,ifpad): 29 | with tf.variable_scope(name): 30 | ker = int(math.sqrt(sizes)) 31 | shape = [i.value for i in x.get_shape()] 32 | ins = shape[-1]//4 33 | w = tf.get_variable('w',[ker,ker,ins,outs],tf.float32,tf.contrib.layers.xavier_initializer_conv2d()) 34 | b = tf.get_variable('b',[outs],tf.float32,tf.constant_initializer(0.)) 35 | pad = "SAME" if ifpad else "VALID" 36 | x_conv = tf.nn.conv2d(tf.depth_to_space(x,2),w,[1,s,s,1],pad)+b 37 | return x_conv 38 | 39 | def lrelu(name,x): 40 | with tf.variable_scope(name): 41 | return tf.nn.relu(x) 42 | 43 | def tanh(name,x): 44 | with tf.variable_scope(name): 45 | return tf.nn.tanh(x) 46 | 47 | def BN(name,x): 48 | with tf.variable_scope(name): 49 | return tf.contrib.layers.batch_norm(x,scale=True,updates_collections=None) 50 | 51 | def FC_location(name,x,outs,im_size,hw_size): 52 | with tf.variable_scope(name): 53 | raw_shape = [i.value for i in x.get_shape()] 54 | new_shape = int(raw_shape[1]*raw_shape[2]*raw_shape[3]) 55 | x_resh = tf.reshape(x,[-1,new_shape]) 56 | w = tf.get_variable('w',[new_shape,outs],tf.float32,tf.contrib.layers.xavier_initializer()) 57 | b = tf.get_variable('b',[1,outs],tf.float32,tf.constant_initializer(0.)) 58 | fc_ = tf.matmul(x_resh,w)+b 59 | # Add a range limit to the position coordinates while making the gradient softer. 60 | xy_constraint = tf.nn.sigmoid(fc_)*(im_size-hw_size) 61 | return tf.cast(tf.round(xy_constraint),tf.int32) 62 | 63 | def get_constant(im_size): 64 | zero_x = np.zeros([im_size,im_size],np.float32) 65 | zero_y = np.zeros([im_size,im_size],np.float32) 66 | for i in range(im_size): 67 | zero_x[:,i] = i+1 68 | zero_y[i,:] = i+1 69 | return zero_x,zero_y 70 | 71 | def sigmoid_mask(x,k): 72 | k = int(k) 73 | X = tf.cast(x,tf.float32) 74 | return 1/(1+tf.exp(-1*k*X)) 75 | 76 | def get_mask(xy,reigon_w,reigon_h,im_size,k,B): 77 | # xy: [batch, 2]: coordinates 78 | # reigon_w, reigon_h: size of the area 79 | # im_size: Image size for generating the original X 80 | # k: The growth factor of the sigmoid function 81 | with tf.variable_scope('Mask'): 82 | x_left = tf.expand_dims(tf.expand_dims(tf.expand_dims(xy[:,0],1),2),3) 83 | x_right = tf.expand_dims(tf.expand_dims(tf.expand_dims(xy[:,0]+reigon_w,1),2),3) 84 | y_top = tf.expand_dims(tf.expand_dims(tf.expand_dims(xy[:,1],1),2),3) 85 | y_bottom = tf.expand_dims(tf.expand_dims(tf.expand_dims(xy[:,1]+reigon_h,1),2),3) 86 | x_value,y_value = get_constant(im_size) 87 | x_constant = np.tile(np.expand_dims(np.expand_dims(x_value,0),3),[B,1,1,1]) 88 | y_constant = np.tile(np.expand_dims(np.expand_dims(y_value,0),3),[B,1,1,1]) 89 | A = sigmoid_mask(x_constant-x_left,k) 90 | C = sigmoid_mask(x_constant-x_right,k) 91 | D = sigmoid_mask(y_constant-y_top,k) 92 | E = sigmoid_mask(y_constant-y_bottom,k) 93 | return (A-C)*(D-E) 94 | 95 | def FC(name,x,outs): 96 | with tf.variable_scope(name): 97 | shape = [i.value for i in x.get_shape()] 98 | size = int(shape[1]*shape[2]*shape[3]) 99 | x_reshape = tf.reshape(x,[-1,size]) 100 | w = tf.get_variable('w',[size,outs],tf.float32,tf.contrib.layers.xavier_initializer()) 101 | b = tf.get_variable('b',[1,outs],tf.float32,tf.constant_initializer(0.)) 102 | return tf.nn.sigmoid(tf.matmul(x_reshape,w)+b) 103 | 104 | def resnet_classifer(name,x): # x: batch,4,4,512 105 | with tf.variable_scope(name): 106 | res1 = res_block('res1',x) 107 | res2 = res_block('res2',res1) 108 | res3 = res_block('res3',res2) 109 | res4 = res_block('res4',res3) 110 | res5 = res_block('res5',res4) 111 | res6 = res_block('res6',res5) 112 | res7 = res_block('res7',res6) 113 | probs = FC('probs',res7,10) 114 | return probs 115 | 116 | def One_ep_Iter(im_dir,batch): 117 | num_ims = min(len(os.listdir(im_dir+'positive')),len(os.listdir(im_dir+'negtive'))) 118 | return num_ims//batch 119 | 120 | def get_shape(x): 121 | L = [i.value for i in x.get_shape()] 122 | return L 123 | 124 | def save_im(sample_im,save_dir,cur_ep,cur_batch,batch): 125 | for i in range(batch): 126 | im_S = sample_im[i,:,:,:] 127 | im_s = (im_S[:,:,::-1]+1)*127.5 128 | s_name = 'ep'+str(cur_ep)+'_sample'+str(cur_batch)+'_batch'+str(i)+'.png' 129 | cv.imwrite(save_dir+s_name,im_s) 130 | return True 131 | 132 | def Save_load(ims,num,save_dir): 133 | b = np.shape(ims)[0] 134 | for i in range(b): 135 | im_S = ims[i,:,:,:] 136 | im_s = (im_S[:,:,::-1]+1)*127.5 137 | name = 'num_'+str(num)+'.png' 138 | cv.imwrite(save_dir+name,im_s) 139 | return True 140 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from utlis import * 2 | from ops import * 3 | import tensorflow as tf 4 | import numpy as np 5 | import cv2 as cv 6 | import math,os,h5py 7 | 8 | class DAE_GAN: 9 | def __init__(self,batch,epoch,im_size,hw_size,k,alpa,beta,im_dir,save_dir,saveS_dir,saveT_dir): 10 | self.batch = batch 11 | self.epoch = epoch 12 | self.im_size = im_size 13 | self.hw_size = hw_size 14 | self.positive_dir = [im_dir+'positive/'+name for name in os.listdir(im_dir+'positive')] 15 | self.negtive_dir = [im_dir+'negtive/'+name for name in os.listdir(im_dir+'negtive')] 16 | self.svhnMat = h5py.File(im_dir+'digitStruct.mat', mode='r') 17 | self.first = 64 18 | self.k = k 19 | self.alpa = alpa 20 | self.beta = beta 21 | self.one_ep_iter = One_ep_Iter(im_dir,batch) 22 | self.S = tf.placeholder(tf.float32,[None,im_size,im_size,3],'S') 23 | self.T = tf.placeholder(tf.float32,[None,im_size,im_size,3],'T') 24 | self.S_label = tf.placeholder(tf.float32,[None,10],'S_label') 25 | self.T_label = tf.placeholder(tf.float32,[None,10],'T_label') 26 | self.save_dir = save_dir 27 | self.saveS = saveS_dir 28 | self.saveT = saveT_dir 29 | 30 | def Encoder(self,x,reuse): 31 | with tf.variable_scope('encoder',reuse=reuse): 32 | conv1 = conv('conv1',x,3*3,2,self.first,True) 33 | bn1 = BN('bn1',conv1) 34 | relu1 = lrelu('relu1',bn1) 35 | 36 | conv2 = conv('conv2',relu1,3*3,2,int(2*self.first),True) 37 | bn2 = BN('bn2',conv2) 38 | relu2 = lrelu('relu2',bn2) 39 | 40 | conv3 = conv('conv3',relu2,3*3,2,int(4*self.first),True) 41 | bn3 = BN('bn3',conv3) 42 | relu3 = lrelu('relu3',bn3) 43 | 44 | conv4 = conv('conv4',relu3,3*3,2,int(8*self.first),True) 45 | bn4 = BN('bn4',conv4) 46 | relu4 = lrelu('relu4',bn4) 47 | return relu4 48 | 49 | def F(self,features,k): 50 | with tf.variable_scope('f_location'): 51 | cha = [i.value for i in features.get_shape()][-1] 52 | 53 | conv1 = conv('conv1',features,3*3,2,int(2*cha),True) 54 | bn1 = BN('bn1',conv1) 55 | relu1 = lrelu('relu1',bn1) 56 | 57 | conv2 = conv('conv2',relu1,3*3,2,int(4*cha),True) 58 | bn2 = BN('bn2',conv2) 59 | relu2 = lrelu('relu2',bn2) 60 | 61 | out_xy = FC_location('fc1',relu2,2,self.im_size,self.hw_size) 62 | mask = get_mask(out_xy,self.hw_size,self.hw_size,self.im_size,k,self.batch) 63 | return mask 64 | 65 | def GAN_G(self,x,reuse): 66 | with tf.variable_scope('G_net',reuse=reuse): 67 | 68 | conv_trans1 = conv_trans('conv_trans1',x,5*5,1,256,True) 69 | Lrelu1 = lrelu('Lrelu1',conv_trans1) 70 | 71 | conv_trans2 = conv_trans('conv_trans2',Lrelu1,5*5,1,128,True) 72 | bn2 = BN('bn2',conv_trans2) 73 | Lrelu2 = lrelu('Lrelu2',bn2) 74 | 75 | conv_trans3 = conv_trans('conv_trans3',Lrelu2,5*5,1,64,True) 76 | bn3 = BN('bn3',conv_trans3) 77 | Lrelu3 = lrelu('Lrelu3',bn3) 78 | 79 | conv_trans4 = conv_trans('conv_trans4',Lrelu3,5*5,1,3,True) 80 | Lrelu4 = tanh('Lrelu4',conv_trans4) 81 | return Lrelu4 82 | 83 | def GAN_D(self,name,x,reuse): 84 | with tf.variable_scope(name,reuse=reuse): 85 | conv1 = conv('conv1',x,3*3,2,64,True) 86 | bn1 = BN('bn1',conv1) 87 | relu1 = lrelu('relu1',bn1) 88 | 89 | conv2 = conv('conv2',relu1,3*3,2,128,True) 90 | bn2 = BN('bn2',conv2) 91 | relu2 = lrelu('relu2',bn2) 92 | 93 | conv3 = conv('conv3',relu2,3*3,2,256,True) 94 | bn3 = BN('bn3',conv3) 95 | relu3 = lrelu('relu3',bn3) 96 | 97 | conv4 = conv('conv4',relu3,3*3,2,512,True) 98 | bn4 = BN('bn4',conv4) 99 | relu4 = lrelu('relu4',bn4) 100 | 101 | D_out = FC('fc1',relu4,1) 102 | return D_out 103 | 104 | def DAE(self,x,reuse): 105 | with tf.variable_scope('DAE',reuse=reuse): 106 | encode_x = self.Encoder(x,reuse) 107 | mask = self.F(encode_x,self.k) 108 | x_mask = x*mask 109 | encode_mask_x = self.Encoder(x_mask,True) 110 | 111 | probs = resnet_classifer('classsifer',encode_mask_x) 112 | return encode_mask_x,probs 113 | 114 | def forward(self): 115 | self.s_DAE,self.s_out_label = self.DAE(self.S,False) 116 | self.t_DAE,self.t_out_label = self.DAE(self.T,True) 117 | self.s_pie = self.GAN_G(self.s_DAE,False) 118 | self.t_pie = self.GAN_G(self.t_DAE,True) 119 | 120 | self.t_D1 = self.GAN_D('D2',self.T,False) 121 | self.t_pie_D = self.GAN_D('D2',self.t_pie,True) 122 | 123 | self.t_D2 = self.GAN_D('D1',self.T,False) 124 | self.s_pie_D = self.GAN_D('D1',self.s_pie,True) 125 | 126 | self.s_pie_DAE,_ = self.DAE(self.s_pie,True) 127 | self.t_pie_DAE,_ = self.DAE(self.t_pie,True) 128 | 129 | def train(self): 130 | self.forward() 131 | # Loss for G and DAE: 132 | # 1、 Loss of s_DAE and s_pie_DAE: 133 | Lcst = tf.reduce_mean(tf.abs(self.s_DAE-self.s_pie_DAE)) # L1范数 134 | # 2、 Loss of t_DAE and t_pie_DAE: 135 | Lsym = tf.reduce_mean(tf.abs(self.t_DAE-self.t_pie_DAE)) # L1范数 136 | # 3、Make D2 judge t_pie as true: 137 | loss_G_DAE1 = tf.reduce_mean(-1*tf.log(self.t_pie_D)) 138 | # 4、Make D1 judge s_pie as true: 139 | loss_G_DAE2 = tf.reduce_mean(-1*tf.log(self.s_pie_D)) 140 | # 5、Loss caused by classification: cross entropy, this alone corresponds to DAE: 141 | cross_entroy = tf.reduce_mean(tf.reduce_mean((-1)*tf.log(self.s_out_label)*self.S_label + (-1)*tf.log(1-self.s_out_label)*(1-self.S_label),1)) 142 | 143 | self.DAE_G_loss = self.alpa*Lcst + self.beta*Lsym + loss_G_DAE1 + loss_G_DAE2 + cross_entroy 144 | 145 | self.D1_loss = tf.reduce_mean((-1)*tf.log(self.t_D2) + (-1)*tf.log(1-self.s_pie_D)) 146 | self.D2_loss = tf.reduce_mean((-1)*tf.log(self.t_D2) + (-1)*tf.log(1-self.t_pie_D)) 147 | DAE_G_vars = [var for var in tf.trainable_variables() if 'DAE' in var.name or 'G_net' in var.name] 148 | D1_vars = [var for var in tf.trainable_variables() if 'D1' in var.name] 149 | D2_vars = [var for var in tf.trainable_variables() if 'D2' in var.name] 150 | optim_DAE_G = tf.train.AdamOptimizer().minimize(self.DAE_G_loss,var_list=DAE_G_vars) 151 | optim_D1 = tf.train.AdamOptimizer().minimize(self.D1_loss,var_list=D1_vars) 152 | optim_D2 = tf.train.AdamOptimizer().minimize(self.D2_loss,var_list=D2_vars) 153 | with tf.Session() as sess: 154 | sess.run(tf.global_variables_initializer()) 155 | graph = tf.summary.FileWriter(self.save_dir,graph=sess.graph) 156 | Saver = tf.train.Saver(max_to_keep=20) 157 | 158 | savedir = self.save_dir+'model.ckpt' 159 | for i in range(self.epoch): 160 | for j in range(self.one_ep_iter): 161 | ims_po,labels_po = get_im_label(j, self.batch, self.positive_dir, self.svhnMat) 162 | ims_neg,labels_neg = get_im_label(j, self.batch, self.negtive_dir, self.svhnMat) 163 | 164 | fed_dict={self.S:ims_po,self.S_label:labels_po,self.T:ims_neg,self.T_label:labels_neg} 165 | 166 | _,LossD1 = sess.run([optim_D1,self.D1_loss],feed_dict=fed_dict) 167 | print('LossD1: ',LossD1) 168 | 169 | _,LossD2 = sess.run([optim_D2,self.D2_loss],feed_dict=fed_dict) 170 | print('LossD2: ',LossD2) 171 | 172 | _,S_sample,T_sample,LossDAE_G = sess.run([optim_DAE_G,self.s_pie,self.t_pie,self.DAE_G_loss],feed_dict=fed_dict) 173 | print('LossDAE_G: ',LossDAE_G) 174 | 175 | save_im(S_sample, self.saveS, i, j, self.batch) 176 | save_im(T_sample, self.saveT, i, j, self.batch) 177 | 178 | step = int(i*self.epoch + j) 179 | Saver.save(sess,savedir,global_step=step) 180 | print('save_success at: ',step) 181 | 182 | def load(self,load_dir,raw_im_dir,save_im_dir): 183 | self.forward() 184 | with tf.Session() as sess: 185 | sess.run(tf.global_variables_initializer()) 186 | graph = tf.summary.FileWriter(self.save_dir,graph=sess.graph) 187 | Saver = tf.train.Saver() 188 | Saver.restore(sess,load_dir) 189 | for i,im_name in enumerate(os.listdir(raw_im_dir)): 190 | im_bgr = cv.resize(cv.imread(raw_im_dir+im_name),(64,64),interpolation=cv.INTER_CUBIC) 191 | im_rgb_unpro = im_bgr[:,:,::-1] 192 | im_rgb = np.expand_dims(((im_rgb_unpro/255)-0.5)*2,0) 193 | # fed_dict={self.S:ims_po,self.S_label:labels_po} 194 | fed_dict={self.S:im_rgb} 195 | s_tar = sess.run(self.s_pie,feed_dict=fed_dict) 196 | Save_load(s_tar,i,save_im_dir) 197 | print('save at: ',i) 198 | --------------------------------------------------------------------------------