├── README.md ├── infomax_cifar10.py └── infomax_tiny_imagenet.py /README.md: -------------------------------------------------------------------------------- 1 | # infomax 2 | extract features by maximizing mutual information 3 | 4 | 5 | a fine-tuned version of (https://arxiv.org/abs/1808.06670). 6 | 7 | see more: https://kexue.fm/archives/6024 8 | 9 | 10 | ![knn1](https://kexue.fm/usr/uploads/2018/10/1623425049.png) ![knn2](https://kexue.fm/usr/uploads/2018/10/1899771582.png) 11 | 12 | ## 交流 13 | QQ交流群:67729435,微信群请加机器人微信号spaces_ac_cn 14 | -------------------------------------------------------------------------------- /infomax_cifar10.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import glob 5 | import imageio 6 | from keras.models import Model 7 | from keras.layers import * 8 | from keras import backend as K 9 | from keras.optimizers import Adam 10 | from keras.datasets import cifar10 11 | import tensorflow as tf 12 | 13 | 14 | (x_train, y_train), (x_test, y_test) = cifar10.load_data() 15 | x_train = x_train.astype('float32') / 255 - 0.5 16 | x_test = x_test.astype('float32') / 255 - 0.5 17 | y_train = y_train.reshape(-1) 18 | y_test = y_test.reshape(-1) 19 | img_dim = x_train.shape[1] 20 | 21 | 22 | z_dim = 256 # 隐变量维度 23 | alpha = 0.5 # 全局互信息的loss比重 24 | beta = 1.5 # 局部互信息的loss比重 25 | gamma = 0.01 # 先验分布的loss比重 26 | 27 | 28 | # 编码器(卷积与最大池化) 29 | x_in = Input(shape=(img_dim, img_dim, 3)) 30 | x = x_in 31 | 32 | for i in range(3): 33 | x = Conv2D(z_dim // 2**(2-i), 34 | kernel_size=(3,3), 35 | padding='SAME')(x) 36 | x = BatchNormalization()(x) 37 | x = LeakyReLU(0.2)(x) 38 | x = MaxPooling2D((2, 2))(x) 39 | 40 | feature_map = x # 截断到这里,认为到这里是feature_map(局部特征) 41 | feature_map_encoder = Model(x_in, x) 42 | 43 | 44 | for i in range(2): 45 | x = Conv2D(z_dim, 46 | kernel_size=(3,3), 47 | padding='SAME')(x) 48 | x = BatchNormalization()(x) 49 | x = LeakyReLU(0.2)(x) 50 | 51 | x = GlobalMaxPooling2D()(x) # 全局特征 52 | 53 | z_mean = Dense(z_dim)(x) # 均值,也就是最终输出的编码 54 | z_log_var = Dense(z_dim)(x) # 方差,这里都是模仿VAE的 55 | 56 | 57 | encoder = Model(x_in, z_mean) # 总的编码器就是输出z_mean 58 | 59 | 60 | # 重参数技巧 61 | def sampling(args): 62 | z_mean, z_log_var = args 63 | u = K.random_normal(shape=K.shape(z_mean)) 64 | return z_mean + K.exp(z_log_var / 2) * u 65 | 66 | 67 | # 重参数层,相当于给输入加入噪声 68 | z_samples = Lambda(sampling)([z_mean, z_log_var]) 69 | prior_kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)) 70 | 71 | 72 | # shuffle层,打乱第一个轴 73 | def shuffling(x): 74 | idxs = K.arange(0, K.shape(x)[0]) 75 | idxs = tf.random_shuffle(idxs) 76 | return K.gather(x, idxs) 77 | 78 | 79 | # 与随机采样的特征拼接(全局) 80 | z_shuffle = Lambda(shuffling)(z_samples) 81 | z_z_1 = Concatenate()([z_samples, z_samples]) 82 | z_z_2 = Concatenate()([z_samples, z_shuffle]) 83 | 84 | # 与随机采样的特征拼接(局部) 85 | feature_map_shuffle = Lambda(shuffling)(feature_map) 86 | z_samples_repeat = RepeatVector(4 * 4)(z_samples) 87 | z_samples_map = Reshape((4, 4, z_dim))(z_samples_repeat) 88 | z_f_1 = Concatenate()([z_samples_map, feature_map]) 89 | z_f_2 = Concatenate()([z_samples_map, feature_map_shuffle]) 90 | 91 | 92 | # 全局判别器 93 | z_in = Input(shape=(z_dim*2,)) 94 | z = z_in 95 | z = Dense(z_dim, activation='relu')(z) 96 | z = Dense(z_dim, activation='relu')(z) 97 | z = Dense(z_dim, activation='relu')(z) 98 | z = Dense(1, activation='sigmoid')(z) 99 | 100 | GlobalDiscriminator = Model(z_in, z) 101 | 102 | z_z_1_scores = GlobalDiscriminator(z_z_1) 103 | z_z_2_scores = GlobalDiscriminator(z_z_2) 104 | global_info_loss = - K.mean(K.log(z_z_1_scores + 1e-6) + K.log(1 - z_z_2_scores + 1e-6)) 105 | 106 | 107 | # 局部判别器 108 | z_in = Input(shape=(None, None, z_dim*2)) 109 | z = z_in 110 | z = Dense(z_dim, activation='relu')(z) 111 | z = Dense(z_dim, activation='relu')(z) 112 | z = Dense(z_dim, activation='relu')(z) 113 | z = Dense(1, activation='sigmoid')(z) 114 | 115 | LocalDiscriminator = Model(z_in, z) 116 | 117 | z_f_1_scores = LocalDiscriminator(z_f_1) 118 | z_f_2_scores = LocalDiscriminator(z_f_2) 119 | local_info_loss = - K.mean(K.log(z_f_1_scores + 1e-6) + K.log(1 - z_f_2_scores + 1e-6)) 120 | 121 | # 用来训练的模型 122 | model_train = Model(x_in, [z_z_1_scores, z_z_2_scores, z_f_1_scores, z_f_2_scores]) 123 | model_train.add_loss(alpha * global_info_loss + beta * local_info_loss + gamma * prior_kl_loss) 124 | model_train.compile(optimizer=Adam(1e-3)) 125 | 126 | model_train.fit(x_train, epochs=50, batch_size=64) 127 | model_train.save_weights('total_model.cifar10.weights') 128 | 129 | 130 | # 输出编码器的特征 131 | zs = encoder.predict(x_train, verbose=True) 132 | zs.mean() # 查看均值(简单观察先验分布有没有达到效果) 133 | zs.std() # 查看方差(简单观察先验分布有没有达到效果) 134 | 135 | 136 | # 随机选一张图片,输出最相近的图片 137 | # 可以选用欧氏距离或者cos值 138 | def sample_knn(path): 139 | n = 10 140 | topn = 10 141 | figure1 = np.zeros((img_dim*n, img_dim*topn, 3)) 142 | figure2 = np.zeros((img_dim*n, img_dim*topn, 3)) 143 | zs_ = zs / (zs**2).sum(1, keepdims=True)**0.5 144 | for i in range(n): 145 | one = np.random.choice(len(x_train)) 146 | idxs = ((zs**2).sum(1) + (zs[one]**2).sum() - 2 * np.dot(zs, zs[one])).argsort()[:topn] 147 | for j,k in enumerate(idxs): 148 | digit = x_train[k] 149 | figure1[i*img_dim: (i+1)*img_dim, 150 | j*img_dim: (j+1)*img_dim] = digit 151 | idxs = np.dot(zs_, zs_[one]).argsort()[-n:][::-1] 152 | for j,k in enumerate(idxs): 153 | digit = x_train[k] 154 | figure2[i*img_dim: (i+1)*img_dim, 155 | j*img_dim: (j+1)*img_dim] = digit 156 | figure1 = (figure1 + 1) / 2 * 255 157 | figure1 = np.clip(figure1, 0, 255) 158 | figure2 = (figure2 + 1) / 2 * 255 159 | figure2 = np.clip(figure2, 0, 255) 160 | imageio.imwrite(path+'_l2.png', figure1) 161 | imageio.imwrite(path+'_cos.png', figure2) 162 | 163 | 164 | sample_knn('test') 165 | -------------------------------------------------------------------------------- /infomax_tiny_imagenet.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import glob 5 | import imageio 6 | from scipy import misc 7 | from keras.models import Model 8 | from keras.layers import * 9 | from keras import backend as K 10 | from keras.optimizers import Adam 11 | from tqdm import tqdm 12 | import tensorflow as tf 13 | 14 | 15 | imgs = glob.glob('tiny-imagenet-200/train/*/images/*') 16 | np.random.shuffle(imgs) 17 | 18 | 19 | def imread(f): 20 | x = misc.imread(f, mode='RGB') 21 | return x.astype(np.float32) / 255 * 2 - 1 22 | 23 | 24 | x_train = np.array([imread(f) for f in tqdm(iter(imgs))]) 25 | img_dim = x_train.shape[1] 26 | 27 | 28 | z_dim = 256 # 隐变量维度 29 | alpha = 0.5 # 全局互信息的loss比重 30 | beta = 1.5 # 局部互信息的loss比重 31 | gamma = 0.01 # 先验分布的loss比重 32 | 33 | 34 | # 编码器(卷积与最大池化) 35 | x_in = Input(shape=(img_dim, img_dim, 3)) 36 | x = x_in 37 | 38 | for i in range(4): 39 | x = Conv2D(z_dim // 2**(3-i), 40 | kernel_size=(3,3), 41 | padding='SAME')(x) 42 | x = BatchNormalization()(x) 43 | x = LeakyReLU(0.2)(x) 44 | x = MaxPooling2D((2, 2))(x) 45 | 46 | feature_map = x # 截断到这里,认为到这里是feature_map(局部特征) 47 | feature_map_encoder = Model(x_in, x) 48 | 49 | 50 | for i in range(2): 51 | x = Conv2D(z_dim, 52 | kernel_size=(3,3), 53 | padding='SAME')(x) 54 | x = BatchNormalization()(x) 55 | x = LeakyReLU(0.2)(x) 56 | 57 | x = GlobalMaxPooling2D()(x) # 全局特征 58 | 59 | z_mean = Dense(z_dim)(x) # 均值,也就是最终输出的编码 60 | z_log_var = Dense(z_dim)(x) # 方差,这里都是模仿VAE的 61 | 62 | 63 | encoder = Model(x_in, z_mean) # 总的编码器就是输出z_mean 64 | 65 | 66 | # 重参数技巧 67 | def sampling(args): 68 | z_mean, z_log_var = args 69 | u = K.random_normal(shape=K.shape(z_mean)) 70 | return z_mean + K.exp(z_log_var / 2) * u 71 | 72 | 73 | # 重参数层,相当于给输入加入噪声 74 | z_samples = Lambda(sampling)([z_mean, z_log_var]) 75 | prior_kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)) 76 | 77 | 78 | # shuffle层,打乱第一个轴 79 | def shuffling(x): 80 | idxs = K.arange(0, K.shape(x)[0]) 81 | idxs = tf.random_shuffle(idxs) 82 | return K.gather(x, idxs) 83 | 84 | 85 | # 与随机采样的特征拼接(全局) 86 | z_shuffle = Lambda(shuffling)(z_samples) 87 | z_z_1 = Concatenate()([z_samples, z_samples]) 88 | z_z_2 = Concatenate()([z_samples, z_shuffle]) 89 | 90 | # 与随机采样的特征拼接(局部) 91 | feature_map_shuffle = Lambda(shuffling)(feature_map) 92 | z_samples_repeat = RepeatVector(4 * 4)(z_samples) 93 | z_samples_map = Reshape((4, 4, z_dim))(z_samples_repeat) 94 | z_f_1 = Concatenate()([z_samples_map, feature_map]) 95 | z_f_2 = Concatenate()([z_samples_map, feature_map_shuffle]) 96 | 97 | 98 | # 全局判别器 99 | z_in = Input(shape=(z_dim*2,)) 100 | z = z_in 101 | z = Dense(z_dim, activation='relu')(z) 102 | z = Dense(z_dim, activation='relu')(z) 103 | z = Dense(z_dim, activation='relu')(z) 104 | z = Dense(1, activation='sigmoid')(z) 105 | 106 | GlobalDiscriminator = Model(z_in, z) 107 | 108 | z_z_1_scores = GlobalDiscriminator(z_z_1) 109 | z_z_2_scores = GlobalDiscriminator(z_z_2) 110 | global_info_loss = - K.mean(K.log(z_z_1_scores + 1e-6) + K.log(1 - z_z_2_scores + 1e-6)) 111 | 112 | 113 | # 局部判别器 114 | z_in = Input(shape=(None, None, z_dim*2)) 115 | z = z_in 116 | z = Dense(z_dim, activation='relu')(z) 117 | z = Dense(z_dim, activation='relu')(z) 118 | z = Dense(z_dim, activation='relu')(z) 119 | z = Dense(1, activation='sigmoid')(z) 120 | 121 | LocalDiscriminator = Model(z_in, z) 122 | 123 | z_f_1_scores = LocalDiscriminator(z_f_1) 124 | z_f_2_scores = LocalDiscriminator(z_f_2) 125 | local_info_loss = - K.mean(K.log(z_f_1_scores + 1e-6) + K.log(1 - z_f_2_scores + 1e-6)) 126 | 127 | 128 | # 用来训练的模型 129 | model_train = Model(x_in, [z_z_1_scores, z_z_2_scores, z_f_1_scores, z_f_2_scores]) 130 | model_train.add_loss(alpha * global_info_loss + beta * local_info_loss + gamma * prior_kl_loss) 131 | model_train.compile(optimizer=Adam(1e-3)) 132 | 133 | model_train.fit(x_train, epochs=100, batch_size=100) 134 | model_train.save_weights('total_model.tiny.imagenet.weights') 135 | 136 | 137 | # 输出编码器的特征 138 | zs = encoder.predict(x_train, verbose=True) 139 | zs.mean() # 查看均值(简单观察先验分布有没有达到效果) 140 | zs.std() # 查看方差(简单观察先验分布有没有达到效果) 141 | 142 | 143 | # 随机选一张图片,输出最相近的图片 144 | # 可以选用欧氏距离或者cos值 145 | def sample_knn(path): 146 | n = 10 147 | topn = 10 148 | figure1 = np.zeros((img_dim*n, img_dim*topn, 3)) 149 | figure2 = np.zeros((img_dim*n, img_dim*topn, 3)) 150 | zs_ = zs / (zs**2).sum(1, keepdims=True)**0.5 151 | for i in range(n): 152 | one = np.random.choice(len(x_train)) 153 | idxs = ((zs**2).sum(1) + (zs[one]**2).sum() - 2 * np.dot(zs, zs[one])).argsort()[:topn] 154 | for j,k in enumerate(idxs): 155 | digit = x_train[k] 156 | figure1[i*img_dim: (i+1)*img_dim, 157 | j*img_dim: (j+1)*img_dim] = digit 158 | idxs = np.dot(zs_, zs_[one]).argsort()[-n:][::-1] 159 | for j,k in enumerate(idxs): 160 | digit = x_train[k] 161 | figure2[i*img_dim: (i+1)*img_dim, 162 | j*img_dim: (j+1)*img_dim] = digit 163 | figure1 = (figure1 + 1) / 2 * 255 164 | figure1 = np.clip(figure1, 0, 255) 165 | figure2 = (figure2 + 1) / 2 * 255 166 | figure2 = np.clip(figure2, 0, 255) 167 | imageio.imwrite(path+'_l2.png', figure1) 168 | imageio.imwrite(path+'_cos.png', figure2) 169 | 170 | 171 | sample_knn('test') 172 | --------------------------------------------------------------------------------