├── README.md └── cnn_imdb_vib.py /README.md: -------------------------------------------------------------------------------- 1 | # Variational Information Bottleneck 2 | -------------------------------------------------------------------------------- /cnn_imdb_vib.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | 3 | from keras.preprocessing import sequence 4 | from keras.layers import * 5 | from keras.models import Model 6 | import keras.backend as K 7 | from keras.datasets import imdb 8 | 9 | 10 | class VIB(Layer): 11 | """变分信息瓶颈层 12 | """ 13 | def __init__(self, lamb, **kwargs): 14 | self.lamb = lamb 15 | super(VIB, self).__init__(**kwargs) 16 | def call(self, inputs): 17 | z_mean, z_log_var = inputs 18 | u = K.random_normal(shape=K.shape(z_mean)) 19 | kl_loss = - 0.5 * K.sum(K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), 0)) 20 | self.add_loss(self.lamb * kl_loss) 21 | u = K.in_train_phase(u, 0.) 22 | return z_mean + K.exp(z_log_var / 2) * u 23 | def compute_output_shape(self, input_shape): 24 | return input_shape[0] 25 | 26 | 27 | max_features = 5000 28 | maxlen = 400 29 | batch_size = 32 30 | embedding_dims = 50 31 | filters = 250 32 | kernel_size = 3 33 | epochs = 5 34 | 35 | 36 | (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features) 37 | x_train = sequence.pad_sequences(x_train, maxlen=maxlen) 38 | x_test = sequence.pad_sequences(x_test, maxlen=maxlen) 39 | 40 | 41 | seq_in = Input(shape=(maxlen,)) 42 | seq = seq_in 43 | 44 | seq = Embedding(max_features, embedding_dims)(seq) 45 | seq = Conv1D(filters, 46 | kernel_size, 47 | padding='valid', 48 | activation='relu', 49 | strides=1)(seq) 50 | seq = GlobalMaxPooling1D()(seq) 51 | 52 | z_mean = Dense(128)(seq) 53 | z_log_var = Dense(128)(seq) 54 | seq = VIB(0.1)([z_mean, z_log_var]) 55 | seq = Dense(1, activation='sigmoid')(seq) 56 | 57 | model = Model(seq_in, seq) 58 | model.compile(loss='binary_crossentropy', 59 | optimizer='adam', 60 | metrics=['accuracy']) 61 | 62 | model.fit(x_train, y_train, 63 | batch_size=batch_size, 64 | epochs=epochs, 65 | validation_data=(x_test, y_test)) 66 | 67 | 68 | """ 69 | # 去掉VIB后的模型 70 | 71 | seq_in = Input(shape=(maxlen,)) 72 | seq = seq_in 73 | 74 | seq = Embedding(max_features, embedding_dims)(seq) 75 | seq = Conv1D(filters, 76 | kernel_size, 77 | padding='valid', 78 | activation='relu', 79 | strides=1)(seq) 80 | seq = GlobalMaxPooling1D()(seq) 81 | seq = Dense(1, activation='sigmoid')(seq) 82 | 83 | model = Model(seq_in, seq) 84 | model.compile(loss='binary_crossentropy', 85 | optimizer='adam', 86 | metrics=['accuracy']) 87 | 88 | model.fit(x_train, y_train, 89 | batch_size=batch_size, 90 | epochs=epochs, 91 | validation_data=(x_test, y_test)) 92 | """ 93 | --------------------------------------------------------------------------------