├── README.md ├── backward.py ├── mnist.py ├── network.py └── test_complex.py /README.md: -------------------------------------------------------------------------------- 1 | # CNN-based-on-Complex-Number 2 | 基于复数的卷积神经网络 3 | 4 | 复现论文“Deep Complex Network”部分代码,论文地址:https://arxiv.org/abs/1705.09792 5 | 6 | 三层(3*3复数卷积+复数池化+复数激活函数),两层(复数全连接+复数Dropout) 7 | 在MNIST数据集对比实数网络与复数网络的性能。 8 | -------------------------------------------------------------------------------- /backward.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import network 3 | import mnist 4 | import matplotlib.pyplot as plt 5 | import random 6 | import numpy as np 7 | 8 | 9 | BATCH_SIZE=32 10 | EPOCH=500 11 | LEARNING_RATE=0.005 12 | KEEP_PROB=1 13 | REGULARIZER=0.005 14 | 15 | 16 | def backward(data): 17 | x=tf.placeholder(tf.float32,(BATCH_SIZE,network.IMAGE_SIZE,network.IMAGE_SIZE,network.IMAGE_CHANNEL)) 18 | y=network.forward(x,True,KEEP_PROB,REGULARIZER) 19 | y_=tf.placeholder(tf.float32,(None,network.OUTPUT_NODE)) 20 | 21 | loss1=tf.nn.softmax_cross_entropy_with_logits(logits=y,labels=y_) 22 | loss2=tf.reduce_mean(loss1) 23 | loss3=loss2+tf.add_n(tf.get_collection('losses')) 24 | 25 | accuracy=tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y,1),tf.argmax(y_,1)),tf.float32)) 26 | 27 | opimizer=tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss3) 28 | 29 | saver = tf.train.Saver(max_to_keep=1) 30 | 31 | with tf.Session() as sess: 32 | sess.run(tf.global_variables_initializer()) 33 | 34 | accu1 = 0 35 | accu2 = 0 36 | 37 | # 记录要绘图的变量 38 | x_p = [i for i in range(1, EPOCH + 1)] 39 | y_loss = [i for i in range(1, EPOCH + 1)] 40 | y_accu = [i for i in range(1, EPOCH + 1)] 41 | 42 | for i in range(1,EPOCH+1): 43 | xf,yf=data.next_batch(BATCH_SIZE) 44 | 45 | xf=mnist.mnist_fft(xf,BATCH_SIZE) 46 | 47 | _, accu, los = sess.run([opimizer, accuracy, loss3], feed_dict={x: xf, y_: yf}) 48 | 49 | y_loss[i - 1] = los 50 | y_accu[i - 1] = accu 51 | 52 | if accu1 > 0.75 and accu2 > 0.75 and accu > 0.75: 53 | saver.save(sess, 'mnist_Complex/mnist_Complex.ckpt',write_meta_graph=False) 54 | 55 | accu1 = accu2 56 | accu2 = accu 57 | 58 | print('Epoch: ', i) 59 | print('loss on batch: ', los) 60 | print('accuracy on batch: ', accu) 61 | print('.......................................') 62 | 63 | print(y_loss) 64 | print(y_accu) 65 | plt.figure() 66 | plt.plot(x_p[0:len(x_p):4], y_loss[0:len(y_loss):4]) 67 | plt.figure() 68 | plt.plot(x_p[0:len(x_p):4], y_accu[0:len(y_loss):4]) 69 | plt.show() 70 | 71 | backward(mnist.train) 72 | 73 | 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /mnist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from tensorflow.examples.tutorials.mnist import input_data 4 | mnist=input_data.read_data_sets("MNIST_data",one_hot=True,reshape=True) 5 | # 注意 返回的是展开的 28*28=784 训练时不要忘记reshape一下 6 | 7 | train=mnist.train 8 | test=mnist.test 9 | 10 | def mnist_fft(xf,batch): 11 | # xf: batch,784 12 | # 返回值: y [batch,28,28,2] 13 | xf=np.reshape(xf,(batch,28,28)) 14 | x=np.ones_like(xf,dtype=np.complex) 15 | for i in range(0,batch): 16 | x[i]=np.fft.fftshift(np.fft.fft2(xf[i])) 17 | # x=np.reshape(x,(batch,28,28,1)) 18 | # y=np.concatenate([np.real(x),np.imag(x)],3) 19 | y=np.stack([np.real(x),np.imag(x)],3) 20 | # 不能用绝对值!会改变相位信息 21 | # y=np.log(np.abs(y)+1) 22 | return y 23 | 24 | 25 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | # 复数张量的表示:原通道数为n,则将通道数分为2n,前n表示实部,后n表示虚部。 5 | IMAGE_SIZE=28 6 | IMAGE_CHANNEL=2 #实部+虚部 7 | FILTER1_SIZE=3 8 | FILTER1_NUM=64 9 | FILTER2_SIZE=3 10 | FILTER2_NUM=64 11 | FILTER3_SIZE=3 12 | FILTER3_NUM=64 13 | FC1_SIZE=128 # 指的是有128个复数 14 | FC2_SIZE=128 15 | OUTPUT_NODE=10 16 | 17 | 18 | def get_weight(shape,regularizer): 19 | # shape: 同卷积核维度 [行,列,通道,个数] 20 | w=tf.Variable(tf.truncated_normal(shape,stddev=1)) 21 | if regularizer!=None: 22 | tf.add_to_collection('losses',tf.contrib.layers.l2_regularizer(regularizer)(w)) 23 | return w 24 | 25 | 26 | def get_bias(shape): 27 | # shape:卷积核个数 [FILTER_NUM] 28 | return tf.Variable(tf.zeros(shape)) 29 | 30 | 31 | def complex_conv(input,filter,strides): 32 | # input: [batch,行,列,通道] 前一半通道是实数,后一半是复数 33 | # filter: [行,列,通道,滤波器个数] 34 | # strides: 核滑动步长[1,行步长,列步长,1] 35 | 36 | # W=A+Bi 37 | # f=x+yi 38 | [A,B]=tf.split(input,2,3) # 在第3维(通道)上切割 (从第0维开始计数) 39 | [x,y]=tf.split(filter,2,2) # 在第2维(通道)上切割 40 | output1= tf.nn.conv2d(A,x,strides,'SAME') - tf.nn.conv2d(B,y,strides,'SAME') #实部 41 | output2= tf.nn.conv2d(A,y,strides,'SAME') + tf.nn.conv2d(B,x,strides,'SAME') #虚部 42 | output=tf.concat([output1,output2],3) # 在通道上拼接 43 | 44 | return output 45 | 46 | def modRelu(input): 47 | # input: [batch,行,列,通道] 48 | A, B = tf.split(input, 2, 3) # 在第3维(通道)上切割 49 | C = tf.abs(tf.complex(A, B)) 50 | 51 | # 可以学习的变量b 52 | b = 50 * np.ones([C.get_shape()[-1]]) 53 | b = tf.Variable(b, dtype=tf.float32) 54 | C1=C-b 55 | # 衰减系数fac 56 | fac = C1/(C+0.0001) 57 | 58 | relu=tf.nn.relu(C1) 59 | flag=tf.cast(tf.cast(relu,tf.bool),tf.float32) 60 | real = A * flag * fac 61 | imag = B * flag * fac 62 | res=tf.concat([real, imag], 3) 63 | return res 64 | 65 | 66 | def zRelu(input): 67 | # input: [batch,行,列,通道] 68 | relu=tf.nn.relu(input) 69 | [real,imag]=tf.split(relu,2,3) 70 | flag_real=tf.cast(tf.cast(real,tf.bool),tf.float32) 71 | flag_imag=tf.cast(tf.cast(imag,tf.bool),tf.float32) 72 | flag=flag_imag*flag_real 73 | real=real*flag 74 | imag=imag*flag 75 | return tf.concat([real,imag],3) 76 | 77 | 78 | def cRelu(input): 79 | # 相当于独立对实部虚部求relu 80 | # input: [batch,行,列,通道] 81 | return tf.nn.relu(input) 82 | 83 | 84 | def complex_avg_pool(input,ksize,strides): 85 | # 均值池化 86 | # input: [batch,行,列,通道] 87 | # ksize: 池化核描述 [1,行,列,1] 88 | # strides: 滑动步长 [1,行,列,1] 89 | return tf.nn.avg_pool(input,ksize,strides,'SAME') 90 | 91 | 92 | def complex_max_pool(input,ksize,strides): 93 | # 最大值池化 94 | # input: [batch,行,列,通道] 95 | # ksize: 池化核描述 [1,行,列,1] 96 | # strides: 滑动步长 [1,行,列,1] 97 | # W=A+Bi 98 | A, B = tf.split(input, 2, 3) # 在第3维(通道)上切割 99 | flatten_A=tf.reshape(A,shape=[-1]) 100 | flatten_B=tf.reshape(B,[-1]) 101 | 102 | C = tf.abs(tf.abs(tf.complex(A, B))) 103 | _, mask = tf.nn.max_pool_with_argmax(C, ksize, strides, padding='SAME') 104 | output_shape=mask.get_shape() 105 | 106 | flatten_mask=tf.reshape(mask,shape=[-1]) 107 | flatten_real=tf.gather(flatten_A,flatten_mask) 108 | flatten_imag=tf.gather(flatten_B,flatten_mask) 109 | 110 | real=tf.reshape(flatten_real,output_shape) 111 | imag=tf.reshape(flatten_imag,output_shape) 112 | 113 | return tf.concat([real,imag],3) 114 | 115 | 116 | def fully_connect(fc,cur_size,regularizer): 117 | # 输入:实部+虚部、大小、正则化项 118 | # 都是2维 119 | 120 | pre_size=tf.cast(fc.get_shape().as_list()[1]/2,tf.int32) 121 | wr = get_weight([pre_size, cur_size], regularizer) 122 | wi = get_weight([pre_size, cur_size], regularizer) 123 | br = get_bias([cur_size]) 124 | bi = get_bias([cur_size]) 125 | 126 | [real,imag]=tf.split(fc,2,1) # 第1维度上切成两份 127 | R=tf.matmul(real,wr)-tf.matmul(imag,wi)+br 128 | W=tf.matmul(real,wi)+tf.matmul(imag,wr)+bi 129 | return tf.concat([R,W],1) 130 | 131 | 132 | def fully_zRelu(fc): 133 | # 全连接层使用的zRelu 134 | fc=tf.nn.relu(fc) 135 | real,imag=tf.split(fc,2,1) 136 | real_flag=tf.cast(tf.cast(real,tf.bool),tf.float32) 137 | imag_flag=tf.cast(tf.cast(imag,tf.bool),tf.float32) 138 | flag=real_flag*imag_flag 139 | return tf.concat([real*flag,imag*flag],1) 140 | 141 | def fully_modRelu(fc): 142 | # 适用于全连接层的modRelu 143 | # fc:[batch,units] 144 | # units前一半是实部,后一半是虚部 145 | # 返回相同形状 146 | 147 | A, B = tf.split(fc, 2, 1) # 在第1维切割 A实部 B虚部 148 | C=tf.complex(A, B) 149 | C = tf.abs(C) 150 | 151 | # 可以学习的变量b 152 | b = 50 * np.ones(([C.get_shape()][-1])) 153 | b = tf.Variable(b, dtype=tf.float32) 154 | 155 | C1 = C - b 156 | # 衰减系数 157 | fac = C1 / (C+0.0001) 158 | 159 | flag = tf.nn.relu(C1) 160 | flag = tf.cast(tf.cast(flag, tf.bool), tf.float32) 161 | 162 | real = A * flag * fac 163 | imag = B * flag * fac 164 | return tf.concat([real, imag], 1) 165 | 166 | 167 | def fully_cRelu(fc): 168 | # 相当于直接做relu,不区分实部虚部 169 | return tf.nn.relu(fc) 170 | 171 | 172 | def dropout(x,keep_prob): 173 | # 输入x: [batch , size] 174 | # 以keep_prob的概率留下 175 | shape=x.get_shape().as_list() 176 | flag=np.ones(shape=(shape[0],shape[1])) 177 | for i in range(0,shape[0]): 178 | for j in range(0,int(shape[1]/2)): 179 | if np.random.rand() > keep_prob: 180 | flag[i][j]=0 181 | flag[i][j+int(shape[1]/2)]=0 182 | rate=np.mean(flag,1) 183 | res=np.reshape(rate,(shape[0],1)) 184 | rate=res 185 | for i in range(shape[1]-1): 186 | res=np.concatenate([res,rate],1) 187 | w = tf.convert_to_tensor(flag,dtype=tf.float32) 188 | res=tf.convert_to_tensor(res,dtype=tf.float32) 189 | return w*x/res 190 | 191 | 192 | def forward(x,train,keep_prob,regularizer): 193 | 194 | # x:前一半通道是实数,后一半通道是复数 195 | filter1=get_weight([FILTER1_SIZE,FILTER1_SIZE,IMAGE_CHANNEL,FILTER1_NUM],regularizer) 196 | filter1_bias=get_bias([FILTER1_NUM*2]) 197 | conv1=complex_conv(x,filter1,[1,1,1,1]) 198 | relu1=modRelu(tf.nn.bias_add(conv1,filter1_bias)) 199 | 200 | pool1=complex_max_pool(relu1,[1,2,2,1],[1,2,2,1]) 201 | 202 | filter2=get_weight([FILTER2_SIZE,FILTER2_SIZE,FILTER1_NUM*2,FILTER2_NUM],regularizer) 203 | filter2_bias=get_bias([FILTER2_NUM*2]) 204 | conv2=complex_conv(pool1,filter2,[1,1,1,1]) 205 | relu2=modRelu(tf.nn.bias_add(conv2,filter2_bias)) 206 | 207 | pool2=complex_max_pool(relu2,[1,2,2,1],[1,2,2,1]) 208 | 209 | filter3 = get_weight([FILTER3_SIZE, FILTER3_SIZE, FILTER2_NUM * 2, FILTER3_NUM], regularizer) 210 | filter3_bias = get_bias([FILTER3_NUM * 2]) 211 | conv3 = complex_conv(pool2, filter3, [1, 1, 1, 1]) 212 | relu3 = modRelu(tf.nn.bias_add(conv3, filter3_bias)) 213 | 214 | pool3 = complex_max_pool(relu3, [1, 2, 2, 1], [1, 2, 2, 1]) 215 | 216 | pool3_real,pool3_imag=tf.split(pool3,2,3) #切成实部和虚部 217 | fc0_ri_shape=pool3_real.get_shape().as_list() 218 | fc0_size=fc0_ri_shape[1]*fc0_ri_shape[2]*fc0_ri_shape[3] 219 | fc0_real=tf.reshape(pool3_real,(-1,fc0_size)) 220 | fc0_imag=tf.reshape(pool3_imag,(-1,fc0_size)) 221 | fc0=tf.concat([fc0_real,fc0_imag],1) 222 | 223 | 224 | fc1=fully_connect(fc0,FC1_SIZE,regularizer) 225 | fc1=fully_modRelu(fc1) 226 | if train: 227 | fc1 = dropout(fc1,keep_prob) 228 | 229 | 230 | fc2 = fully_connect(fc1, FC2_SIZE, regularizer) 231 | fc2=fully_modRelu(fc2) 232 | if train: 233 | fc2=dropout(fc2,keep_prob) 234 | 235 | 236 | output = fully_connect(fc2, OUTPUT_NODE, regularizer) 237 | output_real,output_imag=tf.split(output,2,1) 238 | return tf.abs(tf.complex(output_real,output_imag)) 239 | #return output_real 240 | 241 | 242 | # x=tf.Variable(np.random.random((32,IMAGE_SIZE,IMAGE_SIZE,IMAGE_CHANNEL)),dtype=tf.float32) 243 | # forward(x,True,0.5,0.001) 244 | -------------------------------------------------------------------------------- /test_complex.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import network 3 | import mnist 4 | import numpy as np 5 | 6 | BATCH_SIZE = 32 7 | EPOCH = 50 8 | LEARNING_RATE = 0.005 9 | KEEP_PROB = 1 10 | REGULARIZER = None 11 | 12 | 13 | def backward(data): 14 | x = tf.placeholder(tf.float32, (BATCH_SIZE, network.IMAGE_SIZE, network.IMAGE_SIZE, network.IMAGE_CHANNEL)) 15 | y = network.forward(x, False, KEEP_PROB, REGULARIZER) 16 | y_ = tf.placeholder(tf.float32, (None, network.OUTPUT_NODE)) 17 | 18 | loss1 = tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_) 19 | loss2 = tf.reduce_mean(loss1) 20 | accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)), tf.float32)) 21 | 22 | saver=tf.train.Saver(max_to_keep=1) 23 | 24 | with tf.Session() as sess: 25 | sess.run(tf.global_variables_initializer()) 26 | 27 | model = tf.train.latest_checkpoint('mnist_Complex/') 28 | saver.restore(sess, model) 29 | 30 | accu_sum=0 31 | loss_sum=0 32 | 33 | for i in range(0,EPOCH): 34 | xf, yf = data.next_batch(BATCH_SIZE) 35 | xf = mnist.mnist_fft(xf, BATCH_SIZE) 36 | accu, los = sess.run([accuracy, loss2], feed_dict={x: xf, y_: yf}) 37 | accu_sum+=accu 38 | loss_sum+=los 39 | 40 | accu_mean=accu_sum/EPOCH 41 | loss_mean=loss_sum/EPOCH 42 | print('loss on test: ',loss_mean ) 43 | print('accuracy on test: ',accu_mean) 44 | 45 | 46 | backward(mnist.test) --------------------------------------------------------------------------------