├── Create_data
├── __pycache__
│ └── clear_and_create.cpython-35.pyc
└── clear_and_create.py
├── Data.py
├── utlis.py
├── README.md
├── train.py
├── ops.py
└── model.py
/Create_data/__pycache__/clear_and_create.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rongpeng-Lin/A-DA-GAN-architecture/HEAD/Create_data/__pycache__/clear_and_create.cpython-35.pyc
--------------------------------------------------------------------------------
/Data.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import cv2 as cv
4 | import math,os,h5py,argparse,sys
5 | from Create_data import clear_and_create
6 |
7 | def main(args):
8 | if args.op_type=='clear':
9 | clear_and_create.clear_data(args.im_dir)
10 | return True
11 | else:
12 | clear_and_create.create_data(args.raw_dir,args.if_clip)
13 | return True
14 |
15 | def parse_arguments(argv):
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('--op_type', type=str, help='Choose to clear data or create positive and negative samples,clear or create.', default="clear")
18 | parser.add_argument('--im_dir', type=str, help='Path to the image folder.', default="D:/SVHN_dataset/train/")
19 | parser.add_argument('--raw_dir', type=str, help='The path that has been cleared.', default="D:/SVHN_dataset/train/")
20 | parser.add_argument('--if_clip', type=bool,help='Whether to divide the picture into two parts (the training set will be reduced after segmentation).', default=False)
21 | return parser.parse_args(argv)
22 |
23 | if __name__ == '__main__':
24 | main(parse_arguments(sys.argv[1:]))
25 |
--------------------------------------------------------------------------------
/utlis.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import cv2 as cv
4 | import math,os,h5py,argparse,sys
5 |
6 | def get_im_label(idx,batch,Dir,svhnMat):
7 | im_zeros = np.zeros([batch,64,64,3],np.float32)
8 | label_zeros = np.zeros([batch,10],np.float32)
9 | start = int(idx*batch)
10 | end = start+batch
11 | dirs = Dir[start:end]
12 | print('dirs: ',dirs)
13 | for i,Dir in enumerate(dirs):
14 | im_bgr = cv.resize(cv.imread(Dir),(64,64),interpolation=cv.INTER_CUBIC)
15 | im_rgb_unpro = im_bgr[:,:,::-1]
16 | im_rgb = ((im_rgb_unpro/255)-0.5)*2
17 | im_zeros[i,:,:,:] = im_rgb
18 | label_zeros[i,:] = im_dir2label(Dir,svhnMat)
19 | return im_zeros,label_zeros
20 |
21 | def im_dir2label(a_dir,svhnMat):
22 | label = np.zeros([10,],np.float32)
23 | im_name = a_dir.split('/')[-1]
24 | im_num = int(im_name.split('.')[0])
25 |
26 | item = svhnMat['digitStruct']['bbox'][im_num-1].item()
27 | attr = svhnMat[item]['label']
28 | values = [svhnMat[attr.value[i].item()].value[0][0] for i in range(len(attr))] if len(attr) > 1 else [attr.value[0][0]]
29 | for value in values:
30 | label[int(value)] = 1.0
31 | return label
32 |
33 |
--------------------------------------------------------------------------------
/Create_data/clear_and_create.py:
--------------------------------------------------------------------------------
1 | import os,h5py
2 | import numpy as np
3 | import cv2 as cv
4 |
5 | def clear_data(im_dir):
6 | name = im_dir+'digitStruct.mat'
7 | svhnMat = h5py.File(name=name, mode='r')
8 | im_names = [Im_name for Im_name in os.listdir(im_dir) if Im_name.split('.')[-1]=='png']
9 | for im_name in im_names:
10 | im_num = int(im_name.split('.')[0])
11 | item = svhnMat['digitStruct']['bbox'][im_num-1].item()
12 | attr = svhnMat[item]['label']
13 | values = [svhnMat[attr.value[i].item()].value[0][0] for i in range(len(attr))] if len(attr) > 1 else [attr.value[0][0]]
14 | for value in values:
15 | if value>=10.0:
16 | os.remove(im_dir+im_name)
17 | print('im is: ',im_name)
18 | print('value is: ',value)
19 | break
20 | return True
21 |
22 | def create_data(raw_dir,if_clip):
23 | positive = raw_dir+'positive'
24 | negtive = raw_dir+'negtive'
25 | for new_dir in [positive,negtive]:
26 | os.makedirs(new_dir)
27 | for im_name in os.listdir(raw_dir):
28 | if im_name.split('.')[-1]=='png':
29 | im = cv.imread(raw_dir+im_name)
30 | if if_clip:
31 | if_flip = np.random.uniform()
32 | if if_flip>0.5:
33 | im_flip = cv.flip(im,1,dst=None)
34 | cv.imwrite(negtive+'/'+im_name,im_flip)
35 | os.remove(raw_dir+im_name)
36 | else:
37 | cv.imwrite(positive+'/'+im_name,im)
38 | os.remove(raw_dir+im_name)
39 | else:
40 | cv.imwrite(positive+'/'+im_name,im)
41 | im_flip = cv.flip(im,1,dst=None)
42 | cv.imwrite(negtive+'/'+im_name,im_flip)
43 | os.remove(raw_dir+im_name)
44 | return True
45 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # A-DA-GAN-architecture
2 | ## A basic architecture of "DA-GAN: Instance-level Image Translation by Deep Attention Generative Adversarial Networks"
3 | This is a basic architecture implementation, and the structure of the article is outlined below:
1. The image is encoded using an encoder (convolutional architecture).
4 | 2. Use another set of convolutions combined with a full join to generate an attention area (similar to the border of the target detection) and perform a masking operation on the original image.
5 | 3. The mask operation does not use the 01 mask, but instead uses sigmoid instead of direction propagation, making the change more 'soft'.
6 | ## how to use
7 | 1. There are some samples marked incorrectly in the svhn data set, first clean the sample:
8 | python D:\SVHN_dataset\train\DAE_GAN\Data.py --op_type="clear" --im_dir="D:/SVHN_dataset/train/forcmd/" --raw_dir="D:/SVHN_dataset/train/forcmd/" --if_clip=False
9 | 2. Create positive and negative sample data:
10 | python D:\SVHN_dataset\train\DAE_GAN\Data.py --op_type="create" --im_dir="D:/SVHN_dataset/train/forcmd/" --raw_dir="D:/SVHN_dataset/train/forcmd/" --if_clip=False
11 | 3. Perform training or loading:
12 | Train:
13 | python D:\SVHN_dataset\train\DAE_GAN\train.py --is_train="train" --im_size=64 --batch=2 --epoch=100 --hw_size=30 --k=2e5 --alpa=0.9 --beta=0.5 --im_dir="D:/SVHN_dataset/train/forcmd/" --save_dir="D:/SVHN_dataset/train/forcmd/ckpt/" --saveS_dir="D:/SVHN_dataset/train/forcmd/SampleS/" --saveT_dir="D:/SVHN_dataset/train/forcmd/SampleT/"
14 | Test:
15 | python D:\SVHN_dataset\train\DAE_GAN\train.py --is_train="test" --load_dir="D:/SVHN_dataset/train/ckpt/" --raw_im_dir="D:/SVHN_dataset/test" --save_im_dir="D:/SVHN_dataset/test_save/"
16 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import cv2 as cv
4 | import math,os,h5py,argparse,sys
5 | from model import *
6 |
7 | def main(args):
8 | dae_gan = DAE_GAN(args.batch, args.epoch, args.im_size, args.hw_size, args.k, args.alpa, args.beta, args.im_dir, args.save_dir, args.saveS_dir, args.saveT_dir)
9 | if args.is_train=='train':
10 | dae_gan.train()
11 | else:
12 | dae_gan.load(args.load_dir, args.raw_im_dir, args.save_im_dir)
13 | return True
14 |
15 | def parse_arguments(argv):
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('--is_train', type=str, help='Training or loading.', default="train")
18 | parser.add_argument('--load_dir', type=str, help='Load model checkpoint.', default="D:/SVHN_dataset/train/ckpt/")
19 | parser.add_argument('--raw_im_dir', type=str, help='Image to test.', default="D:/SVHN_dataset/test")
20 | parser.add_argument('--save_im_dir', type=str, help='Save sample images dir.', default="D:/SVHN_dataset/test_save/")
21 |
22 | parser.add_argument('--im_size', type=int, help='Image size (height, width) in pixels.', default=64)
23 | parser.add_argument('--batch', type=int, help='batch size.', default=64)
24 | parser.add_argument('--epoch', type=int, help='Number of training cyclese.', default=100)
25 | parser.add_argument('--hw_size', type=int, help='The size of the attention area removed.', default=30)
26 | parser.add_argument('--k', type=float, help='Gain coefficient of sigmoid when generating mask.', default=int(2e2))
27 | parser.add_argument('--alpa', type=float, help='Error weight_1.', default=0.4)
28 | parser.add_argument('--beta', type=float, help='Error weight_2.', default=0.6)
29 | parser.add_argument('--im_dir', type=str, help='Path to the image folder.', default="D:/SVHN_dataset/train/")
30 | parser.add_argument('--save_dir', type=str, help='Model save path.', default="D:/SVHN_dataset/train/ckpt/")
31 | parser.add_argument('--saveS_dir', type=str, help='The path that has been cleared.', default="D:/SVHN_dataset/train/SampleS/")
32 | parser.add_argument('--saveT_dir', type=str, help='Source image save path.', default="D:/SVHN_dataset/train/SampleT/")
33 | return parser.parse_args(argv)
34 |
35 | if __name__ == '__main__':
36 | main(parse_arguments(sys.argv[1:]))
37 |
--------------------------------------------------------------------------------
/ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import cv2 as cv
4 | import math,os,h5py,argparse,sys
5 |
6 | def conv(name,x,kers,s,outs,pad):
7 | with tf.variable_scope(name):
8 | ker = int(math.sqrt(kers))
9 | shape = [i.value for i in x.get_shape()]
10 | w = tf.get_variable('w',
11 | [ker,ker,shape[-1],outs],
12 | tf.float32,
13 | tf.contrib.layers.xavier_initializer_conv2d())
14 | b = tf.get_variable('b',[outs],tf.float32,tf.constant_initializer(0.))
15 | padd = "SAME" if pad else "VALID"
16 | x_conv = tf.nn.conv2d(x,w,[1,s,s,1],padd) + b
17 | return x_conv
18 |
19 | def res_block(name,x):
20 | with tf.variable_scope(name):
21 | shape = [i.value for i in x.get_shape()]
22 | conv1 = conv(name+'_conv1',x,3*3,shape[-1],1,True)
23 | bn1 = tf.nn.relu(tf.contrib.layers.batch_norm(conv1,scale=True,updates_collections=None))
24 | conv2 = conv(name+'_conv2',bn1,3*3,shape[-1],1,True)
25 | bn2 = tf.contrib.layers.batch_norm(conv2,scale=True,updates_collections=None)
26 | return tf.nn.relu(bn2+x)
27 |
28 | def conv_trans(name,x,sizes,s,outs,ifpad):
29 | with tf.variable_scope(name):
30 | ker = int(math.sqrt(sizes))
31 | shape = [i.value for i in x.get_shape()]
32 | ins = shape[-1]//4
33 | w = tf.get_variable('w',[ker,ker,ins,outs],tf.float32,tf.contrib.layers.xavier_initializer_conv2d())
34 | b = tf.get_variable('b',[outs],tf.float32,tf.constant_initializer(0.))
35 | pad = "SAME" if ifpad else "VALID"
36 | x_conv = tf.nn.conv2d(tf.depth_to_space(x,2),w,[1,s,s,1],pad)+b
37 | return x_conv
38 |
39 | def lrelu(name,x):
40 | with tf.variable_scope(name):
41 | return tf.nn.relu(x)
42 |
43 | def tanh(name,x):
44 | with tf.variable_scope(name):
45 | return tf.nn.tanh(x)
46 |
47 | def BN(name,x):
48 | with tf.variable_scope(name):
49 | return tf.contrib.layers.batch_norm(x,scale=True,updates_collections=None)
50 |
51 | def FC_location(name,x,outs,im_size,hw_size):
52 | with tf.variable_scope(name):
53 | raw_shape = [i.value for i in x.get_shape()]
54 | new_shape = int(raw_shape[1]*raw_shape[2]*raw_shape[3])
55 | x_resh = tf.reshape(x,[-1,new_shape])
56 | w = tf.get_variable('w',[new_shape,outs],tf.float32,tf.contrib.layers.xavier_initializer())
57 | b = tf.get_variable('b',[1,outs],tf.float32,tf.constant_initializer(0.))
58 | fc_ = tf.matmul(x_resh,w)+b
59 | # Add a range limit to the position coordinates while making the gradient softer.
60 | xy_constraint = tf.nn.sigmoid(fc_)*(im_size-hw_size)
61 | return tf.cast(tf.round(xy_constraint),tf.int32)
62 |
63 | def get_constant(im_size):
64 | zero_x = np.zeros([im_size,im_size],np.float32)
65 | zero_y = np.zeros([im_size,im_size],np.float32)
66 | for i in range(im_size):
67 | zero_x[:,i] = i+1
68 | zero_y[i,:] = i+1
69 | return zero_x,zero_y
70 |
71 | def sigmoid_mask(x,k):
72 | k = int(k)
73 | X = tf.cast(x,tf.float32)
74 | return 1/(1+tf.exp(-1*k*X))
75 |
76 | def get_mask(xy,reigon_w,reigon_h,im_size,k,B):
77 | # xy: [batch, 2]: coordinates
78 | # reigon_w, reigon_h: size of the area
79 | # im_size: Image size for generating the original X
80 | # k: The growth factor of the sigmoid function
81 | with tf.variable_scope('Mask'):
82 | x_left = tf.expand_dims(tf.expand_dims(tf.expand_dims(xy[:,0],1),2),3)
83 | x_right = tf.expand_dims(tf.expand_dims(tf.expand_dims(xy[:,0]+reigon_w,1),2),3)
84 | y_top = tf.expand_dims(tf.expand_dims(tf.expand_dims(xy[:,1],1),2),3)
85 | y_bottom = tf.expand_dims(tf.expand_dims(tf.expand_dims(xy[:,1]+reigon_h,1),2),3)
86 | x_value,y_value = get_constant(im_size)
87 | x_constant = np.tile(np.expand_dims(np.expand_dims(x_value,0),3),[B,1,1,1])
88 | y_constant = np.tile(np.expand_dims(np.expand_dims(y_value,0),3),[B,1,1,1])
89 | A = sigmoid_mask(x_constant-x_left,k)
90 | C = sigmoid_mask(x_constant-x_right,k)
91 | D = sigmoid_mask(y_constant-y_top,k)
92 | E = sigmoid_mask(y_constant-y_bottom,k)
93 | return (A-C)*(D-E)
94 |
95 | def FC(name,x,outs):
96 | with tf.variable_scope(name):
97 | shape = [i.value for i in x.get_shape()]
98 | size = int(shape[1]*shape[2]*shape[3])
99 | x_reshape = tf.reshape(x,[-1,size])
100 | w = tf.get_variable('w',[size,outs],tf.float32,tf.contrib.layers.xavier_initializer())
101 | b = tf.get_variable('b',[1,outs],tf.float32,tf.constant_initializer(0.))
102 | return tf.nn.sigmoid(tf.matmul(x_reshape,w)+b)
103 |
104 | def resnet_classifer(name,x): # x: batch,4,4,512
105 | with tf.variable_scope(name):
106 | res1 = res_block('res1',x)
107 | res2 = res_block('res2',res1)
108 | res3 = res_block('res3',res2)
109 | res4 = res_block('res4',res3)
110 | res5 = res_block('res5',res4)
111 | res6 = res_block('res6',res5)
112 | res7 = res_block('res7',res6)
113 | probs = FC('probs',res7,10)
114 | return probs
115 |
116 | def One_ep_Iter(im_dir,batch):
117 | num_ims = min(len(os.listdir(im_dir+'positive')),len(os.listdir(im_dir+'negtive')))
118 | return num_ims//batch
119 |
120 | def get_shape(x):
121 | L = [i.value for i in x.get_shape()]
122 | return L
123 |
124 | def save_im(sample_im,save_dir,cur_ep,cur_batch,batch):
125 | for i in range(batch):
126 | im_S = sample_im[i,:,:,:]
127 | im_s = (im_S[:,:,::-1]+1)*127.5
128 | s_name = 'ep'+str(cur_ep)+'_sample'+str(cur_batch)+'_batch'+str(i)+'.png'
129 | cv.imwrite(save_dir+s_name,im_s)
130 | return True
131 |
132 | def Save_load(ims,num,save_dir):
133 | b = np.shape(ims)[0]
134 | for i in range(b):
135 | im_S = ims[i,:,:,:]
136 | im_s = (im_S[:,:,::-1]+1)*127.5
137 | name = 'num_'+str(num)+'.png'
138 | cv.imwrite(save_dir+name,im_s)
139 | return True
140 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from utlis import *
2 | from ops import *
3 | import tensorflow as tf
4 | import numpy as np
5 | import cv2 as cv
6 | import math,os,h5py
7 |
8 | class DAE_GAN:
9 | def __init__(self,batch,epoch,im_size,hw_size,k,alpa,beta,im_dir,save_dir,saveS_dir,saveT_dir):
10 | self.batch = batch
11 | self.epoch = epoch
12 | self.im_size = im_size
13 | self.hw_size = hw_size
14 | self.positive_dir = [im_dir+'positive/'+name for name in os.listdir(im_dir+'positive')]
15 | self.negtive_dir = [im_dir+'negtive/'+name for name in os.listdir(im_dir+'negtive')]
16 | self.svhnMat = h5py.File(im_dir+'digitStruct.mat', mode='r')
17 | self.first = 64
18 | self.k = k
19 | self.alpa = alpa
20 | self.beta = beta
21 | self.one_ep_iter = One_ep_Iter(im_dir,batch)
22 | self.S = tf.placeholder(tf.float32,[None,im_size,im_size,3],'S')
23 | self.T = tf.placeholder(tf.float32,[None,im_size,im_size,3],'T')
24 | self.S_label = tf.placeholder(tf.float32,[None,10],'S_label')
25 | self.T_label = tf.placeholder(tf.float32,[None,10],'T_label')
26 | self.save_dir = save_dir
27 | self.saveS = saveS_dir
28 | self.saveT = saveT_dir
29 |
30 | def Encoder(self,x,reuse):
31 | with tf.variable_scope('encoder',reuse=reuse):
32 | conv1 = conv('conv1',x,3*3,2,self.first,True)
33 | bn1 = BN('bn1',conv1)
34 | relu1 = lrelu('relu1',bn1)
35 |
36 | conv2 = conv('conv2',relu1,3*3,2,int(2*self.first),True)
37 | bn2 = BN('bn2',conv2)
38 | relu2 = lrelu('relu2',bn2)
39 |
40 | conv3 = conv('conv3',relu2,3*3,2,int(4*self.first),True)
41 | bn3 = BN('bn3',conv3)
42 | relu3 = lrelu('relu3',bn3)
43 |
44 | conv4 = conv('conv4',relu3,3*3,2,int(8*self.first),True)
45 | bn4 = BN('bn4',conv4)
46 | relu4 = lrelu('relu4',bn4)
47 | return relu4
48 |
49 | def F(self,features,k):
50 | with tf.variable_scope('f_location'):
51 | cha = [i.value for i in features.get_shape()][-1]
52 |
53 | conv1 = conv('conv1',features,3*3,2,int(2*cha),True)
54 | bn1 = BN('bn1',conv1)
55 | relu1 = lrelu('relu1',bn1)
56 |
57 | conv2 = conv('conv2',relu1,3*3,2,int(4*cha),True)
58 | bn2 = BN('bn2',conv2)
59 | relu2 = lrelu('relu2',bn2)
60 |
61 | out_xy = FC_location('fc1',relu2,2,self.im_size,self.hw_size)
62 | mask = get_mask(out_xy,self.hw_size,self.hw_size,self.im_size,k,self.batch)
63 | return mask
64 |
65 | def GAN_G(self,x,reuse):
66 | with tf.variable_scope('G_net',reuse=reuse):
67 |
68 | conv_trans1 = conv_trans('conv_trans1',x,5*5,1,256,True)
69 | Lrelu1 = lrelu('Lrelu1',conv_trans1)
70 |
71 | conv_trans2 = conv_trans('conv_trans2',Lrelu1,5*5,1,128,True)
72 | bn2 = BN('bn2',conv_trans2)
73 | Lrelu2 = lrelu('Lrelu2',bn2)
74 |
75 | conv_trans3 = conv_trans('conv_trans3',Lrelu2,5*5,1,64,True)
76 | bn3 = BN('bn3',conv_trans3)
77 | Lrelu3 = lrelu('Lrelu3',bn3)
78 |
79 | conv_trans4 = conv_trans('conv_trans4',Lrelu3,5*5,1,3,True)
80 | Lrelu4 = tanh('Lrelu4',conv_trans4)
81 | return Lrelu4
82 |
83 | def GAN_D(self,name,x,reuse):
84 | with tf.variable_scope(name,reuse=reuse):
85 | conv1 = conv('conv1',x,3*3,2,64,True)
86 | bn1 = BN('bn1',conv1)
87 | relu1 = lrelu('relu1',bn1)
88 |
89 | conv2 = conv('conv2',relu1,3*3,2,128,True)
90 | bn2 = BN('bn2',conv2)
91 | relu2 = lrelu('relu2',bn2)
92 |
93 | conv3 = conv('conv3',relu2,3*3,2,256,True)
94 | bn3 = BN('bn3',conv3)
95 | relu3 = lrelu('relu3',bn3)
96 |
97 | conv4 = conv('conv4',relu3,3*3,2,512,True)
98 | bn4 = BN('bn4',conv4)
99 | relu4 = lrelu('relu4',bn4)
100 |
101 | D_out = FC('fc1',relu4,1)
102 | return D_out
103 |
104 | def DAE(self,x,reuse):
105 | with tf.variable_scope('DAE',reuse=reuse):
106 | encode_x = self.Encoder(x,reuse)
107 | mask = self.F(encode_x,self.k)
108 | x_mask = x*mask
109 | encode_mask_x = self.Encoder(x_mask,True)
110 |
111 | probs = resnet_classifer('classsifer',encode_mask_x)
112 | return encode_mask_x,probs
113 |
114 | def forward(self):
115 | self.s_DAE,self.s_out_label = self.DAE(self.S,False)
116 | self.t_DAE,self.t_out_label = self.DAE(self.T,True)
117 | self.s_pie = self.GAN_G(self.s_DAE,False)
118 | self.t_pie = self.GAN_G(self.t_DAE,True)
119 |
120 | self.t_D1 = self.GAN_D('D2',self.T,False)
121 | self.t_pie_D = self.GAN_D('D2',self.t_pie,True)
122 |
123 | self.t_D2 = self.GAN_D('D1',self.T,False)
124 | self.s_pie_D = self.GAN_D('D1',self.s_pie,True)
125 |
126 | self.s_pie_DAE,_ = self.DAE(self.s_pie,True)
127 | self.t_pie_DAE,_ = self.DAE(self.t_pie,True)
128 |
129 | def train(self):
130 | self.forward()
131 | # Loss for G and DAE:
132 | # 1、 Loss of s_DAE and s_pie_DAE:
133 | Lcst = tf.reduce_mean(tf.abs(self.s_DAE-self.s_pie_DAE)) # L1范数
134 | # 2、 Loss of t_DAE and t_pie_DAE:
135 | Lsym = tf.reduce_mean(tf.abs(self.t_DAE-self.t_pie_DAE)) # L1范数
136 | # 3、Make D2 judge t_pie as true:
137 | loss_G_DAE1 = tf.reduce_mean(-1*tf.log(self.t_pie_D))
138 | # 4、Make D1 judge s_pie as true:
139 | loss_G_DAE2 = tf.reduce_mean(-1*tf.log(self.s_pie_D))
140 | # 5、Loss caused by classification: cross entropy, this alone corresponds to DAE:
141 | cross_entroy = tf.reduce_mean(tf.reduce_mean((-1)*tf.log(self.s_out_label)*self.S_label + (-1)*tf.log(1-self.s_out_label)*(1-self.S_label),1))
142 |
143 | self.DAE_G_loss = self.alpa*Lcst + self.beta*Lsym + loss_G_DAE1 + loss_G_DAE2 + cross_entroy
144 |
145 | self.D1_loss = tf.reduce_mean((-1)*tf.log(self.t_D2) + (-1)*tf.log(1-self.s_pie_D))
146 | self.D2_loss = tf.reduce_mean((-1)*tf.log(self.t_D2) + (-1)*tf.log(1-self.t_pie_D))
147 | DAE_G_vars = [var for var in tf.trainable_variables() if 'DAE' in var.name or 'G_net' in var.name]
148 | D1_vars = [var for var in tf.trainable_variables() if 'D1' in var.name]
149 | D2_vars = [var for var in tf.trainable_variables() if 'D2' in var.name]
150 | optim_DAE_G = tf.train.AdamOptimizer().minimize(self.DAE_G_loss,var_list=DAE_G_vars)
151 | optim_D1 = tf.train.AdamOptimizer().minimize(self.D1_loss,var_list=D1_vars)
152 | optim_D2 = tf.train.AdamOptimizer().minimize(self.D2_loss,var_list=D2_vars)
153 | with tf.Session() as sess:
154 | sess.run(tf.global_variables_initializer())
155 | graph = tf.summary.FileWriter(self.save_dir,graph=sess.graph)
156 | Saver = tf.train.Saver(max_to_keep=20)
157 |
158 | savedir = self.save_dir+'model.ckpt'
159 | for i in range(self.epoch):
160 | for j in range(self.one_ep_iter):
161 | ims_po,labels_po = get_im_label(j, self.batch, self.positive_dir, self.svhnMat)
162 | ims_neg,labels_neg = get_im_label(j, self.batch, self.negtive_dir, self.svhnMat)
163 |
164 | fed_dict={self.S:ims_po,self.S_label:labels_po,self.T:ims_neg,self.T_label:labels_neg}
165 |
166 | _,LossD1 = sess.run([optim_D1,self.D1_loss],feed_dict=fed_dict)
167 | print('LossD1: ',LossD1)
168 |
169 | _,LossD2 = sess.run([optim_D2,self.D2_loss],feed_dict=fed_dict)
170 | print('LossD2: ',LossD2)
171 |
172 | _,S_sample,T_sample,LossDAE_G = sess.run([optim_DAE_G,self.s_pie,self.t_pie,self.DAE_G_loss],feed_dict=fed_dict)
173 | print('LossDAE_G: ',LossDAE_G)
174 |
175 | save_im(S_sample, self.saveS, i, j, self.batch)
176 | save_im(T_sample, self.saveT, i, j, self.batch)
177 |
178 | step = int(i*self.epoch + j)
179 | Saver.save(sess,savedir,global_step=step)
180 | print('save_success at: ',step)
181 |
182 | def load(self,load_dir,raw_im_dir,save_im_dir):
183 | self.forward()
184 | with tf.Session() as sess:
185 | sess.run(tf.global_variables_initializer())
186 | graph = tf.summary.FileWriter(self.save_dir,graph=sess.graph)
187 | Saver = tf.train.Saver()
188 | Saver.restore(sess,load_dir)
189 | for i,im_name in enumerate(os.listdir(raw_im_dir)):
190 | im_bgr = cv.resize(cv.imread(raw_im_dir+im_name),(64,64),interpolation=cv.INTER_CUBIC)
191 | im_rgb_unpro = im_bgr[:,:,::-1]
192 | im_rgb = np.expand_dims(((im_rgb_unpro/255)-0.5)*2,0)
193 | # fed_dict={self.S:ims_po,self.S_label:labels_po}
194 | fed_dict={self.S:im_rgb}
195 | s_tar = sess.run(self.s_pie,feed_dict=fed_dict)
196 | Save_load(s_tar,i,save_im_dir)
197 | print('save at: ',i)
198 |
--------------------------------------------------------------------------------