├── src ├── __init__.py ├── encoder.py ├── decoder.py ├── utils.py ├── unsupervised_similarity_clustering.py ├── run_cluster_mouse_organs.py └── moe_utils.py ├── .gitattributes ├── data ├── indices_test.csv ├── indices_train.csv ├── labels_merged.csv ├── labels_prior_genes.csv └── data_merged_tpm.csv.bz2 ├── spec-file.yml ├── README.md └── LICENSE /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.csv filter=lfs diff=lfs merge=lfs -text 2 | *.csv.bz2 filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /data/indices_test.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b152dd5473c017a451399a75f9e7e32a43094594b221f4703412d909df5873bd 3 | size 18440 4 | -------------------------------------------------------------------------------- /data/indices_train.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:cc83b3706e4f012f2b9d98832768ecee15b9b9d5b8c4de3a46bc6b72e8afb61c 3 | size 73758 4 | -------------------------------------------------------------------------------- /data/labels_merged.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:586cfbb2ce21edf77f2f8202ca5443b62966aee0f9fc0533ae10cbc5f8131392 3 | size 147038 4 | -------------------------------------------------------------------------------- /data/labels_prior_genes.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2fc491adcc50a6a458fe3d58527fbe9a645d3d52f967014946dbbfc4dce89838 3 | size 241106 4 | -------------------------------------------------------------------------------- /data/data_merged_tpm.csv.bz2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a92500d14678594efae5b7f1a05c845250865d3508f2192c9bc19366ebc09d33 3 | size 348200104 4 | -------------------------------------------------------------------------------- /spec-file.yml: -------------------------------------------------------------------------------- 1 | name: moesimvae 2 | channels: 3 | - conda-forge/label/cf201901 4 | - conda-forge 5 | - https://repo.continuum.io/pkgs/free 6 | - defaults 7 | dependencies: 8 | - pip=21.0.1=py36hecd8cb5_0 9 | - pip: 10 | - six==1.12.0 11 | - scipy==1.5.4 12 | - numpy==1.19.5 13 | - pandas==1.1.5 14 | - matplotlib==3.3.4 15 | - tensorflow==1.13.1 16 | - theano==1.0.4 17 | - umap-learn==0.5.1 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MoE-Sim-VAE 2 | Mixture-of-Experts Variational Autoencoder for Clustering and Generating from Similarity-Based Representations on Single Cell Data 3 | 4 | ## Install conda environment 5 | ``` 6 | conda env create -f spec-file.yml 7 | conda activate moesimvae 8 | ``` 9 | 10 | ## Run MoE-Sim-VAE on scRNA-sequencing data clustering mouse organs 11 | ``` 12 | cd src 13 | python run_cluster_mouse_organs.py --dir_output PATH/TO/OUTPUT_DIRECTORY --loss_coef_kl_div_code_standard_gaussian 0.2 --loss_coef_clustering 0.8 14 | ``` -------------------------------------------------------------------------------- /src/encoder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.distributions as tfd 3 | from src.utils import initializers, activation_fcts 4 | 5 | 6 | class Encoder(object): 7 | 8 | 9 | def __init__(self, code_size, encoder_internal_size, activation_fct, dropout_rate, init_method='xavier', batch_normalization=True, depth_encoder=1): 10 | self.code_size = code_size 11 | self.encoder_internal_size = encoder_internal_size 12 | self.batch_normalization = batch_normalization 13 | self.init = initializers[init_method] 14 | self.activation_fct = activation_fcts[activation_fct] 15 | self.depth_encoder = depth_encoder 16 | self.dropout_rate = dropout_rate 17 | 18 | def build_encoder(self, x): 19 | x = tf.layers.flatten(x) 20 | noisy_x = tf.layers.dropout(x, rate=self.dropout_rate) 21 | 22 | for i in range(self.depth_encoder): 23 | my_dense = tf.layers.Dense(self.encoder_internal_size, kernel_initializer=self.init(), name="denes_" + str(i)) 24 | x = my_dense(x) 25 | noisy_x = my_dense(noisy_x) 26 | if self.batch_normalization: 27 | x = tf.layers.batch_normalization(x) 28 | noisy_x = tf.layers.batch_normalization(noisy_x) 29 | x = self.activation_fct(x) 30 | noisy_x = self.activation_fct(noisy_x) 31 | 32 | self.loc = tf.layers.dense(x, self.code_size, kernel_initializer=self.init()) 33 | self.loc_noisy = tf.layers.dense(noisy_x, self.code_size, kernel_initializer=self.init()) 34 | self.scale = tf.ones(tf.shape(self.loc), dtype=tf.float32) 35 | return tfd.MultivariateNormalDiag(loc=self.loc, scale_diag=self.scale).sample(), \ 36 | tfd.MultivariateNormalDiag(loc=self.loc_noisy, scale_diag=self.scale).sample(), self.loc, self.loc_noisy, self.scale 37 | 38 | -------------------------------------------------------------------------------- /src/decoder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.distributions as tfd 3 | from src.utils import initializers, activation_fcts, local_moe, ffn_expert_fn 4 | 5 | 6 | 7 | class Decoder(object): 8 | 9 | """ 10 | Creates an object of CellGanGen, which is the generator class for CellGan. 11 | 12 | Args: 13 | - num_experts: int 14 | Number of experts used in the CellGan_generator 15 | 16 | - num_markers: int 17 | Number of markers used in the experiment 18 | 19 | - num_filters: int 20 | Number of filters to be used in convolutional layer 21 | 22 | - noisy_gating: bool 23 | Whether to use the noise component in gating networks 24 | 25 | - noise_epsilon: float 26 | noise threshold 27 | 28 | - num_top: int 29 | Number of top experts to use for each example 30 | 31 | - init_method: str, default 'xavier' 32 | Method of initializing the weights 33 | 34 | """ 35 | 36 | def __init__(self, num_experts, num_markers, code_size, init_method='xavier', batch_normalization=True, decoder_depth=1, activation_fct='elu', decoder_internal_size=20): 37 | 38 | self.num_experts = num_experts 39 | self.num_markers = num_markers 40 | self.code_size = code_size 41 | 42 | 43 | self.batch_normalization = batch_normalization 44 | self.decoder_depth = decoder_depth 45 | self.activation_fct = activation_fcts[activation_fct] 46 | self.decoder_internal_size = decoder_internal_size 47 | 48 | self.init = initializers[init_method] 49 | 50 | 51 | def build_decoder(self, logits_activated, code, gates): 52 | 53 | # define moe function 54 | with tf.variable_scope('fn_gen_outputs'): 55 | self.moe_func = list() 56 | for i in range(self.num_experts): 57 | my_fun = ffn_expert_fn(output_size=self.num_markers, init=self.init, name='expert_' + str(i), 58 | activation_fct=self.activation_fct, depth=self.decoder_depth, 59 | decoder_internal_size=self.decoder_internal_size, batch_normalization=self.batch_normalization) 60 | self.moe_func.append(my_fun) 61 | 62 | gen_outputs = tf.map_fn(lambda x: local_moe(x[0], tf.reshape(x[1], (1, self.num_experts)), self.moe_func, name='gen_outputs'), (code, gates), dtype=(tf.float32, tf.float32)) 63 | 64 | return gen_outputs[0], gen_outputs[1] 65 | 66 | 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import matplotlib.pyplot as plt 4 | import os 5 | import tensorflow as tf 6 | import tensorflow.contrib.distributions as tfd 7 | from src.moe_utils import SparseDispatcher, flatten_all_but_last, Parallelism 8 | import six 9 | import numpy as np 10 | from sklearn.metrics import f1_score 11 | import argparse 12 | 13 | xav_init = tf.contrib.layers.xavier_initializer 14 | normal_init = tf.truncated_normal_initializer 15 | zero_init = tf.zeros_initializer 16 | 17 | # Different methods for initializing the data 18 | initializers = dict() 19 | initializers['xavier'] = xav_init 20 | initializers['normal'] = normal_init 21 | initializers['zeros'] = zero_init 22 | 23 | # Different activation functions 24 | activation_fcts = dict() 25 | activation_fcts['relu'] = tf.keras.activations.relu 26 | activation_fcts['elu'] = tf.keras.activations.elu 27 | activation_fcts['sigmoid'] = tf.keras.activations.sigmoid 28 | activation_fcts['hard_sigmoid'] = tf.keras.activations.hard_sigmoid 29 | activation_fcts['selu'] = tf.keras.activations.selu 30 | activation_fcts['softmax'] = tf.keras.activations.softmax 31 | activation_fcts['softplus'] = tf.keras.activations.softplus 32 | activation_fcts['softsign'] = tf.keras.activations.softsign 33 | activation_fcts['tanh'] = tf.keras.activations.tanh 34 | activation_fcts['LeakyRelu'] = tf.nn.leaky_relu 35 | 36 | 37 | 38 | 39 | def save_loss_plot(out_dir, loss, log=False, tag=''): 40 | """ 41 | Saves loss plot to output directory 42 | :param out_dir: str, output directory 43 | :param disc_loss: list, discriminator losses 44 | :param gen_loss: list, generator losses 45 | :return: no returns 46 | """ 47 | if log: 48 | filename = os.path.join(out_dir, 'loss_log_plot'+tag+'.pdf') 49 | else: 50 | filename = os.path.join(out_dir, 'loss_plot'+tag+'.pdf') 51 | plt.figure() 52 | if log: 53 | plt.plot(range(len(loss)), np.log(loss), 'r') 54 | else: 55 | plt.plot(range(len(loss)), loss, 'r') 56 | plt.xlabel('Iteration') 57 | plt.ylabel('Loss') 58 | plt.tight_layout() 59 | plt.savefig(filename) 60 | plt.close() 61 | 62 | 63 | 64 | def write_readme_file(params, dir_output, filename='ReadMe.txt'): 65 | fout = os.path.join(dir_output, filename) 66 | fo = open(fout, 'w') 67 | for k, v in params.items(): 68 | fo.write(str(k) + ' = ' + str(v) + '\n\n') 69 | fo.close() 70 | 71 | 72 | def generate_subset(inputs, 73 | num_cells_per_input, 74 | batch_size, 75 | weights=None, 76 | return_indices=False): 77 | """ 78 | Returns a random subset from input data of shape (batch_size, num_cells_per_input, num_markers) 79 | :param inputs: numpy array, the input ndarray to sample from 80 | :param num_cells_per_input: int, number of cells per multi-cell input 81 | :param batch_size: int, batch size of the subset 82 | :param weights: list of float, whether there is a preference for some cells 83 | :param return_indices: bool, whether to return subset indices or not 84 | :return: 85 | """ 86 | 87 | num_cells_total = inputs.shape[0] 88 | 89 | if weights is not None: 90 | indices = np.random.choice( 91 | num_cells_total, 92 | size=batch_size * num_cells_per_input, 93 | replace=True, 94 | p=weights) 95 | 96 | else: 97 | indices = np.random.choice( 98 | num_cells_total, 99 | size=batch_size * num_cells_per_input, 100 | replace=True) 101 | 102 | subset = inputs[indices, ] 103 | subset = np.reshape(subset, newshape=(batch_size, num_cells_per_input, -1)) 104 | 105 | if return_indices: 106 | return subset, indices 107 | 108 | else: 109 | return subset 110 | 111 | 112 | def str2bool(v): 113 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 114 | return True 115 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 116 | return False 117 | else: 118 | raise argparse.ArgumentTypeError('Boolean value expected.') 119 | 120 | 121 | 122 | def compute_f_measure(y_true, y_pred): 123 | """ 124 | Compute f-measure of subpopulation prediction results. 125 | :param y_true: 126 | :param y_pred: 127 | :return: double, f-measure 128 | """ 129 | 130 | y_true_unique = np.unique(y_true) 131 | y_pred_unique = np.unique(y_pred) 132 | 133 | N = len(y_true) 134 | f_measure_i = list() 135 | 136 | for i, y_i in enumerate(y_true_unique): 137 | f_measure_j = list() 138 | temp_ind_y = np.where(np.asarray(y_true) == y_i)[0] 139 | 140 | binary_y_i = np.zeros((N, )) 141 | binary_y_i[temp_ind_y] = 1 142 | 143 | n_c_i = len(temp_ind_y) 144 | for j, y_j in enumerate(y_pred_unique): 145 | temp_ind_y_j = np.where(np.asarray(y_pred) == y_j)[0] 146 | 147 | binary_y_j = np.zeros((N,)) 148 | binary_y_j[temp_ind_y_j] = 1 149 | 150 | f_measure_j.append(f1_score(binary_y_i, binary_y_j)) 151 | 152 | ind_max = np.argmax(np.asarray(f_measure_j)) 153 | f_measure_i.append(n_c_i/N*f_measure_j[ind_max]) 154 | 155 | return(np.sum(f_measure_i)) 156 | 157 | 158 | def sigmoid(x): 159 | return 1 / (1 + np.exp(-x)) 160 | 161 | 162 | 163 | def ffn_expert_fn(output_size, 164 | init, 165 | name, 166 | activation_fct, 167 | depth=1, 168 | decoder_internal_size=200, 169 | batch_normalization=True): 170 | 171 | def my_fn(x): 172 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE): 173 | 174 | for i in range(depth): 175 | x = tf.layers.dense(x, decoder_internal_size, kernel_initializer=init()) 176 | if batch_normalization: 177 | x = tf.layers.batch_normalization(x) 178 | x = activation_fct(x) 179 | 180 | output = tf.layers.dense(inputs=x, 181 | units=output_size, 182 | activation=None, 183 | name='gen_output', 184 | kernel_initializer = init()) 185 | return output 186 | 187 | return my_fn 188 | 189 | 190 | def local_moe(x, 191 | gates, 192 | list_moe_func, 193 | name=None, 194 | additional_dispatch_params=None): 195 | 196 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE): 197 | dispatcher = SparseDispatcher(len(list_moe_func), tf.ones((1, len(list_moe_func)))) 198 | 199 | expert_kwargs = {} 200 | expert_kwargs["x"] = dispatcher.dispatch(flatten_all_but_last(x)) 201 | for k, v in six.iteritems(additional_dispatch_params or {}): 202 | v = flatten_all_but_last(v) 203 | expert_kwargs[k] = dispatcher.dispatch(v) 204 | 205 | ep = Parallelism(['existing_device']*len(list_moe_func), reuse=tf.AUTO_REUSE) 206 | expert_outputs = ep(list_moe_func, **expert_kwargs) 207 | 208 | expert_outputs = tf.stack(expert_outputs) 209 | expert_outputs = tf.squeeze(expert_outputs) 210 | 211 | expert_outputs_combined = tf.reduce_sum(tf.matmul(gates, expert_outputs), axis=0) 212 | 213 | 214 | return expert_outputs, expert_outputs_combined 215 | -------------------------------------------------------------------------------- /src/unsupervised_similarity_clustering.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.layers import xavier_initializer 3 | import tensorflow.contrib.distributions as tfd 4 | import numpy as np 5 | 6 | 7 | class USC: 8 | 9 | def __init__(self, batch_size, num_markers, k, loss_coef_kernel, loss_coef_depict, trainings_data, 10 | learning_rate, code_size, cluster_network, gradient_stop_on_data=False, predict_class_similarity=False): 11 | self.model_hparams = dict() 12 | self.model_hparams['batch_size'] = batch_size 13 | self.model_hparams['k'] = k 14 | self.model_hparams['loss_coef_kernel'] = loss_coef_kernel 15 | self.model_hparams['loss_coef_depict'] = loss_coef_depict 16 | self.model_hparams['learning_rate'] = learning_rate 17 | self.model_hparams['code_size'] = code_size 18 | self.data_train = trainings_data 19 | self.data_predict = trainings_data 20 | if gradient_stop_on_data: 21 | self.data_train = tf.stop_gradient(self.data_train) 22 | self.data_predict = tf.stop_gradient(self.data_predict) 23 | self.cluster_network = cluster_network 24 | self.predict_class_similarity = predict_class_similarity 25 | 26 | self.__define_model() 27 | 28 | 29 | 30 | 31 | def __create_variables(self): 32 | self.k_ij = tf.placeholder(tf.float32, shape=[self.model_hparams['batch_size'], self.model_hparams['batch_size']], name='k_ij') 33 | if self.predict_class_similarity: 34 | self.class_similarity = tf.placeholder(dtype=tf.float32, shape=[self.model_hparams['batch_size'], self.model_hparams['k']], name='class_similarity') 35 | 36 | self.w_ik_predict = tf.Variable(1, validate_shape=False, dtype=tf.float32) 37 | self.w_ik = tf.get_variable(dtype=tf.float32, name='prob_assignment', shape=[self.model_hparams['batch_size'], self.model_hparams['k']], validate_shape=True, initializer=xavier_initializer()) 38 | self.alpha_k = tf.get_variable(dtype=tf.float32, name='mixture_weights', shape=[self.model_hparams['k']], validate_shape=True, initializer=xavier_initializer()) 39 | self.mu_k = tf.get_variable(dtype=tf.float32, name='mean', shape=[self.model_hparams['k'], self.model_hparams['code_size']], validate_shape=True, initializer=xavier_initializer()) 40 | self.sigma_k = tf.get_variable(dtype=tf.float32, name='covariance', shape=[self.model_hparams['k'], self.model_hparams['code_size']], validate_shape=True, initializer=xavier_initializer()) 41 | 42 | self.softmax_parameter_represenation = tf.get_variable(dtype=tf.float32, name='depict_represenation', shape=[self.model_hparams['code_size'], self.model_hparams['k']]) 43 | 44 | 45 | def __define_model(self): 46 | 47 | self.__create_variables() 48 | 49 | ################################################## 50 | # predictions 51 | preds_train, reps_train, preds, reps, reps_train_noisy = self.cluster_network(self.data_train, self.data_predict) 52 | 53 | self.w_ik = preds_train 54 | self.representation = reps_train 55 | self.w_ik = tf.keras.activations.softmax(self.w_ik) 56 | self.w_ik_new_similarity = tf.matmul(self.w_ik, self.w_ik, transpose_b=True) 57 | w_ik_new_assignments = tf.reduce_sum(self.w_ik, axis=0) / self.model_hparams['batch_size'] 58 | self.cluster_assignments = tf.argmax(self.w_ik, axis=1) 59 | 60 | w_ik_new_predict = preds 61 | self.representation_predict = reps 62 | self.w_ik_predict = tf.keras.activations.softmax(w_ik_new_predict) 63 | 64 | ################################################## 65 | # depict auxilary fct on cluster predictions 66 | q_i_k_norm_factor_sum_code = tf.sqrt(tf.reduce_sum(tf.clip_by_value(self.w_ik, 1e-30, 1), axis=0)) 67 | q_ik = tf.squeeze(tf.map_fn(lambda x: x / q_i_k_norm_factor_sum_code, self.w_ik, dtype=tf.float32)) 68 | q_i_k_norm_factor_code_noisy = tf.reduce_sum(q_ik, axis=1) 69 | q_ik = q_ik / tf.stack([q_i_k_norm_factor_code_noisy] * self.model_hparams['k'], axis=1) 70 | q_ik = tf.clip_by_value(q_ik, 1e-30, 1) 71 | 72 | ################################################## 73 | # depict on cluster representation 74 | self.loss_depict_represenation = tf.constant(0, dtype=tf.float32) 75 | if not reps_train_noisy is None: 76 | p_i_k_code = tf.exp(tf.matmul(reps_train, self.softmax_parameter_represenation)) 77 | p_i_k_norm_factor_code = tf.reduce_sum(p_i_k_code, axis=1) 78 | p_i_k_code = p_i_k_code / tf.stack([p_i_k_norm_factor_code]*self.model_hparams['k'], axis=1) 79 | p_i_k_code = tf.clip_by_value(p_i_k_code, 1e-30, 1) 80 | 81 | p_i_k_code_noisy = tf.exp(tf.matmul(reps_train_noisy, self.softmax_parameter_represenation)) 82 | p_i_k_norm_factor_code = tf.reduce_sum(p_i_k_code_noisy, axis=1) 83 | p_i_k_code_noisy = p_i_k_code_noisy / tf.stack([p_i_k_norm_factor_code]*self.model_hparams['k'], axis=1) 84 | p_i_k_code_noisy = tf.clip_by_value(p_i_k_code_noisy, 1e-30, 1) 85 | 86 | q_i_k_norm_factor_sum_code = tf.sqrt(tf.reduce_sum(p_i_k_code_noisy, axis=0)) 87 | q_i_k_code_noisy = tf.squeeze(tf.map_fn(lambda x : x / q_i_k_norm_factor_sum_code, p_i_k_code_noisy, dtype=tf.float32)) 88 | q_i_k_norm_factor_code_noisy = tf.reduce_sum(q_i_k_code_noisy, axis=1) 89 | q_i_k_code_noisy = q_i_k_code_noisy / tf.stack([q_i_k_norm_factor_code_noisy]*self.model_hparams['k'], axis=1) 90 | q_i_k_code_noisy = tf.clip_by_value(q_i_k_code_noisy, 1e-30, 1) 91 | 92 | self.loss_depict_represenation = - 1.0/self.model_hparams['batch_size'] * tf.reduce_sum(tf.reduce_sum(q_i_k_code_noisy * tf.log(p_i_k_code), axis=1), axis=0) 93 | 94 | 95 | 96 | ################################################## 97 | # loss fctions 98 | similarity_cluster_prob = tf.matmul(self.k_ij, self.w_ik) 99 | similarity_cluster_prob = tf.keras.layers.Softmax(axis=1)(similarity_cluster_prob) 100 | self.loss_entropy_batch_sample = tf.reduce_mean(-1*tf.reduce_sum(similarity_cluster_prob * tf.log(similarity_cluster_prob), axis=1)) 101 | 102 | self.loss_depict = - 1/ self.model_hparams['batch_size'] * tf.reduce_sum(tf.reduce_sum(q_ik * tf.log(tf.clip_by_value(self.w_ik, 1e-30, 1)), axis=1), axis=0) 103 | 104 | if self.predict_class_similarity: 105 | self.loss_class_similarity = tf.reduce_mean(tf.keras.backend.binary_crossentropy(target=self.class_similarity, output=self.w_ik)) 106 | else: 107 | self.loss_kernel_similarity = tf.reduce_mean(tf.keras.backend.binary_crossentropy(target=self.k_ij, output=self.w_ik_new_similarity)) 108 | 109 | 110 | if self.predict_class_similarity: 111 | self.loss_clustering = self.model_hparams['loss_coef_kernel'] * self.loss_class_similarity + \ 112 | self.model_hparams['loss_coef_depict'] * self.loss_depict + \ 113 | self.model_hparams['loss_coef_depict'] * self.loss_depict_represenation 114 | else: 115 | self.loss_clustering = self.model_hparams['loss_coef_kernel'] * self.loss_kernel_similarity + \ 116 | self.model_hparams['loss_coef_depict'] * self.loss_depict + \ 117 | self.model_hparams['loss_coef_depict'] * self.loss_depict_represenation 118 | self.optimize_clustering = tf.train.AdamOptimizer(self.model_hparams['learning_rate'], beta1=0.9, beta2=0.99).minimize(self.loss_clustering) 119 | 120 | ################################################## 121 | # M-step 122 | self.N_k = tf.clip_by_value(tf.reduce_sum(self.w_ik, axis=0), 1, self.model_hparams['batch_size']) 123 | alpha_k_new = self.N_k / self.model_hparams['batch_size'] 124 | self.update_alpha_k = tf.assign(self.alpha_k, alpha_k_new) 125 | 126 | mu_k_new_unnormalized = tf.matmul(self.w_ik, self.data_train, transpose_a=True) 127 | mu_k_new_norm_factor = tf.stack([self.N_k] * self.model_hparams['code_size'], axis=1) 128 | mu_k_new = mu_k_new_unnormalized / mu_k_new_norm_factor 129 | self.updated_mu_k = tf.assign(self.mu_k, mu_k_new) 130 | 131 | self.log_likelihood = self.__gmm_log_likelihood(self.data_train) 132 | 133 | 134 | 135 | def __pairwise_euclidean_distance(self, A, B): 136 | """ 137 | Computes pairwise distances between each elements of A and each elements of B. 138 | Args: 139 | A, [m,d] matrix 140 | B, [n,d] matrix 141 | Returns: 142 | D, [m,n] matrix of pairwise distances 143 | """ 144 | with tf.variable_scope('pairwise_dist'): 145 | # squared norms of each row in A and B 146 | na = tf.reduce_sum(tf.square(A), 1) 147 | nb = tf.reduce_sum(tf.square(B), 1) 148 | 149 | # na as a row and nb as a co"lumn vectors 150 | na = tf.reshape(na, [-1, 1]) 151 | nb = tf.reshape(nb, [1, -1]) 152 | 153 | # return pairwise euclidead difference matrix 154 | D = tf.sqrt(tf.maximum(na - 2 * tf.matmul(A, B, False, True) + nb, 0.0)) 155 | return D 156 | 157 | def __gmm_log_likelihood(self, data): 158 | # evaluate based on log likelihood 159 | mixture_probabilities = tf.transpose( 160 | tf.map_fn(lambda x: tfd.MultivariateNormalDiag(loc=x[0], scale_diag=x[1]).prob(data) * x[2], 161 | (self.mu_k, self.sigma_k, self.alpha_k), dtype=tf.float32)) 162 | log_likelihood = tf.reduce_sum(tf.log(tf.reduce_sum(mixture_probabilities, axis=1)), axis=0) 163 | return log_likelihood 164 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/run_cluster_mouse_organs.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import matplotlib.pyplot as plt 4 | import argparse 5 | import os 6 | import sys 7 | from datetime import datetime as dt 8 | import pandas as pd 9 | from shutil import copyfile 10 | 11 | import numpy as np 12 | import tensorflow as tf 13 | from sklearn.decomposition import PCA 14 | from sklearn.preprocessing import MinMaxScaler 15 | from sklearn.neighbors import NearestNeighbors 16 | from sklearn.metrics import normalized_mutual_info_score 17 | from scipy.stats import multivariate_normal 18 | from collections import OrderedDict 19 | import umap 20 | 21 | ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../')) 22 | sys.path.insert(0, ROOT_DIR) 23 | 24 | from src.encoder import Encoder 25 | from src.decoder import Decoder 26 | from src.utils import initializers, sigmoid, compute_f_measure, str2bool, generate_subset, write_readme_file, activation_fcts, save_loss_plot 27 | from src.unsupervised_similarity_clustering import USC 28 | 29 | def main(): 30 | 31 | parser = argparse.ArgumentParser() 32 | 33 | # IO parameters 34 | parser.add_argument('--dir_output', help='Directory where output will be generated.') 35 | parser.add_argument('--plotting', type=str2bool, default=True, help='Whether to always plot the updated model.') 36 | 37 | ## model hyperparameters 38 | parser.add_argument('--batch_normalization', type=str2bool, default=True, help='Boolean whether to include batch normalization.') 39 | parser.add_argument('--encoder_depth', type=int, default=1, help='Depth of encoder.') 40 | parser.add_argument('--encoder_internal_size', type=int, default=100, help='Internal size of encoder.') 41 | parser.add_argument('--decoder_depth', type=int, default=1, help='Depth of decoder.') 42 | parser.add_argument('--decoder_internal_size', type=int, default=100, help='Internal size of decoder.') 43 | parser.add_argument('--depth_cluster_network', type=int, default=1, help='Depth of clustering network.') 44 | parser.add_argument('--internal_size_cluster_network', type=int, default=100, help='Internal size of clustering network.') 45 | parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate.') 46 | parser.add_argument('--dropout_rate', type=float, default=.5, help='Dropout rate.') 47 | parser.add_argument('--code_size', type=int, default=20, help='Code size.') 48 | parser.add_argument('--num_experts', default=7, type=int, help='Number of experts in model.') 49 | parser.add_argument('--batch_size', type=int, default=256, help='Number of cells per multi-cell input.') 50 | parser.add_argument('--batch_size_test', type=int, default=512, help='Number of samples to test model.') 51 | parser.add_argument('--num_iterations', type=int, default=20000, help='Number of trainings iterations.') 52 | parser.add_argument('--activation_fct', default='elu', type=str, choices=['relu', 'elu', 'sigmoid', 'hard_sigmoid', 'selu', 'softmax', 'softplus', 'softsign', 'tanh', 'LeakyRelu'], help='Choice of activation functions to use in decoder.') 53 | 54 | ## loss coefficients 55 | parser.add_argument('--loss_coef_reconst_data', type=float, default=1, help='Loss coefficient of reconstruction loss.') 56 | parser.add_argument('--loss_coef_kl_div_code_standard_gaussian', type=float, default=1, help='Locc coefficient of Kullback Leibler divergence for GMM in latent space.') 57 | parser.add_argument('--loss_coef_clustering', type=float, default=1, help='Loss coefficient of clustering loss.') 58 | parser.add_argument('--loss_coef_cluster_kernel', type=float, default=1, help='Loss coefficient of kernel similarity.') 59 | parser.add_argument('--loss_coef_cluster_depict', type=float, default=1, help='Loss coefficent of DEPICT loss in clustering loss.') 60 | 61 | ## preprocessing 62 | parser.add_argument('--quantile_prob', type=float, default=.2, help='Percentile for data filtering.') 63 | 64 | args = parser.parse_args() 65 | 66 | model_hparams = { 67 | 'path_data' : os.path.join(ROOT_DIR, 'data'), 68 | 'plotting' : args.plotting, 69 | 'batch_normalization' : args.batch_normalization, 70 | 'encoder_depth' : args.encoder_depth, 71 | 'encoder_internal_size' : args.encoder_internal_size, 72 | 'decoder_depth' : args.decoder_depth, 73 | 'decoder_internal_size' : args.decoder_internal_size, 74 | 'depth_cluster_network' : args.depth_cluster_network, 75 | 'internal_size_cluster_network' : args.internal_size_cluster_network, 76 | 'learning_rate' : args.learning_rate, 77 | 'dropout_rate' : args.dropout_rate, 78 | 'code_size' : args.code_size, 79 | 'num_experts' : args.num_experts, 80 | 'batch_size': args.batch_size, 81 | 'batch_size_test' : args.batch_size_test, 82 | 'num_iterations' : args.num_iterations, 83 | 'activation_fct' : args.activation_fct, 84 | 'loss_coef_reconst_data' : args.loss_coef_reconst_data, 85 | 'loss_coef_kl_div_code_standard_gaussian' : args.loss_coef_kl_div_code_standard_gaussian, 86 | 'loss_coef_clustering' : args.loss_coef_clustering, 87 | 'loss_coef_cluster_kernel' : args.loss_coef_cluster_kernel, 88 | 'loss_coef_cluster_depict' : args.loss_coef_cluster_depict, 89 | 'quantile_prob' : args.quantile_prob 90 | } 91 | 92 | # Setup the output directory 93 | experiment_name = dt.now().strftime('%d-%m_%H-%M-%S-%f') 94 | random_id = np.random.choice(list(range(99999)), 1) 95 | experiment_name += '_' + str(format(random_id[0], '05d')) 96 | dir_output = os.path.join(args.dir_output, experiment_name) 97 | 98 | if not os.path.exists(dir_output): 99 | os.makedirs(dir_output) 100 | 101 | # copy executeable in output dir 102 | copyfile(sys.argv[0], os.path.join(dir_output, os.path.basename(sys.argv[0]))) 103 | 104 | 105 | # load data 106 | data = pd.read_csv(os.path.join(model_hparams['path_data'], 'data_merged_tpm.csv.bz2'), index_col=0, compression='bz2') 107 | labels = pd.read_csv(os.path.join(model_hparams['path_data'], 'labels_merged.csv'), header=None).values.squeeze() 108 | ind_train = pd.read_csv(os.path.join(model_hparams['path_data'], 'indices_train.csv'), header=None).values.squeeze() 109 | ind_test = pd.read_csv(os.path.join(model_hparams['path_data'], 'indices_test.csv'), header=None).values.squeeze() 110 | prior = pd.read_csv(os.path.join(model_hparams['path_data'],'labels_prior_genes.csv'), dtype=np.float32) 111 | 112 | training_data = data.values[ind_train] 113 | training_labels = labels[ind_train].squeeze() 114 | test_data = data.values[ind_test] 115 | test_labels = labels[ind_test].squeeze() 116 | training_prior = prior.values[ind_train] 117 | 118 | # filter data for points without prior knowledge 119 | training_prior_sum = np.sum(training_prior, axis=1) 120 | ind_train_prior = np.where(training_prior_sum == 1)[0] 121 | training_data = training_data[ind_train_prior] 122 | training_labels = training_labels[ind_train_prior] 123 | training_prior = training_prior[ind_train_prior] 124 | 125 | training_labels_unique = np.unique(training_labels) 126 | test_labels_unique = np.unique(test_labels) 127 | 128 | 129 | # define color mapping and marker mapping 130 | cmap = matplotlib.cm.get_cmap('viridis') 131 | colors_labels = cmap(np.linspace(0, 1, len(training_labels_unique))) 132 | 133 | dict_colors = dict() 134 | dict_marker = dict() 135 | for i, label in enumerate(training_labels_unique): 136 | dict_marker[label] = '$' + str(label) + '$' 137 | dict_colors[label] = colors_labels[i].reshape((1, 4)) 138 | 139 | 140 | cmap = matplotlib.cm.get_cmap('viridis') 141 | colors_experts = cmap(np.linspace(0, 1, model_hparams['num_experts'])) 142 | 143 | 144 | dict_marker_experts = dict() 145 | dict_colors_experts = dict() 146 | for exp_id in range(model_hparams['num_experts']): 147 | dict_marker_experts[exp_id] = '$' + str(exp_id) + '$' 148 | dict_colors_experts[exp_id] = colors_experts[exp_id].reshape((1, 4)) 149 | 150 | # scale data between 0 and 1 151 | scaler = MinMaxScaler(feature_range=(0, 1)) 152 | scaler.fit(training_data) 153 | training_data_scaled = scaler.transform(training_data) 154 | test_data_scaled = scaler.transform(test_data) 155 | 156 | # filter vor genes with low variance 157 | sum_gene_count_total = np.sum(training_data, axis=0) 158 | var_gene = np.var(training_data, axis=0) 159 | 160 | ind_filter_count = np.where(sum_gene_count_total > np.quantile(sum_gene_count_total, model_hparams['quantile_prob']))[0] 161 | ind_filter_var = np.where(var_gene > np.quantile(var_gene, model_hparams['quantile_prob']))[0] 162 | ind_gene_filter = np.unique([ind_filter_count, ind_filter_var]) 163 | 164 | training_data_scaled = training_data_scaled[:, ind_gene_filter] 165 | test_data_scaled = test_data_scaled[:, ind_gene_filter] 166 | 167 | model_hparams['num_markers'] = training_data_scaled.shape[1] 168 | 169 | # fit pca 170 | pca = PCA() 171 | pca = pca.fit(training_data_scaled) 172 | pca_transform = pca.transform(training_data_scaled) 173 | 174 | # fit umap 175 | um = umap.UMAP() 176 | um = um.fit(training_data_scaled) 177 | um_transform = um.transform(training_data_scaled) 178 | 179 | plt.figure() 180 | for i, sub in enumerate(training_labels_unique): 181 | temp_ind = np.where(training_labels == sub)[0] 182 | plt.scatter(um_transform[temp_ind, 0], um_transform[temp_ind, 1], s=30, marker=dict_marker[sub], c=dict_colors[sub], label=sub) 183 | plt.xlabel('UM1') 184 | plt.ylabel('UM2') 185 | plt.legend(loc='best', fontsize=14) 186 | plt.tight_layout() 187 | plt.savefig(os.path.join(dir_output, 'umap_data.pdf')) 188 | plt.close() 189 | 190 | plt.figure() 191 | for i, sub in enumerate(training_labels_unique): 192 | temp_ind = np.where(training_labels == sub)[0] 193 | plt.scatter(pca_transform[temp_ind, 0], pca_transform[temp_ind, 1], s=30, marker=dict_marker[sub], c=dict_colors[sub], label=sub) 194 | plt.xlabel('PC1') 195 | plt.ylabel('PC2') 196 | plt.legend(loc='best', fontsize=14) 197 | plt.tight_layout() 198 | plt.savefig(os.path.join(dir_output, 'pca_data.pdf')) 199 | plt.close() 200 | 201 | 202 | ################################ 203 | # define model 204 | data = tf.placeholder(tf.float32, [None, model_hparams['num_markers']]) 205 | 206 | # define encoder 207 | with tf.variable_scope('encoder', reuse=tf.AUTO_REUSE): 208 | encoder = Encoder(code_size=model_hparams['code_size'], encoder_internal_size=model_hparams['encoder_internal_size'], activation_fct=model_hparams['activation_fct'], init_method='xavier', batch_normalization=model_hparams['batch_normalization'], depth_encoder=model_hparams['encoder_depth'], dropout_rate=model_hparams['dropout_rate']) 209 | make_encoder = tf.make_template('encoder', encoder.build_encoder) 210 | code, code_noisy, code_loc, code_loc_noisy, code_scale = make_encoder(data) 211 | 212 | # define clustering network 213 | with tf.variable_scope('clustering', reuse=tf.AUTO_REUSE): 214 | # clustering network 215 | def network(x, y): 216 | for i in range(model_hparams['depth_cluster_network']): 217 | dense = tf.keras.layers.Dense(model_hparams['internal_size_cluster_network'], activation=activation_fcts[model_hparams['activation_fct']], kernel_initializer=initializers['xavier']()) 218 | x = dense(x) 219 | y = dense(y) 220 | dropout = tf.keras.layers.Dropout(model_hparams['dropout_rate']) 221 | x = dropout(x) 222 | y = dropout(y) 223 | dense = tf.keras.layers.Dense(model_hparams['num_experts'], kernel_initializer=initializers['xavier']()) 224 | 225 | output_x = dense(x) 226 | output_y = dense(y) 227 | return output_x, x, output_y, y, None 228 | 229 | usc = USC(batch_size=model_hparams['batch_size'], num_markers=model_hparams['num_markers'], k=model_hparams['num_experts'], loss_coef_kernel=model_hparams['loss_coef_cluster_kernel'], loss_coef_depict=model_hparams['loss_coef_cluster_depict'], trainings_data=code, learning_rate=model_hparams['learning_rate'], code_size=model_hparams['code_size'], cluster_network=network, predict_class_similarity=True) 230 | 231 | centroids = usc.mu_k 232 | assignments = usc.cluster_assignments 233 | gates_assignments = tf.one_hot(assignments, model_hparams['num_experts'], axis=1) 234 | loss_clustering_usc = usc.loss_clustering 235 | 236 | ################################ 237 | # decoder 238 | with tf.variable_scope('decoder'): 239 | decoder = Decoder(num_experts=model_hparams['num_experts'], num_markers=model_hparams['num_markers'], code_size=model_hparams['code_size'], batch_normalization=model_hparams['batch_normalization'], decoder_depth=model_hparams['decoder_depth'], activation_fct=model_hparams['activation_fct'], decoder_internal_size=model_hparams['decoder_internal_size']) 240 | make_decoder = tf.make_template('decoder', decoder.build_decoder) 241 | gen_output, gen_output_gated = make_decoder(code, code, gates_assignments) 242 | 243 | ################################ 244 | # KL loss for standard gaussian with cluster mean 245 | centroids_expanded = tf.expand_dims(centroids, axis=0) 246 | code_expanded = tf.expand_dims(code, axis=1) 247 | dist_samples = code_expanded - centroids_expanded 248 | dist_samples_squared = tf.square(dist_samples) 249 | gates_assignments_expanded = tf.expand_dims(gates_assignments, axis=-1) 250 | dist_samples_mask = dist_samples_squared * gates_assignments_expanded 251 | 252 | std_unnormalized = tf.reduce_sum(dist_samples_mask, axis=0) 253 | normalizing_factors = tf.clip_by_value(tf.reduce_sum(gates_assignments_expanded, axis=0), 1, 1e20) 254 | std_normalized = std_unnormalized / normalizing_factors 255 | std_normalized = tf.where(tf.equal(std_normalized, 0.), tf.ones_like(std_normalized), std_normalized) 256 | 257 | kl_divergence_code_standard_gaussian = 0.5 * (tf.reduce_sum(std_normalized, axis=1) - model_hparams['code_size'] * tf.ones(model_hparams['num_experts']) - tf.reduce_sum(tf.log(std_normalized), axis=1)) 258 | loss_kl_divergence_code_standard_gaussian = tf.reduce_mean(kl_divergence_code_standard_gaussian) 259 | 260 | ################################ 261 | # reconstruction loss for data 262 | loss_reconst_data = tf.reduce_mean(tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=gen_output_gated, labels=data), axis=1)) 263 | 264 | ################################ 265 | # model loss function 266 | loss_model = model_hparams['loss_coef_reconst_data'] * loss_reconst_data + \ 267 | model_hparams['loss_coef_clustering'] * loss_clustering_usc + \ 268 | model_hparams['loss_coef_kl_div_code_standard_gaussian'] * loss_kl_divergence_code_standard_gaussian 269 | 270 | ################################ 271 | # variable scopes 272 | variables_encoder = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='encoder') 273 | variables_decoder = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='decoder') 274 | variables_usc_clustering = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='clustering') 275 | variables_depict = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='depict') 276 | 277 | ################################ 278 | # define optimzers 279 | optimize_model = tf.train.AdamOptimizer(model_hparams['learning_rate'], beta1=0.9, beta2=0.99).minimize(loss_model, var_list=[variables_encoder, variables_decoder, variables_depict, variables_usc_clustering]) 280 | 281 | # write down current readme file 282 | write_readme_file(model_hparams, dir_output) 283 | 284 | ################################ 285 | # start training 286 | list_loss_model = list() 287 | list_loss_reconst_data = list() 288 | list_loss_kl_div_standard_gaussian = list() 289 | list_loss_clustering_usc = list() 290 | readme_current_training = dict() 291 | with tf.Session() as sess: 292 | 293 | sess.run(tf.global_variables_initializer()) 294 | saver = tf.train.Saver() 295 | 296 | for iteration in range(model_hparams['num_iterations']): 297 | real_batch, indices_batch = generate_subset(inputs=training_data_scaled, num_cells_per_input=model_hparams['batch_size'], weights=None, batch_size=1, return_indices=True) 298 | real_batch = np.squeeze(real_batch) 299 | labels_class_similarity = training_prior[indices_batch] 300 | 301 | ####################### 302 | # update model 303 | _ = sess.run([optimize_model], feed_dict={data: real_batch, usc.class_similarity : labels_class_similarity}) 304 | 305 | temp_assignments, temp_loss_model, temp_loss_reconst_data, temp_loss_kl_divergence_code_standard_gaussian, temp_loss_clustering_usc = sess.run([assignments, loss_model, loss_reconst_data, loss_kl_divergence_code_standard_gaussian, loss_clustering_usc], feed_dict={data: real_batch, usc.class_similarity : labels_class_similarity}) 306 | 307 | if iteration % 10 == 0: 308 | f_measure = compute_f_measure(training_labels[indices_batch], temp_assignments) 309 | nmi = normalized_mutual_info_score(training_labels[indices_batch], temp_assignments) 310 | 311 | if np.isnan(temp_loss_model): 312 | raise Exception('NaN error in loss! -.-') 313 | 314 | # save loss plots 315 | list_loss_model.append(temp_loss_model) 316 | save_loss_plot(dir_output, list_loss_model, tag='_model') 317 | list_loss_reconst_data.append(temp_loss_reconst_data) 318 | save_loss_plot(dir_output, list_loss_reconst_data, tag='_reconst_data') 319 | list_loss_kl_div_standard_gaussian.append(temp_loss_kl_divergence_code_standard_gaussian) 320 | save_loss_plot(dir_output, list_loss_kl_div_standard_gaussian, tag='_kl_divergence_code_standard_gaussian') 321 | list_loss_clustering_usc.append(temp_loss_clustering_usc) 322 | save_loss_plot(dir_output, list_loss_clustering_usc, tag='_clustering_usc') 323 | 324 | if iteration == 0: 325 | best_loss = temp_loss_model 326 | saver.save(sess, os.path.join(dir_output, 'model.ckpt')) 327 | else: 328 | if (temp_loss_model < best_loss): 329 | best_loss = temp_loss_model 330 | saver.save(sess, os.path.join(dir_output, 'model.ckpt')) 331 | 332 | if model_hparams['plotting']: 333 | if iteration % 1000 == 0: 334 | temp_ind_test = np.random.choice(range(len(test_data_scaled)), model_hparams['batch_size_test'], replace=False) 335 | 336 | temp_gates, temp_code, temp_gen_output_gated, temp_assignments, temp_centroids = sess.run([gates_assignments, code, gen_output_gated, assignments, centroids],feed_dict={data: test_data_scaled[temp_ind_test]}) 337 | 338 | if not os.path.exists(os.path.join(dir_output, str(iteration))): 339 | os.makedirs(os.path.join(dir_output, str(iteration))) 340 | 341 | saver.restore(sess, os.path.join(dir_output, 'model.ckpt')) 342 | expert_id = np.argmax(temp_gates, axis=1) 343 | 344 | # plot frequency of experts 345 | temp_dir_output = os.path.join(dir_output, str(iteration), 'experts') 346 | if not os.path.exists(temp_dir_output): 347 | os.makedirs(temp_dir_output) 348 | 349 | expert_frequency = [list(expert_id).count(x) for x in range(model_hparams['num_experts'])] 350 | expert_frequency = expert_frequency / np.sum(expert_frequency) 351 | plt.figure() 352 | plt.bar(range(model_hparams['num_experts']), expert_frequency, align='center') 353 | plt.xticks(range(model_hparams['num_experts']),[str(x) for x in range(model_hparams['num_experts'])]) 354 | plt.xlabel('Expert ID') 355 | plt.ylabel('Expert weight') 356 | plt.tight_layout() 357 | plt.savefig(os.path.join(temp_dir_output, 'barplot_expert_weights.pdf')) 358 | plt.close() 359 | 360 | for i, label in enumerate(test_labels_unique): 361 | temp_ind = np.where(test_labels[temp_ind_test] == label)[0] 362 | plt.figure() 363 | temp_test_gates_arg_max = list(expert_id[temp_ind]) 364 | plt.bar(range(model_hparams['num_experts']), [list(temp_test_gates_arg_max).count(x) for x in range(model_hparams['num_experts'])], tick_label=range(model_hparams['num_experts'])) 365 | plt.xlabel('Expert ID') 366 | plt.ylabel('Frequency selected') 367 | plt.tight_layout() 368 | plt.savefig(os.path.join(temp_dir_output, 'bar_gates_' + str(label) + '.pdf')) 369 | plt.close() 370 | 371 | # plot code 372 | temp_dir_output = os.path.join(dir_output, str(iteration), 'code') 373 | if not os.path.exists(temp_dir_output): 374 | os.makedirs(temp_dir_output) 375 | 376 | pca_code = PCA(n_components=2) 377 | pca_code.fit(temp_code) 378 | pca_code_transform = pca_code.transform(np.vstack([temp_code])) 379 | pca_centroids_transform = pca_code.transform(temp_centroids) 380 | 381 | plt.figure() 382 | plt.subplot(121) 383 | for i, sub in enumerate(test_labels_unique): 384 | temp_ind = np.where(test_labels[temp_ind_test] == sub)[0] 385 | plt.scatter(pca_code_transform[temp_ind, 0], pca_code_transform[temp_ind, 1], c=dict_colors[sub], marker=dict_marker[sub], s=15) 386 | plt.xlabel('PC1') 387 | plt.ylabel('PC2') 388 | plt.title('Labels') 389 | plt.tight_layout() 390 | plt.subplot(122) 391 | for i, sub in enumerate(np.unique(expert_id)): 392 | temp_ind = np.where(expert_id == sub)[0] 393 | plt.scatter(pca_code_transform[temp_ind, 0], pca_code_transform[temp_ind, 1], c=dict_colors_experts[sub], marker=dict_marker_experts[sub], s=15, label=str(sub)) 394 | plt.scatter(pca_centroids_transform[:, 0], pca_centroids_transform[:, 1], c='red', marker='+', s=15) 395 | plt.xlabel('PC1') 396 | plt.ylabel('PC2') 397 | plt.title('Gating') 398 | plt.tight_layout() 399 | plt.savefig(os.path.join(temp_dir_output, 'code.pdf')) 400 | plt.close() 401 | 402 | readme_current_training['iteration'] = iteration 403 | readme_current_training['f-measure'] = f_measure 404 | readme_current_training['NMI'] = nmi 405 | readme_current_training['best_loss'] = best_loss 406 | write_readme_file(readme_current_training, dir_output, 'current_training_state.txt') 407 | 408 | 409 | dict_losses = OrderedDict() 410 | dict_losses['iteration'] = iteration 411 | dict_losses['best_loss'] = best_loss 412 | dict_losses['loss_model'] = temp_loss_model 413 | dict_losses['reconst_data'] = temp_loss_reconst_data 414 | dict_losses['kl_div_code_standard_gaussian'] = temp_loss_kl_divergence_code_standard_gaussian 415 | dict_losses['usc'] = temp_loss_clustering_usc 416 | dict_losses['f-measure'] = f_measure 417 | print(dict_losses) 418 | 419 | pd.DataFrame(list_loss_model).to_csv(os.path.join(dir_output, 'loss_model.csv'), index=False, header=False, sep=',') 420 | pd.DataFrame(list_loss_reconst_data).to_csv(os.path.join(dir_output, 'loss_reconst_data.csv'), index=False,header=False, sep=',') 421 | pd.DataFrame(list_loss_kl_div_standard_gaussian).to_csv(os.path.join(dir_output, 'loss_kl_code_standard_gaussian.csv'), index=False, header=False, sep=',') 422 | pd.DataFrame(list_loss_clustering_usc).to_csv(os.path.join(dir_output, 'loss_usc.csv'), index=False,header=False, sep=',') 423 | 424 | 425 | ################################ 426 | # plot test data 427 | if not os.path.exists(os.path.join(dir_output, str(iteration))): 428 | os.makedirs(os.path.join(dir_output, str(iteration))) 429 | 430 | # restore model 431 | saver.restore(sess, os.path.join(dir_output, 'model.ckpt')) 432 | 433 | # predict on test data 434 | list_temp_gates = list() 435 | list_temp_code = list() 436 | list_temp_gen_output_gated = list() 437 | list_temp_assignments = list() 438 | for i in range(len(test_data_scaled)): 439 | temp_gates, temp_code, temp_gen_output_gated, temp_assignments, temp_centroids = sess.run([gates_assignments, code, gen_output_gated, assignments, centroids], feed_dict={data: test_data_scaled[i].reshape([1, model_hparams['num_markers']])}) 440 | temp_gen_output_gated = sigmoid(temp_gen_output_gated) 441 | 442 | list_temp_gates.append(temp_gates) 443 | list_temp_code.append(temp_code) 444 | list_temp_gen_output_gated.append(temp_gen_output_gated) 445 | list_temp_assignments.append(temp_assignments) 446 | 447 | temp_gates = np.vstack(list_temp_gates) 448 | temp_code = np.vstack(list_temp_code) 449 | temp_gen_output_gated = np.vstack(list_temp_gen_output_gated) 450 | temp_assignments = np.vstack(list_temp_assignments).squeeze() 451 | temp_ind_test = np.random.choice(range(len(test_data_scaled)), model_hparams['batch_size_test'], replace=False) 452 | temp_ind_test_gen = np.random.choice(range(len(temp_gen_output_gated)), model_hparams['batch_size_test'], replace=False) 453 | 454 | pd.DataFrame(temp_code).to_csv(os.path.join(dir_output, 'code.csv'), index=False, header=False, sep=',') 455 | pd.DataFrame(temp_assignments).to_csv(os.path.join(dir_output, 'assignments.csv'), index=False, header=False,sep=',') 456 | pd.DataFrame(temp_centroids).to_csv(os.path.join(dir_output, 'centroids.csv'), index=False, header=False, sep=',') 457 | 458 | # reconstruct data from moments predicted from gating 459 | expert_id = np.argmax(temp_gates, axis=1) 460 | 461 | temp_dir_output = os.path.join(dir_output, str(iteration)) 462 | if not os.path.exists(temp_dir_output): 463 | os.makedirs(temp_dir_output) 464 | 465 | # plot frequency of experts 466 | temp_dir_output = os.path.join(dir_output, str(iteration), 'experts') 467 | if not os.path.exists(temp_dir_output): 468 | os.makedirs(temp_dir_output) 469 | 470 | expert_frequency = [list(expert_id).count(x) for x in range(model_hparams['num_experts'])] 471 | expert_frequency = expert_frequency / np.sum(expert_frequency) 472 | plt.figure() 473 | plt.bar(range(model_hparams['num_experts']), expert_frequency, align='center') 474 | plt.xticks(range(model_hparams['num_experts']), [str(x) for x in range(model_hparams['num_experts'])]) 475 | plt.xlabel('Expert ID') 476 | plt.ylabel('Expert weight') 477 | plt.tight_layout() 478 | plt.savefig(os.path.join(temp_dir_output, 'barplot_expert_weights.pdf')) 479 | plt.close() 480 | 481 | for i, label in enumerate(test_labels_unique): 482 | temp_ind = np.where(test_labels == label)[0] 483 | plt.figure() 484 | temp_test_gates_arg_max = list(expert_id[temp_ind]) 485 | plt.bar(range(model_hparams['num_experts']), [list(temp_test_gates_arg_max).count(x) for x in range(model_hparams['num_experts'])], tick_label=range(model_hparams['num_experts'])) 486 | plt.xlabel('Expert ID') 487 | plt.ylabel('Frequency selected') 488 | plt.tight_layout() 489 | plt.savefig(os.path.join(temp_dir_output, 'bar_gates_' + str(label) + '.pdf')) 490 | plt.close() 491 | 492 | temp_dir_output = os.path.join(dir_output, str(iteration)) 493 | if not os.path.exists(temp_dir_output): 494 | os.makedirs(temp_dir_output) 495 | 496 | # generator output 497 | temp_dir_output = os.path.join(dir_output, str(iteration), 'gen_output') 498 | if not os.path.exists(temp_dir_output): 499 | os.makedirs(temp_dir_output) 500 | 501 | pca_transformed = pca.transform(np.vstack([test_data_scaled[temp_ind_test], temp_gen_output_gated[temp_ind_test_gen]])) 502 | plt.figure() 503 | for i, sub in enumerate(test_labels_unique): 504 | temp_ind = np.where(test_labels[temp_ind_test] == sub)[0] 505 | plt.scatter(pca_transformed[temp_ind, 0], pca_transformed[temp_ind, 1], c=dict_colors[sub].reshape((1, 4)), s=1) 506 | for i, exp_id in enumerate(range(model_hparams['batch_size_test'])): 507 | temp_ind = np.where(expert_id[temp_ind_test_gen] == exp_id)[0] 508 | temp_ind += model_hparams['batch_size_test'] 509 | for j in temp_ind: 510 | plt.text(pca_transformed[j, 0], pca_transformed[j, 1], str(exp_id), color='r', fontsize=6) 511 | plt.xlabel('PCA1') 512 | plt.ylabel('PCA2') 513 | plt.tight_layout() 514 | plt.savefig(os.path.join(temp_dir_output, 'pca_all-real_vs_expert_gen_output.pdf')) 515 | plt.close() 516 | 517 | for id in np.unique(expert_id): 518 | temp_ind_id = np.where(expert_id == id)[0] 519 | temp_exp_ids = expert_id[temp_ind_id] 520 | 521 | pca_transformed = pca.transform(np.vstack([test_data_scaled[temp_ind_test], temp_gen_output_gated[temp_ind_id]])) 522 | plt.figure() 523 | for i, sub in enumerate(test_labels_unique): 524 | temp_ind = np.where(test_labels[temp_ind_test] == sub)[0] 525 | plt.scatter(pca_transformed[temp_ind, 0], pca_transformed[temp_ind, 1], c=dict_colors[sub].reshape((1, 4)), s=1) 526 | for j_id, j_sample in enumerate(range(model_hparams['batch_size_test'], pca_transformed.shape[0])): 527 | plt.text(pca_transformed[j_sample, 0], pca_transformed[j_sample, 1], temp_exp_ids[j_id], color='r', fontsize=6) 528 | plt.xlabel('PCA1') 529 | plt.ylabel('PCA2') 530 | plt.tight_layout() 531 | plt.savefig(os.path.join(temp_dir_output, str(id) + '_pca_all-real_vs_expert_gen_outputs.pdf')) 532 | plt.close() 533 | 534 | # plot code 535 | temp_dir_output = os.path.join(dir_output, str(iteration), 'code') 536 | if not os.path.exists(temp_dir_output): 537 | os.makedirs(temp_dir_output) 538 | 539 | pca_code = PCA() 540 | pca_code.fit(temp_code) 541 | pca_code_transform = pca_code.transform(np.vstack([temp_code])) 542 | pca_centroids_transform = pca_code.transform(temp_centroids) 543 | 544 | plt.figure() 545 | plt.subplot(121) 546 | for i, sub in enumerate(test_labels_unique): 547 | temp_ind = np.where(test_labels == sub)[0] 548 | plt.scatter(pca_code_transform[temp_ind, 0], pca_code_transform[temp_ind, 1], c=dict_colors[sub], marker=dict_marker[sub], s=15) 549 | plt.xlabel('PC1') 550 | plt.ylabel('PC2') 551 | plt.title('Labels') 552 | plt.tight_layout() 553 | plt.subplot(122) 554 | for i, sub in enumerate(np.unique(expert_id)): 555 | temp_ind = np.where(expert_id == sub)[0] 556 | plt.scatter(pca_code_transform[temp_ind, 0], pca_code_transform[temp_ind, 1], c=dict_colors_experts[sub], marker=dict_marker_experts[sub], s=15, label=str(sub)) 557 | plt.scatter(pca_centroids_transform[:, 0], pca_centroids_transform[:, 1], c='red', marker='+', s=15) 558 | plt.xlabel('PC1') 559 | plt.ylabel('PC2') 560 | plt.title('Gating') 561 | plt.tight_layout() 562 | plt.savefig(os.path.join(temp_dir_output, 'code.pdf')) 563 | plt.close() 564 | 565 | 566 | f_measure = compute_f_measure(test_labels, expert_id) 567 | model_hparams['f-measure'] = f_measure 568 | model_hparams['NMI'] = normalized_mutual_info_score(test_labels, expert_id) 569 | model_hparams['best_loss'] = best_loss 570 | write_readme_file(model_hparams, dir_output) 571 | 572 | 573 | if __name__ == '__main__': 574 | main() 575 | 576 | -------------------------------------------------------------------------------- /src/moe_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Utilities for creating Sparsely-Gated Mixture-of-Experts Layers. 16 | 17 | See "Outrageously Large Neural Networks" 18 | https://arxiv.org/abs/1701.06538 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import functools 26 | import math 27 | 28 | # Dependency imports 29 | 30 | import six 31 | from six.moves import xrange # pylint: disable=redefined-builtin 32 | from six.moves import zip # pylint: disable=redefined-builtin 33 | import tensorflow as tf 34 | import tensorflow.contrib.distributions as tfd 35 | 36 | from tensorflow.python.eager import context 37 | from tensorflow.python.framework import function 38 | 39 | 40 | DEFAULT_DEV_STRING = "existing_device" 41 | 42 | 43 | @function.Defun( 44 | python_grad_func=lambda x, dy: tf.convert_to_tensor(dy), 45 | shape_func=lambda op: [op.inputs[0].get_shape()]) 46 | def convert_gradient_to_tensor(x): 47 | """Identity operation whose gradient is converted to a `Tensor`. 48 | 49 | Currently, the gradient to `tf.concat` is particularly expensive to 50 | compute if dy is an `IndexedSlices` (a lack of GPU implementation 51 | forces the gradient operation onto CPU). This situation occurs when 52 | the output of the `tf.concat` is eventually passed to `tf.gather`. 53 | It is sometimes faster to convert the gradient to a `Tensor`, so as 54 | to get the cheaper gradient for `tf.concat`. To do this, replace 55 | `tf.concat(x)` with `convert_gradient_to_tensor(tf.concat(x))`. 56 | 57 | Args: 58 | x: A `Tensor`. 59 | 60 | Returns: 61 | The input `Tensor`. 62 | """ 63 | return x 64 | 65 | 66 | def add_scope(scope=None, scope_fn=None): 67 | """Return a decorator which add a TF name/variable scope to a function. 68 | 69 | Note that the function returned by the decorator accept an additional 'name' 70 | parameter, which can overwritte the name scope given when the function is 71 | created. 72 | 73 | Args: 74 | scope (str): name of the scope. If None, the function name is used. 75 | scope_fn (fct): Either tf.name_scope or tf.variable_scope 76 | 77 | Returns: 78 | fct: the add_scope decorator 79 | """ 80 | 81 | def decorator(f): 82 | @functools.wraps(f) 83 | def decorated(*args, **kwargs): 84 | name = kwargs.pop("name", 85 | None) # Python 2 hack for keyword only args 86 | with scope_fn(name or scope or f.__name__): 87 | return f(*args, **kwargs) 88 | 89 | return decorated 90 | 91 | return decorator 92 | 93 | 94 | def add_var_scope(scope=None): 95 | return add_scope(scope, scope_fn=tf.variable_scope) 96 | 97 | 98 | def add_name_scope(scope=None): 99 | return add_scope(scope, scope_fn=tf.name_scope) 100 | 101 | 102 | def _add_variable_proxy_methods(var, proxy_tensor): 103 | """Proxy methods of underlying variable. 104 | 105 | This enables our custom getters to still work with, e.g., batch norm. 106 | 107 | Args: 108 | var: Variable to proxy 109 | proxy_tensor: Tensor that is identity of var 110 | """ 111 | proxy_tensor.read_value = lambda: tf.identity(proxy_tensor) 112 | proxy_tensor.assign_sub = var.assign_sub 113 | 114 | 115 | class Parallelism(object): 116 | """Helper class for creating sets of parallel function calls. 117 | 118 | The purpose of this class is to replace this code: 119 | 120 | e = [] 121 | f = [] 122 | for i in xrange(len(devices)): 123 | with tf.device(devices[i]): 124 | e_, f_ = func(a[i], b[i], c) 125 | e.append(e_) 126 | f.append(f_) 127 | 128 | with this code: 129 | 130 | e, f = expert_utils.Parallelism(devices)(func, a, b, c) 131 | """ 132 | 133 | def __init__(self, 134 | device_names_or_functions, 135 | reuse=True, 136 | caching_devices=None, 137 | daisy_chain_variables=False, 138 | ps_devices=None): 139 | """Create a Parallelism. 140 | 141 | Args: 142 | device_names_or_functions: A list of length n, containing device names 143 | or device functions (see `tf.device`) 144 | reuse: True or None. Whether to reuse variables created in the first 145 | replica in the subsequent replicas. 146 | caching_devices: Either `None`, or a list of length n containing device 147 | names. 148 | daisy_chain_variables: a boolean - if true, then copies variables in a 149 | daisy chain between devices. 150 | ps_devices: list, list of devices for experts. 151 | 152 | Returns: 153 | a Parallelism. 154 | """ 155 | assert device_names_or_functions 156 | self._devices = device_names_or_functions 157 | self._n = len(device_names_or_functions) 158 | self._reuse = reuse 159 | self._caching_devices = self._maybe_repeat(caching_devices) 160 | self._daisy_chain_variables = daisy_chain_variables 161 | self._ps_devices = ps_devices or [""] 162 | 163 | def __call__(self, fn, *args, **kwargs): 164 | """A parallel set of function calls (using the specified devices). 165 | 166 | Args: 167 | fn: a function or a list of n functions. 168 | *args: additional args. Each arg should either be not a list, or a list 169 | of length n. 170 | **kwargs: additional keyword args. Each arg should either be not a 171 | list, or a list of length n. 172 | 173 | Returns: 174 | either a single list of length n (if fn does not return a tuple), or a 175 | tuple of lists of length n (if fn returns a tuple). 176 | """ 177 | # Construct lists or args and kwargs for each function. 178 | if args: 179 | my_args = transpose_list_of_lists( 180 | [self._maybe_repeat(arg) for arg in args]) 181 | else: 182 | my_args = [[] for _ in xrange(self.n)] 183 | my_kwargs = [{} for _ in xrange(self.n)] 184 | for k, v in six.iteritems(kwargs): 185 | vals = self._maybe_repeat(v) 186 | for i in xrange(self.n): 187 | my_kwargs[i][k] = vals[i] 188 | 189 | # Construct lists of functions. 190 | fns = self._maybe_repeat(fn) 191 | 192 | # Now make the parallel call. 193 | outputs = [] 194 | cache = {} 195 | tensor_to_var = {} 196 | for i in xrange(self.n): 197 | 198 | def daisy_chain_getter(getter, name, *args, **kwargs): 199 | """Get a variable and cache in a daisy chain.""" 200 | device_var_key = (self._devices[i], name) 201 | if device_var_key in cache: 202 | # if we have the variable on the correct device, return it. 203 | return cache[device_var_key] 204 | if name in cache: 205 | # if we have it on a different device, copy it from the last device 206 | last_device_v = cache[name] 207 | var = tensor_to_var[last_device_v] 208 | v = tf.identity(last_device_v) 209 | else: 210 | var = getter(name, *args, **kwargs) 211 | v = tf.identity(var._ref()) # pylint: disable=protected-access 212 | 213 | # keep track of the original variable 214 | tensor_to_var[v] = var 215 | _add_variable_proxy_methods(tensor_to_var[v], v) 216 | # update the cache 217 | cache[name] = v 218 | cache[device_var_key] = v 219 | return v 220 | 221 | # Variable scope will not reset caching_device on reused variables, 222 | # so we make a custom getter that uses identity to cache the variable. 223 | # pylint: disable=cell-var-from-loop 224 | def caching_getter(getter, name, *args, **kwargs): 225 | """Cache variables on device.""" 226 | key = (self._caching_devices[i], name) 227 | if key in cache: 228 | return cache[key] 229 | 230 | v = getter(name, *args, **kwargs) 231 | with tf.device(self._caching_devices[i]): 232 | ret = tf.identity(v._ref()) # pylint: disable=protected-access 233 | _add_variable_proxy_methods(v, ret) 234 | cache[key] = ret 235 | return ret 236 | 237 | if self._daisy_chain_variables: 238 | custom_getter = daisy_chain_getter 239 | elif self._caching_devices[i]: 240 | custom_getter = caching_getter 241 | else: 242 | custom_getter = None 243 | # pylint: enable=cell-var-from-loop 244 | with tf.name_scope("parallel_%d" % i): 245 | with tf.variable_scope( 246 | tf.get_variable_scope() 247 | if self._reuse else "parallel_%d" % i, 248 | reuse=True if i > 0 and self._reuse else None, 249 | caching_device=self._caching_devices[i], 250 | custom_getter=custom_getter): 251 | # TODO(noam, epot, avaswani) 252 | # Allows for passing no device in case you want to default to the 253 | # existing device. This is needed when we put all experts on a single 254 | # device, for example in local_moe. 255 | if self._devices[i] != DEFAULT_DEV_STRING: 256 | with tf.device(self._devices[i]): 257 | outputs.append(fns[i](*my_args[i], **my_kwargs[i])) 258 | else: 259 | outputs.append(fns[i](*my_args[i], **my_kwargs[i])) 260 | if isinstance(outputs[0], tuple): 261 | outputs = list(zip(*outputs)) 262 | outputs = tuple([list(o) for o in outputs]) 263 | return outputs 264 | 265 | @property 266 | def n(self): 267 | return self._n 268 | 269 | @property 270 | def devices(self): 271 | return self._devices 272 | 273 | @property 274 | def ps_devices(self): 275 | return self._ps_devices 276 | 277 | def _maybe_repeat(self, x): 278 | """Utility function for processing arguments that are singletons or lists. 279 | 280 | Args: 281 | x: either a list of self.n elements, or not a list. 282 | 283 | Returns: 284 | a list of self.n elements. 285 | """ 286 | if isinstance(x, list): 287 | assert len(x) == self.n 288 | return x 289 | else: 290 | return [x] * self.n 291 | 292 | 293 | def _rowwise_unsorted_segment_sum(values, indices, n): 294 | """UnsortedSegmentSum on each row. 295 | 296 | Args: 297 | values: a `Tensor` with shape `[batch_size, k]`. 298 | indices: an integer `Tensor` with shape `[batch_size, k]`. 299 | n: an integer. 300 | Returns: 301 | A `Tensor` with the same type as `values` and shape `[batch_size, n]`. 302 | """ 303 | batch, k = tf.unstack(tf.shape(indices), num=2) 304 | indices_flat = tf.reshape(indices, 305 | [-1]) + tf.div(tf.range(batch * k), k) * n 306 | ret_flat = tf.unsorted_segment_sum( 307 | tf.reshape(values, [-1]), indices_flat, batch * n) 308 | return tf.reshape(ret_flat, [batch, n]) 309 | 310 | 311 | def _normal_distribution_cdf(x, stddev): 312 | """Evaluates the CDF of the normal distribution. 313 | 314 | Normal distribution with mean 0 and standard deviation stddev, 315 | evaluated at x=x. 316 | 317 | input and output `Tensor`s have matching shapes. 318 | 319 | Args: 320 | x: a `Tensor` 321 | stddev: a `Tensor` with the same shape as `x`. 322 | 323 | Returns: 324 | a `Tensor` with the same shape as `x`. 325 | 326 | """ 327 | return 0.5 * (1.0 + tf.erf(x / (math.sqrt(2) * stddev + 1e-20))) 328 | 329 | 330 | def _prob_in_top_k(clean_values, noisy_values, noise_stddev, noisy_top_values, 331 | k): 332 | """Helper function to NoisyTopKGating. 333 | 334 | Computes the probability that value is in top k, given different random noise. 335 | 336 | This gives us a way of backpropagating from a loss that balances the number 337 | of times each expert is in the top k experts per example. 338 | 339 | In the case of no noise, pass in None for noise_stddev, and the result will 340 | not be differentiable. 341 | 342 | Args: 343 | clean_values: a `Tensor` of shape [batch, n]. 344 | noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus 345 | normally distributed noise with standard deviation noise_stddev. 346 | noise_stddev: a `Tensor` of shape [batch, n], or None 347 | noisy_top_values: a `Tensor` of shape [batch, m]. 348 | "values" Output of tf.top_k(noisy_top_values, m). m >= k+1 349 | k: an integer. 350 | 351 | Returns: 352 | a `Tensor` of shape [batch, n]. 353 | """ 354 | batch = tf.shape(clean_values)[0] 355 | m = tf.shape(noisy_top_values)[1] 356 | top_values_flat = tf.reshape(noisy_top_values, [-1]) 357 | # we want to compute the threshold that a particular value would have to 358 | # exceed in order to make the top k. This computation differs depending 359 | # on whether the value is already in the top k. 360 | threshold_positions_if_in = tf.range(batch) * m + k 361 | threshold_if_in = tf.expand_dims( 362 | tf.gather(top_values_flat, threshold_positions_if_in), 1) 363 | is_in = tf.greater(noisy_values, threshold_if_in) 364 | if noise_stddev is None: 365 | return tf.to_float(is_in) 366 | threshold_positions_if_out = threshold_positions_if_in - 1 367 | threshold_if_out = tf.expand_dims( 368 | tf.gather(top_values_flat, threshold_positions_if_out), 1) 369 | # is each value currently in the top k. 370 | prob_if_in = _normal_distribution_cdf(clean_values - threshold_if_in, 371 | noise_stddev) 372 | prob_if_out = _normal_distribution_cdf(clean_values - threshold_if_out, 373 | noise_stddev) 374 | prob = tf.where(is_in, prob_if_in, prob_if_out) 375 | return prob 376 | 377 | 378 | def cv_squared(x): 379 | """The squared coefficient of variation of a sample. 380 | 381 | Useful as a loss to encourage a positive distribution to be more uniform. 382 | Epsilons added for numerical stability. 383 | Returns 0 for an empty Tensor. 384 | 385 | Args: 386 | x: a `Tensor`. 387 | 388 | Returns: 389 | a `Scalar`. 390 | """ 391 | epsilon = 1e-10 392 | float_size = tf.to_float(tf.size(x)) + epsilon 393 | mean = tf.reduce_sum(x) / float_size 394 | variance = tf.reduce_sum(tf.square(x - mean)) / float_size 395 | return variance / (tf.square(mean) + epsilon) 396 | 397 | 398 | def _gates_to_load(gates): 399 | """Compute the true load per expert, given the gates. 400 | 401 | The load is the number of examples for which the corresponding gate is >0. 402 | 403 | Args: 404 | gates: a `Tensor` of shape [batch_size, n] 405 | Returns: 406 | a float32 `Tensor` of shape [n] 407 | """ 408 | return tf.reduce_sum(tf.to_float(gates > 0), 0) 409 | 410 | 411 | def _my_top_k(x, k): 412 | """GPU-compatible version of top-k that works for very small constant k. 413 | 414 | Calls argmax repeatedly. 415 | 416 | tf.nn.top_k is implemented for GPU, but the gradient, sparse_to_dense, 417 | seems not to be, so if we use tf.nn.top_k, then both the top_k and its 418 | gradient go on cpu. Once this is not an issue, this function becomes 419 | obselete and should be replaced by tf.nn.top_k. 420 | 421 | Args: 422 | x: a 2d Tensor. 423 | k: a small integer. 424 | 425 | Returns: 426 | values: a Tensor of shape [batch_size, k] 427 | indices: a int32 Tensor of shape [batch_size, k] 428 | """ 429 | if k > 10: 430 | return tf.nn.top_k(x, k) 431 | values = [] 432 | indices = [] 433 | depth = tf.shape(x)[1] 434 | for i in xrange(k): 435 | values.append(tf.reduce_max(x, 1)) 436 | argmax = tf.argmax(x, 1) 437 | indices.append(argmax) 438 | if i + 1 < k: 439 | x += tf.one_hot(argmax, depth, -1e9) 440 | return tf.stack(values, axis=1), tf.to_int32(tf.stack(indices, axis=1)) 441 | 442 | 443 | def noisy_top_k_gating(x, 444 | num_experts, 445 | codes_size, 446 | train, 447 | deep_gating_architecture, 448 | gating_convolutions, 449 | init, 450 | activation_fct, 451 | batch_normalization, 452 | dropout, 453 | dropout_rate, 454 | depth_gating, 455 | gating_internal_size, 456 | k=2, 457 | initializer=tf.zeros_initializer(), 458 | noisy_gating=True, 459 | noise_epsilon=1e-2, 460 | name=None): 461 | """Noisy top-k gating. 462 | 463 | See paper: https://arxiv.org/abs/1701.06538. 464 | 465 | Args: 466 | x: input Tensor with shape [batch_size, input_size] 467 | num_experts: an integer 468 | train: a boolean - we only add noise at training time. 469 | k: an integer - number of experts per example 470 | initializer: an initializer 471 | noisy_gating: a boolean 472 | noise_epsilon: a float 473 | name: an optional string 474 | 475 | Returns: 476 | gates: a Tensor with shape [batch_size, num_experts] 477 | load: a Tensor with shape [num_experts] 478 | logits: a Tensor with shape [batch_size, num_experts] 479 | """ 480 | with tf.variable_scope(name, default_name="noisy_top_k_gating"): 481 | input_size = x.get_shape().as_list()[-1] 482 | 483 | if not deep_gating_architecture: 484 | w_gate = tf.get_variable("w_gate", [input_size, num_experts], 485 | tf.float32, initializer) 486 | clean_logits = tf.matmul(x, w_gate) 487 | 488 | else: 489 | if gating_convolutions: 490 | # x = tf.layers.batch_normalization(x) 491 | 492 | batch_size = tf.shape(x)[0] 493 | x = tf.expand_dims(x, -1) 494 | 495 | # gen_dense1 = tf.layers.conv1d(inputs=x, 496 | # filters=gating_internal_size, 497 | # kernel_size=codes_size, 498 | # kernel_initializer=init(), 499 | # name="g_conv_init", 500 | # activation=None) 501 | # gen_dense1 = tf.transpose(gen_dense1, [0, 2, 1]) 502 | # if batch_normalization: 503 | # gen_dense1 = tf.layers.batch_normalization(gen_dense1) 504 | # gen_dense1 = activation_fct(gen_dense1) 505 | # 506 | # for i in range(depth_gating): 507 | # gen_dense1 = tf.layers.conv1d(inputs=gen_dense1, 508 | # filters=gating_internal_size, 509 | # kernel_size=gating_internal_size, 510 | # kernel_initializer=init(), 511 | # name="gen_conv_" + str(i)) 512 | # gen_dense1 = tf.transpose(gen_dense1, [0, 2, 1]) 513 | # if batch_normalization: 514 | # gen_dense1 = tf.layers.batch_normalization(gen_dense1) 515 | # gen_dense1 = activation_fct(gen_dense1) 516 | 517 | gen_dense1 = tf.layers.conv1d(inputs=x, 518 | filters=num_experts, 519 | kernel_size=codes_size, 520 | kernel_initializer=init(), 521 | name="gen_conv_ouput", 522 | use_bias=False) 523 | gen_dense1 = tf.transpose(gen_dense1, [0, 2, 1]) 524 | 525 | clean_logits = tf.reshape(gen_dense1, shape=[batch_size, num_experts]) 526 | 527 | 528 | else: 529 | # x = tf.layers.batch_normalization(x) 530 | 531 | # gen_dense1 = tf.layers.dense(inputs=x, 532 | # units=gating_internal_size, 533 | # activation=None, 534 | # name="gating_dense_init", 535 | # kernel_initializer=init()) 536 | # if batch_normalization: 537 | # gen_dense1 = tf.layers.batch_normalization(gen_dense1) 538 | # gen_dense1 = activation_fct(gen_dense1) 539 | # 540 | # if dropout: 541 | # gen_dense1 = tf.layers.dropout(gen_dense1, rate=dropout_rate) 542 | # 543 | # for i in range(depth_gating): 544 | # gen_dense1 = tf.layers.dense(inputs=gen_dense1, 545 | # units=gating_internal_size, 546 | # activation=None, 547 | # name="gating_dense_" + str(i), 548 | # kernel_initializer=init()) 549 | # if batch_normalization: 550 | # gen_dense1 = tf.layers.batch_normalization(gen_dense1) 551 | # gen_dense1 = activation_fct(gen_dense1) 552 | # 553 | # if dropout: 554 | # gen_dense1 = tf.layers.dropout(gen_dense1, rate=dropout_rate) 555 | 556 | clean_logits = tf.layers.dense(inputs=x, 557 | units=num_experts, 558 | activation=None, 559 | name='gating_output', 560 | kernel_initializer=init(), 561 | use_bias=False) 562 | 563 | 564 | 565 | if noisy_gating: 566 | w_noise = tf.get_variable("w_noise", [input_size, num_experts], 567 | tf.float32, initializer) 568 | 569 | raw_noise_stddev = tf.matmul(x, w_noise) 570 | noise_stddev = ((tf.nn.softplus(raw_noise_stddev) + noise_epsilon) 571 | * (tf.to_float(train))) 572 | noisy_logits = clean_logits + ( 573 | tf.random_normal(tf.shape(clean_logits)) * noise_stddev) 574 | logits = noisy_logits 575 | if should_generate_summaries(): 576 | tf.summary.histogram("noisy_logits", noisy_logits) 577 | tf.summary.histogram("noise_stddev", noise_stddev) 578 | else: 579 | logits = clean_logits 580 | 581 | top_logits, top_indices = _my_top_k(logits, min(k + 1, num_experts)) 582 | top_k_logits = tf.slice(top_logits, [0, 0], [-1, k]) 583 | top_k_indices = tf.slice(top_indices, [0, 0], [-1, k]) 584 | top_k_gates = tf.nn.softmax(top_k_logits) 585 | # This will be a `Tensor` of shape `[batch_size, n]`, with zeros in the 586 | # positions corresponding to all but the top k experts per example. 587 | gates = _rowwise_unsorted_segment_sum(top_k_gates, top_k_indices, 588 | num_experts) 589 | if not deep_gating_architecture: 590 | if noisy_gating and k < num_experts: 591 | load = tf.reduce_sum( 592 | _prob_in_top_k(clean_logits, noisy_logits, noise_stddev, 593 | top_logits, k), 0) 594 | else: 595 | load = _gates_to_load(gates) 596 | else: 597 | load = _gates_to_load(gates) 598 | if should_generate_summaries(): 599 | tf.summary.histogram("importance", tf.reduce_sum(gates, 0)) 600 | tf.summary.histogram("load", load) 601 | return gates, load, logits 602 | 603 | 604 | 605 | class PadRemover(object): 606 | """Helper to remove padding from a tensor before sending to the experts. 607 | 608 | The padding is computed for one reference tensor containing the padding mask 609 | and then can be applied to any other tensor of shape [dim_origin,...]. 610 | 611 | Ex: 612 | input = [ 613 | [tok1, tok2], 614 | [tok3, tok4], 615 | [0, 0], 616 | [0, 0], 617 | [tok5, tok6], 618 | [0, 0], 619 | ] 620 | output = [ 621 | [tok1, tok2], 622 | [tok3, tok4], 623 | [tok5, tok6], 624 | ] 625 | """ 626 | 627 | def __init__(self, pad_mask): 628 | """Compute and store the location of the padding. 629 | 630 | Args: 631 | pad_mask (tf.Tensor): Reference padding tensor of shape 632 | [batch_size,length] or [dim_origin] (dim_origin=batch_size*length) 633 | containing non-zeros positive values to indicate padding location. 634 | """ 635 | self.nonpad_ids = None 636 | self.dim_origin = None 637 | 638 | with tf.name_scope("pad_reduce/get_ids"): 639 | pad_mask = tf.reshape(pad_mask, [-1]) # Flatten the batch 640 | # nonpad_ids contains coordinates of zeros rows (as pad_mask is 641 | # float32, checking zero equality is done with |x| < epsilon, with 642 | # epsilon=1e-9 as standard, here pad_mask only contains positive values 643 | # so tf.abs would be redundant) 644 | self.nonpad_ids = tf.to_int32(tf.where(pad_mask < 1e-9)) 645 | self.dim_origin = tf.shape(pad_mask)[:1] 646 | 647 | def remove(self, x): 648 | """Remove padding from the given tensor. 649 | 650 | Args: 651 | x (tf.Tensor): of shape [dim_origin,...] 652 | 653 | Returns: 654 | a tensor of shape [dim_compressed,...] with dim_compressed <= dim_origin 655 | """ 656 | with tf.name_scope("pad_reduce/remove"): 657 | x_shape = x.get_shape().as_list() 658 | x = tf.gather_nd( 659 | x, 660 | indices=self.nonpad_ids, 661 | ) 662 | if not context.in_eager_mode(): 663 | # This is a hack but for some reason, gather_nd return a tensor of 664 | # undefined shape, so the shape is set up manually 665 | x.set_shape([None] + x_shape[1:]) 666 | return x 667 | 668 | def restore(self, x): 669 | """Add padding back to the given tensor. 670 | 671 | Args: 672 | x (tf.Tensor): of shape [dim_compressed,...] 673 | 674 | Returns: 675 | a tensor of shape [dim_origin,...] with dim_compressed >= dim_origin. The 676 | dim is restored from the original reference tensor 677 | """ 678 | with tf.name_scope("pad_reduce/restore"): 679 | x = tf.scatter_nd( 680 | indices=self.nonpad_ids, 681 | updates=x, 682 | shape=tf.concat( 683 | [self.dim_origin, tf.shape(x)[1:]], axis=0), 684 | ) 685 | return x 686 | 687 | 688 | @add_name_scope("map_ids") 689 | def map_ids(x, indices, map_fn): 690 | """Apply a function to each coordinate ids of a multidimentional tensor. 691 | 692 | This allows to process each sequence of a batch independently. This is 693 | similar to tf.map_fn but with tensor where the batch dim has been flatten. 694 | 695 | Warning: The indices ids have to be contigous and orderd in memory as the 696 | output vector for each of the ids are simply concatenated after being 697 | processed. 698 | Ex: if your indices are [0,2,2,1,2,0], the output will contains the processed 699 | rows in the following order: [0,0,1,2,2,2] 700 | 701 | Args: 702 | x (Tensor): The tensor to be dispatched of shape [length,...] 703 | indices (Tensor): A int32 tensor of size [length, 1] containing the batch 704 | coordinate of x 705 | map_fn (fct): Function called for every ids of the original tensor. Take 706 | as input a tensor of same rank than x and from shape [length_id,...] with 707 | length_id <= length. Isn't called if length_id == 0 708 | 709 | Returns: 710 | a tensor of same shape as x, where each elements has been processed 711 | """ 712 | indices = tf.reshape(indices, [-1]) 713 | 714 | t_i = tf.constant(0) 715 | # batch_coordinates start at 0 716 | t_batch_size = tf.reduce_max(indices) + 1 717 | 718 | # ta_stack_out will store the intermediate results for each individual id 719 | # As alternative to tf.TensorArray, scatter_update could potentially be used 720 | # but that would require an additional mutable tensor. 721 | ta_stack_out = tf.TensorArray( 722 | x.dtype, 723 | size=t_batch_size, 724 | ) 725 | 726 | # Then we iterate over each sequence individually and compute the 727 | # transformation for each id 728 | while_condition = lambda t_i, *args: tf.less(t_i, t_batch_size) 729 | 730 | def body(t_i, ta_stack_out): 731 | """Loop body.""" 732 | # Gather the ids 733 | current_ids = tf.to_int32(tf.where(tf.equal(indices, t_i))) 734 | t_row = tf.gather_nd(x, indices=current_ids) 735 | 736 | # TODO(epot): Should not call map_fn if t_row size is 0 737 | 738 | # Apply transformation to each id 739 | # Restore batch_dim=1 as most function expect [batch_dim, length, ...] as 740 | # input 741 | t_row = tf.expand_dims(t_row, axis=0) 742 | t_row = map_fn(t_row) 743 | t_row = tf.squeeze(t_row, axis=0) # Squeeze for concatenation 744 | ta_stack_out = ta_stack_out.write(t_i, t_row) 745 | 746 | return [tf.add(t_i, 1), ta_stack_out] # ++i 747 | 748 | # Run the loop, equivalent to: 749 | # stack_out = [] 750 | # while i < batch_size: 751 | # stack_out.expand(map_fn(x[indices==i])) 752 | _, ta_stack_out = tf.while_loop(while_condition, body, [t_i, ta_stack_out]) 753 | 754 | # Merge all results 755 | return ta_stack_out.concat() 756 | 757 | 758 | class SparseDispatcher(object): 759 | """Helper for implementing a mixture of experts. 760 | 761 | The purpose of this class is to create input minibatches for the 762 | experts and to combine the results of the experts to form a unified 763 | output tensor. 764 | 765 | There are two functions: 766 | dispatch - take an input Tensor and create input Tensors for each expert. 767 | combine - take output Tensors from each expert and form a combined output 768 | Tensor. Outputs from different experts for the same batch element are 769 | summed together, weighted by the provided "gates". 770 | 771 | The class is initialized with a "gates" Tensor, which specifies which 772 | batch elements go to which experts, and the weights to use when combining 773 | the outputs. Batch element b is sent to expert e iff gates[b, e] != 0. 774 | 775 | The inputs and outputs are all two-dimensional [batch, depth]. 776 | Caller is responsible for collapsing additional dimensions prior to 777 | calling this class and reshaping the output to the original shape. 778 | See reshape_like(). 779 | 780 | Example use: 781 | 782 | gates: a float32 `Tensor` with shape `[batch_size, num_experts]` 783 | inputs: a float32 `Tensor` with shape `[batch_size, input_size]` 784 | experts: a list of length `num_experts` containing sub-networks. 785 | 786 | dispatcher = SparseDispatcher(num_experts, gates) 787 | expert_inputs = dispatcher.dispatch(inputs) 788 | expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)] 789 | outputs = dispatcher.combine(expert_outputs) 790 | 791 | The preceding code sets the output for a particular example b to: 792 | output[b] = Sum_i(gates[b, i] * experts[i](inputs[b])) 793 | 794 | This class takes advantage of sparsity in the gate matrix by including in the 795 | `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`. 796 | """ 797 | 798 | def __init__(self, num_experts, gates): 799 | """Create a SparseDispatcher. 800 | 801 | Args: 802 | num_experts: an integer. 803 | gates: a `Tensor` of shape `[batch_size, num_experts]`. 804 | 805 | Returns: 806 | a SparseDispatcher 807 | """ 808 | self._gates = gates 809 | self._num_experts = num_experts 810 | 811 | where = tf.cast(tf.where(tf.transpose(gates) > 0), dtype=tf.int32) 812 | self._expert_index, self._batch_index = tf.unstack( 813 | where, num=2, axis=1) 814 | self._part_sizes_tensor = tf.reduce_sum(tf.to_int32(gates > 0), [0]) 815 | self._nonzero_gates = tf.gather( 816 | tf.reshape(self._gates, [-1]), 817 | self._batch_index * num_experts + self._expert_index) 818 | 819 | @add_name_scope() 820 | def dispatch(self, inp): 821 | """Create one input Tensor for each expert. 822 | 823 | The `Tensor` for a expert `i` contains the slices of `inp` corresponding 824 | to the batch elements `b` where `gates[b, i] > 0`. 825 | 826 | Args: 827 | inp: a `Tensor` of shape "[batch_size, ]` 828 | Returns: 829 | a list of `num_experts` `Tensor`s with shapes 830 | `[expert_batch_size_i, ]`. 831 | """ 832 | inp = tf.gather(inp, self._batch_index) 833 | return tf.split(inp, self._part_sizes_tensor, 0, num=self._num_experts) 834 | 835 | @add_name_scope() 836 | def combine(self, expert_out, multiply_by_gates=True): 837 | """Sum together the expert output, weighted by the gates. 838 | 839 | The slice corresponding to a particular batch element `b` is computed 840 | as the sum over all experts `i` of the expert output, weighted by the 841 | corresponding gate values. If `multiply_by_gates` is set to False, the 842 | gate values are ignored. 843 | 844 | Args: 845 | expert_out: a list of `num_experts` `Tensor`s, each with shape 846 | `[expert_batch_size_i, ]`. 847 | multiply_by_gates: a boolean 848 | 849 | Returns: 850 | a `Tensor` with shape `[batch_size, ]`. 851 | """ 852 | # see comments on convert_gradient_to_tensor 853 | stitched = convert_gradient_to_tensor(tf.concat(expert_out, 0)) 854 | if multiply_by_gates: 855 | stitched *= tf.expand_dims(self._nonzero_gates, 1) 856 | combined = tf.unsorted_segment_sum(stitched, self._batch_index, 857 | tf.shape(self._gates)[0]) 858 | return combined 859 | 860 | def expert_to_gates(self): 861 | """Gate values corresponding to the examples in the per-expert `Tensor`s. 862 | 863 | Returns: 864 | a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32` 865 | and shapes `[expert_batch_size_i]` 866 | """ 867 | return tf.split( 868 | self._nonzero_gates, 869 | self._part_sizes_tensor, 870 | 0, 871 | num=self._num_experts) 872 | 873 | def expert_to_batch_indices(self): 874 | """Batch indices corresponding to the examples in the per-expert `Tensor`s. 875 | 876 | Returns: 877 | a list of `num_experts` one-dimensional `Tensor`s with type `tf.int64` 878 | and shapes `[expert_batch_size_i]` 879 | """ 880 | return tf.split( 881 | self._batch_index, 882 | self._part_sizes_tensor, 883 | 0, 884 | num=self._num_experts) 885 | 886 | @property 887 | def part_sizes(self): 888 | return self._part_sizes_tensor 889 | 890 | 891 | class DistributedSparseDispatcher(object): 892 | """A distributed version of SparseDispatcher. 893 | 894 | Instead of one batch of input examples, we simultaneously process 895 | a list of num_datashards batches of input examples. The per-expert 896 | `Tensor`s contain a combination of examples from the different datashards. 897 | 898 | Each datashard is associated with a particular device and each expert is 899 | associated with a particular device. All per-datashard and per-expert 900 | `Tensor`s are created on those devices. There is no single-device bottleneck. 901 | """ 902 | 903 | def __init__(self, data_parallelism, expert_parallelism, gates): 904 | """Create a DistributedSparseDispatcher. 905 | 906 | Args: 907 | data_parallelism: a Parallelism object. 908 | expert_parallelism: a Parallelism object. 909 | gates: a list of datashard_parallelism.n `Tensor`s of shapes 910 | `[batch_size[d], num_experts]`. 911 | 912 | Returns: 913 | a DistributedSparseDispatcher 914 | """ 915 | self._gates = gates 916 | self._dp = data_parallelism 917 | self._ep = expert_parallelism 918 | assert len(gates) == self._dp.n 919 | self._dispatchers = self._dp(SparseDispatcher, self._ep.n, gates) 920 | 921 | def dispatch(self, inp): 922 | """Create one input Tensor for each expert. 923 | 924 | Args: 925 | inp: a list of length num_datashards `Tensor`s with shapes 926 | `[batch_size[d], ]`. 927 | Returns: 928 | a list of `num_experts` `Tensor`s with shapes 929 | `[num_examples[i], ]`. 930 | """ 931 | dispatched = self._dp(lambda a, b: a.dispatch(b), self._dispatchers, 932 | inp) 933 | ret = self._ep(tf.concat, transpose_list_of_lists(dispatched), 0) 934 | if ret[0].dtype == tf.float32: 935 | # see comments on convert_gradient_to_tensor 936 | ret = self._ep(convert_gradient_to_tensor, ret) 937 | return ret 938 | 939 | def combine(self, expert_out, multiply_by_gates=True): 940 | """Sum together the expert output, multiplied by the corresponding gates. 941 | 942 | Args: 943 | expert_out: a list of `num_experts` `Tensor`s, each with shape 944 | `[expert_batch_size_i, ]`. 945 | multiply_by_gates: a boolean. 946 | 947 | Returns: 948 | a list of num_datashards `Tensor`s with shapes 949 | `[batch_size[d], ]`. 950 | """ 951 | expert_part_sizes = tf.unstack( 952 | tf.stack([d.part_sizes for d in self._dispatchers]), 953 | num=self._ep.n, 954 | axis=1) 955 | # list of lists of shape [num_experts][num_datashards] 956 | expert_output_parts = self._ep(tf.split, expert_out, expert_part_sizes) 957 | expert_output_parts_t = transpose_list_of_lists(expert_output_parts) 958 | 959 | def my_combine(dispatcher, parts): 960 | return dispatcher.combine( 961 | convert_gradient_to_tensor(tf.concat(parts, 0)), 962 | multiply_by_gates=multiply_by_gates) 963 | 964 | return self._dp(my_combine, self._dispatchers, expert_output_parts_t) 965 | 966 | def expert_to_gates(self): 967 | """Gate values corresponding to the examples in the per-expert `Tensor`s. 968 | 969 | Returns: 970 | a list of `num_experts` one-dimensional `Tensor`s of type `tf.float32`. 971 | """ 972 | return self._ep( 973 | tf.concat, 974 | transpose_list_of_lists( 975 | self._dp(lambda d: d.expert_to_gates(), self._dispatchers)), 0) 976 | 977 | 978 | def transpose_list_of_lists(lol): 979 | """Transpose a list of equally-sized python lists. 980 | 981 | Args: 982 | lol: a list of lists 983 | Returns: 984 | a list of lists 985 | """ 986 | assert lol, "cannot pass the empty list" 987 | return [list(x) for x in zip(*lol)] 988 | 989 | 990 | def ffn_expert_fn(output_size, 991 | input_size, 992 | dropout, 993 | dropout_rate, 994 | batch_normalization, 995 | init, 996 | depth, 997 | activation_fct, 998 | convolution, 999 | decoder_internal_size): 1000 | """Returns a function that creates a feed-forward network. 1001 | 1002 | Use this function to create the expert_fn argument to distributed_moe. 1003 | 1004 | Args: 1005 | input_size: an integer 1006 | hidden_sizes: a list of integers 1007 | output_size: an integer 1008 | hidden_activation: a unary function. 1009 | 1010 | Returns: 1011 | a unary function 1012 | """ 1013 | 1014 | def my_fn(x): 1015 | # if convolution: 1016 | # batch_size = tf.shape(x)[0] 1017 | # x = tf.expand_dims(x, -1) 1018 | # 1019 | # gen_dense1 = tf.layers.conv1d(inputs=x, 1020 | # filters=decoder_internal_size, 1021 | # kernel_size = input_size, 1022 | # kernel_initializer=init(), 1023 | # name="g_conv_init") 1024 | # gen_dense1 = tf.transpose(gen_dense1, [0, 2, 1]) 1025 | # 1026 | # else: 1027 | # gen_dense1 = tf.layers.dense(inputs=x, 1028 | # units=decoder_internal_size, 1029 | # activation=None, 1030 | # name="gen_dense_init", 1031 | # kernel_initializer = init()) 1032 | # if batch_normalization: 1033 | # gen_dense1 = tf.layers.batch_normalization(gen_dense1) 1034 | # gen_dense1 = activation_fct(gen_dense1) 1035 | # 1036 | # if dropout: 1037 | # gen_dense1 = tf.layers.dropout(gen_dense1, rate=dropout_rate) 1038 | # 1039 | # 1040 | # for i in range(depth): 1041 | # if convolution: 1042 | # gen_dense1 = tf.layers.conv1d(inputs=gen_dense1, 1043 | # filters=decoder_internal_size, 1044 | # kernel_size=decoder_internal_size, 1045 | # kernel_initializer=init(), 1046 | # name="gen_conv_" + str(i)) 1047 | # gen_dense1 = tf.transpose(gen_dense1, [0, 2, 1]) 1048 | # else: 1049 | # gen_dense1 = tf.layers.dense(inputs=gen_dense1, 1050 | # units=decoder_internal_size, 1051 | # activation=None, 1052 | # name="gen_dense_" + str(i), 1053 | # kernel_initializer=init()) 1054 | # if batch_normalization: 1055 | # gen_dense1 = tf.layers.batch_normalization(gen_dense1) 1056 | # gen_dense1 = activation_fct(gen_dense1) 1057 | # 1058 | # if dropout: 1059 | # gen_dense1 = tf.layers.dropout(gen_dense1, rate=dropout_rate) 1060 | # 1061 | # if convolution: 1062 | # gen_output = tf.layers.conv1d(inputs=gen_dense1, 1063 | # filters=output_size, 1064 | # kernel_size=decoder_internal_size, 1065 | # kernel_initializer=init(), 1066 | # name="gen_conv_output") 1067 | # gen_output = tf.transpose(gen_output, [0, 2, 1]) 1068 | # gen_output = tf.reshape(gen_output, shape=[batch_size, output_size]) 1069 | # else: 1070 | # gen_output = tf.layers.dense(inputs=gen_dense1, 1071 | # units=output_size, 1072 | # activation=None, 1073 | # name='gen_output', 1074 | # kernel_initializer = init()) 1075 | 1076 | if convolution: 1077 | batch_size = tf.shape(x)[0] 1078 | x = tf.expand_dims(x, -1) 1079 | 1080 | gen_output = tf.layers.conv1d(inputs=x, 1081 | filters=output_size, 1082 | kernel_size=input_size, 1083 | kernel_initializer=init(), 1084 | name="gen_conv_output") 1085 | gen_output = tf.transpose(gen_output, [0, 2, 1]) 1086 | gen_output = tf.reshape(gen_output, shape=[batch_size, output_size]) 1087 | 1088 | else: 1089 | gen_output = tf.layers.dense(inputs=x, 1090 | units=output_size, 1091 | activation=None, 1092 | name='gen_output', 1093 | kernel_initializer = init()) 1094 | 1095 | return gen_output 1096 | 1097 | return my_fn 1098 | 1099 | 1100 | def ffn_expert_fn_decoder(encoder_internal_size, output_size): 1101 | def my_fn(x): 1102 | x = tf.layers.dense(x, encoder_internal_size, tf.nn.sigmoid) 1103 | loc = tf.layers.dense(x, output_size) 1104 | scale = tf.layers.dense(x, output_size, tf.nn.softplus) 1105 | return sample(loc, scale), loc, scale 1106 | return my_fn 1107 | 1108 | tfd = tf.contrib.distributions 1109 | 1110 | 1111 | def sample(loc, scale): 1112 | 1113 | # Sample epsilon 1114 | epsilon = tf.random_normal(tf.shape(scale), name='epsilon') 1115 | 1116 | # Sample latent variable 1117 | std_encoder = tf.exp(0.5 * scale) 1118 | z = loc + tf.multiply(std_encoder, epsilon) 1119 | return(z) 1120 | 1121 | 1122 | 1123 | def reshape_like(a, b): 1124 | """Reshapes a to match the shape of b in all but the last dimension.""" 1125 | ret = tf.reshape(a, tf.concat([tf.shape(b)[:-1], tf.shape(a)[-1:]], 0)) 1126 | if not context.in_eager_mode(): 1127 | ret.set_shape(b.get_shape().as_list()[:-1] + 1128 | a.get_shape().as_list()[-1:]) 1129 | return ret 1130 | 1131 | 1132 | def flatten_all_but_last(a): 1133 | """Flatten all dimensions of a except the last.""" 1134 | ret = tf.reshape(a, [-1, tf.shape(a)[-1]]) 1135 | if not context.in_eager_mode(): 1136 | ret.set_shape([None] + a.get_shape().as_list()[-1:]) 1137 | return ret 1138 | 1139 | 1140 | def distributed_moe(data_parallelism, 1141 | expert_devices, 1142 | xs, 1143 | train, 1144 | input_size, 1145 | expert_fn, 1146 | num_experts, 1147 | k=2, 1148 | loss_coef=1e-2, 1149 | name=None): 1150 | """Call a distributed mixture of experts. 1151 | 1152 | Args: 1153 | data_parallelism: a expert_utils.Parallelism object. 1154 | expert_devices: a list of strings. We round-robin the experts across these 1155 | devices. 1156 | xs: a list of input tensors, each with shape [... , input_size] 1157 | train: a boolean scalar. 1158 | input_size: an integer (input size for this layer) 1159 | expert_fn: a unary function for each expert to run 1160 | It should take a Tensor with shape [batch_size, input_size] 1161 | and return a Tensor with shape [batch_size, output_size]. 1162 | e.g. ffn_expert_fn(...) 1163 | num_experts: an integer - number of experts 1164 | k: an integer - how many experts to use for each batch element 1165 | loss_coef: a scalar - multiplier on load-balancing losses 1166 | name: a string 1167 | 1168 | Returns: 1169 | ys: a list of tensors. Each Tensor has the same shape as the corresponding 1170 | Tensor in xs, except for the last dimension, which is output_size. 1171 | extra_training_loss: a scalar. This should be added into the overall 1172 | training loss of the model. The backpropagation of this loss 1173 | encourages all experts to be approximately equally used across a batch. 1174 | """ 1175 | dp = data_parallelism 1176 | # create a parallelism object for running the experts. 1177 | # We use the default of reuse=False. Otherwise, the experts would all 1178 | # use the same variables. 1179 | ep = Parallelism( 1180 | [expert_devices[i % len(expert_devices)] for i in xrange(num_experts)], 1181 | reuse=None) 1182 | # Experts expect 2d input tensors, so flatten the batch dimension and all 1183 | # spatial dimensions together. 1184 | xs_flat = dp(tf.reshape, xs, [[-1, input_size]] * dp.n) 1185 | with tf.variable_scope(name, default_name="moe"): 1186 | # The gates indicate which batch elements go to which tensors. 1187 | # load is a measure of approximately how many examples go to each expert 1188 | gates, load, logits = dp( 1189 | noisy_top_k_gating, 1190 | xs_flat, 1191 | num_experts, 1192 | train, 1193 | k, 1194 | initializer=tf.zeros_initializer(), 1195 | noisy_gating=True, 1196 | noise_epsilon=1e-2) 1197 | # This magic object helps us shuffle data between datashards and experts. 1198 | dispatcher = DistributedSparseDispatcher(dp, ep, gates) 1199 | expert_in = dispatcher.dispatch(xs_flat) 1200 | expert_out = ep(expert_fn, expert_in) 1201 | ys_flat = dispatcher.combine(expert_out) 1202 | ys = dp(reshape_like, ys_flat, xs) 1203 | # compute some load-balancing losses. 1204 | load = tf.add_n(load) 1205 | importance = tf.add_n(dp(tf.reduce_sum, gates, 0)) 1206 | loss = loss_coef * (cv_squared(importance) + cv_squared(load)) 1207 | return ys, loss 1208 | 1209 | 1210 | 1211 | def tsne_repel(code_size, gate_logits, batch_size, p): 1212 | 1213 | nu = tf.constant(code_size - 1, dtype=tf.float32) 1214 | 1215 | sum_y = tf.reduce_sum(tf.square(gate_logits), reduction_indices=1) 1216 | num = -2.0 * tf.matmul(gate_logits, 1217 | gate_logits, 1218 | transpose_b=True) + tf.reshape(sum_y, [-1, 1]) + sum_y 1219 | num = num / nu 1220 | 1221 | p = p + 0.1 / batch_size 1222 | p = p / tf.expand_dims(tf.reduce_sum(p, reduction_indices=1), 1) 1223 | 1224 | num = tf.pow(1.0 + num, -(nu + 1.0) / 2.0) 1225 | attraction = tf.multiply(p, tf.log(num)) 1226 | attraction = -tf.reduce_sum(attraction) 1227 | 1228 | den = tf.reduce_sum(num, reduction_indices=1) - 1 1229 | repellant = tf.reduce_sum(tf.log(den)) 1230 | 1231 | return (repellant + attraction) / batch_size 1232 | 1233 | 1234 | import numpy as np 1235 | def compute_transition_probability(x, perplexity=30.0, tol=1e-4, max_iter=50, verbose=False): 1236 | # x should be properly scaled so the distances are not either too small or too large 1237 | 1238 | if verbose: 1239 | print('tSNE: searching for sigma ...') 1240 | 1241 | (n, d) = x.shape 1242 | sum_x = np.sum(np.square(x), 1) 1243 | 1244 | dist = np.add(np.add(-2 * np.dot(x, x.T), sum_x).T, sum_x) 1245 | p = np.zeros((n, n)) 1246 | 1247 | # Parameterized by precision 1248 | beta = np.ones((n, 1)) 1249 | entropy = np.log(perplexity) / np.log(2) 1250 | 1251 | # Binary search for sigma_i 1252 | idx = range(n) 1253 | for i in range(n): 1254 | idx_i = list(idx[:i]) + list(idx[i+1:n]) 1255 | 1256 | beta_min = -np.inf 1257 | beta_max = np.inf 1258 | 1259 | # Remove d_ii 1260 | dist_i = dist[i, idx_i] 1261 | h_i, p_i = compute_entropy(dist_i, beta[i]) 1262 | h_diff = h_i - entropy 1263 | 1264 | iter_i = 0 1265 | while np.abs(h_diff) > tol and iter_i < max_iter: 1266 | if h_diff > 0: 1267 | beta_min = beta[i].copy() 1268 | if np.isfinite(beta_max): 1269 | beta[i] = (beta[i] + beta_max) / 2.0 1270 | else: 1271 | beta[i] *= 2.0 1272 | else: 1273 | beta_max = beta[i].copy() 1274 | if np.isfinite(beta_min): 1275 | beta[i] = (beta[i] + beta_min) / 2.0 1276 | else: 1277 | beta[i] /= 2.0 1278 | 1279 | h_i, p_i = compute_entropy(dist_i, beta[i]) 1280 | h_diff = h_i - entropy 1281 | 1282 | iter_i += 1 1283 | 1284 | p[i, idx_i] = p_i 1285 | 1286 | if verbose: 1287 | print('Min of sigma square: {}'.format(np.min(1 / beta))) 1288 | print('Max of sigma square: {}'.format(np.max(1 / beta))) 1289 | print('Mean of sigma square: {}'.format(np.mean(1 / beta))) 1290 | 1291 | return p 1292 | 1293 | import sys 1294 | MAX_VAL = np.log(sys.float_info.max) / 2.0 1295 | def compute_entropy(dist=np.array([]), beta=1.0): 1296 | p = -dist * beta 1297 | shift = MAX_VAL - max(p) 1298 | p = np.exp(p + shift) 1299 | sum_p = np.sum(p) 1300 | 1301 | h = np.log(sum_p) - shift + beta * np.sum(np.multiply(dist, p)) / sum_p 1302 | 1303 | return h, p / sum_p 1304 | 1305 | 1306 | 1307 | def local_moe(x, 1308 | code, 1309 | loc_code, 1310 | scale_code, 1311 | train, 1312 | expert_fn, 1313 | expert_fn_loc, 1314 | expert_fn_scale, 1315 | num_experts, 1316 | num_markers, 1317 | code_size, 1318 | deep_gating_architecture, 1319 | gating_convolutions, 1320 | init, 1321 | activation_fct, 1322 | batch_normalization, 1323 | dropout, 1324 | dropout_rate, 1325 | depth_gating, 1326 | gating_internal_size, 1327 | p_transition_prob, 1328 | k=2, 1329 | pass_x=True, 1330 | pass_gates=False, 1331 | additional_dispatch_params=None, 1332 | noisy_gating=True, 1333 | noise_eps=1e-2, 1334 | name=None, 1335 | loss_coef=1e-2, 1336 | loss_coef_reconst_code=1, 1337 | regularize_std_of_experts=False, 1338 | regularize_importance=False, 1339 | regularize_expert_distance=False, 1340 | regularize_experts_dist_to_mean=False, 1341 | regularize_laod_balancing=False, 1342 | reconst_code_from_gates=False, 1343 | regularize_entropy_on_gating_logits=False): 1344 | """Call a local mixture of experts. 1345 | 1346 | Args: 1347 | x: a tensors with shape [... , input_size] 1348 | train: a boolean scalar. 1349 | expert_fn: a function. 1350 | num_experts: an integer - number of experts 1351 | k: an integer - how many experts to use for each batch element 1352 | loss_coef: a scalar - multiplier on load-balancing losses 1353 | pass_x: a boolean. If true, x will also be dispatched to the experts. 1354 | pass_gates: a boolean. If true, gates will be passed to experts. Might be 1355 | necessary when dealing with sparse encoder-encoder decoder attention 1356 | additional_dispatch_params: The extra tensors that need to be sent to each 1357 | expert. Examples include batch batch coordinates (see 1358 | common_attention.local_expert_attention) 1359 | name: a string 1360 | 1361 | Returns: 1362 | y: a tensor. Has the same shape as x, except for the last dimension, 1363 | which is output_size. 1364 | extra_training_loss: a scalar. This should be added into the overall 1365 | training loss of the model. The backpropagation of this loss 1366 | encourages all experts to be approximately equally used across a batch. 1367 | """ 1368 | 1369 | with tf.variable_scope(name, default_name="local_moe"): 1370 | x_flat = flatten_all_but_last(x) 1371 | 1372 | # The gates indicate which batch elements go to which tensors. 1373 | # load is a measure of approximately how many examples go to each expert 1374 | gates, load, logits = noisy_top_k_gating( 1375 | x_flat, 1376 | num_experts, 1377 | code_size, 1378 | train, 1379 | deep_gating_architecture, 1380 | gating_convolutions, 1381 | init, 1382 | activation_fct, 1383 | batch_normalization, 1384 | dropout, 1385 | dropout_rate, 1386 | depth_gating, 1387 | gating_internal_size, 1388 | k, 1389 | initializer=tf.zeros_initializer(), 1390 | noisy_gating=noisy_gating, 1391 | noise_epsilon=noise_eps) 1392 | 1393 | gate_logits_softmax = tf.nn.softmax(logits) 1394 | 1395 | 1396 | # This magic object helps us shuffle data between datashards and experts. 1397 | dispatcher = SparseDispatcher(num_experts, gates) 1398 | dispatcher_moe_loss = SparseDispatcher(num_experts, tf.ones(tf.shape(gates))) 1399 | dispatcher_loc = SparseDispatcher(num_experts, tf.ones(tf.shape(gates))) 1400 | dispatcher_loc_reconst = SparseDispatcher(num_experts, gates) 1401 | dispatcher_scale = SparseDispatcher(num_experts, tf.ones(tf.shape(gates))) 1402 | dispatcher_scale_reconst = SparseDispatcher(num_experts, gates) 1403 | 1404 | # Set up expert_fn arguments 1405 | expert_kwargs = {} 1406 | expert_kwargs_moe_loss = {} 1407 | expert_kwargs_loc = {} 1408 | expert_kwargs_loc_reconst = {} 1409 | expert_kwargs_scale = {} 1410 | expert_kwargs_scale_reconst = {} 1411 | if pass_x: 1412 | expert_kwargs["x"] = dispatcher.dispatch(x_flat) 1413 | expert_kwargs_moe_loss["x"] = dispatcher_moe_loss.dispatch(x_flat) 1414 | expert_kwargs_loc["x"] = dispatcher_loc.dispatch(gate_logits_softmax) 1415 | expert_kwargs_loc_reconst["x"] = dispatcher_loc_reconst.dispatch(gate_logits_softmax) 1416 | expert_kwargs_scale["x"] = dispatcher_scale.dispatch(gate_logits_softmax) 1417 | expert_kwargs_scale_reconst["x"] = dispatcher_scale_reconst.dispatch(gate_logits_softmax) 1418 | if pass_gates: 1419 | expert_kwargs["gates"] = dispatcher.expert_to_gates() 1420 | expert_kwargs_moe_loss["gates"] = dispatcher_moe_loss.expert_to_gates() 1421 | expert_kwargs_loc["gates"] = dispatcher_loc.expert_to_gates() 1422 | expert_kwargs_loc_reconst["gates"] = dispatcher_loc_reconst.expert_to_gates() 1423 | expert_kwargs_scale["gates"] = dispatcher_scale.expert_to_gates() 1424 | expert_kwargs_scale_reconst["gates"] = dispatcher_scale_reconst.expert_to_gates() 1425 | for k, v in six.iteritems(additional_dispatch_params or {}): 1426 | v = flatten_all_but_last(v) 1427 | expert_kwargs[k] = dispatcher.dispatch(v) 1428 | expert_kwargs_moe_loss[k] = dispatcher_moe_loss.dispatch(v) 1429 | expert_kwargs_loc[k] = dispatcher_loc.dispatch(v) 1430 | expert_kwargs_loc_reconst[k] = dispatcher_loc_reconst.dispatch(v) 1431 | expert_kwargs_scale[k] = dispatcher_scale.dispatch(v) 1432 | expert_kwargs_scale_reconst[k] = dispatcher_scale_reconst.dispatch(v) 1433 | 1434 | ep = Parallelism([DEFAULT_DEV_STRING] * num_experts, reuse=None) 1435 | expert_outputs = ep(expert_fn, **expert_kwargs) 1436 | 1437 | # compute distances between expert outputs to maximize them 1438 | with tf.variable_scope('moe_loss'): 1439 | ep_moe_loss = Parallelism([DEFAULT_DEV_STRING] * num_experts, reuse=None) 1440 | expert_outputs_moe_loss = ep_moe_loss(expert_fn, **expert_kwargs_moe_loss) 1441 | 1442 | expert_outputs_moe_loss = tf.map_fn(lambda x: tf.identity(x), elems=expert_outputs_moe_loss, dtype=tf.float32) 1443 | expert_outputs_moe_loss = tf.transpose(expert_outputs_moe_loss, [1,0,2]) 1444 | expert_outputs_moe_loss_mean = tf.map_fn(lambda x: tf.reduce_mean(x, axis=0), expert_outputs_moe_loss, dtype=tf.float32) 1445 | 1446 | distance_matrix = pairwise_dist(expert_outputs_moe_loss_mean) 1447 | 1448 | # regularize std for expert outputs 1449 | expert_outputs_moe_loss_std = tf.map_fn(lambda x: tf.keras.backend.std(x, axis=0), expert_outputs_moe_loss, dtype=tf.float32) 1450 | 1451 | 1452 | y_flat = dispatcher.combine(expert_outputs) 1453 | y = reshape_like(y_flat, x) 1454 | 1455 | 1456 | loss = tf.zeros((1,), dtype=tf.float32) 1457 | loss_code_reconst = tf.zeros((1,), dtype=tf.float32) 1458 | # code_bottleneck = tf.placeholder(tf.float32, shape=[None, None]) 1459 | batch_size = tf.shape(x)[0] 1460 | code_reconst = None 1461 | if regularize_laod_balancing: 1462 | loss += loss_coef * cv_squared(load) 1463 | if regularize_std_of_experts: 1464 | loss += tf.reduce_sum(expert_outputs_moe_loss_std) 1465 | if regularize_importance: 1466 | importance = tf.reduce_sum(gates, 0) 1467 | loss += loss_coef * cv_squared(importance) 1468 | if regularize_expert_distance: 1469 | loss += tf.reduce_sum(distance_matrix)/2 1470 | if regularize_experts_dist_to_mean: 1471 | expert_outputs_moe_loss_mean_flatten = tf.reshape(expert_outputs_moe_loss_mean, [num_experts*num_markers, ]) 1472 | # expert_outputs_moe_loss_std_flatten = tf.reshape(expert_outputs_moe_loss_std, [num_experts*num_markers, ]) 1473 | 1474 | # normal_dist_samples = tf.map_fn(lambda x: tf.distributions.Normal(x[0], x[1]), (expert_outputs_moe_loss_mean_flatten, expert_outputs_moe_loss_std_flatten), dtype=tf.distributions) 1475 | 1476 | expert_outputs_moe_loss_flatten = tf.reshape(tf.transpose(expert_outputs_moe_loss, [1,0,2]), [num_experts*num_markers, batch_size]) 1477 | 1478 | expert_outputs_dist_to_mean = tf.map_fn(lambda x: dist_to(x[0], x[1]), (expert_outputs_moe_loss_flatten, expert_outputs_moe_loss_mean_flatten), dtype=tf.float32) 1479 | 1480 | loss += tf.reduce_sum(expert_outputs_dist_to_mean) 1481 | if reconst_code_from_gates: 1482 | # compute distances between expert outputs to maximize them 1483 | with tf.variable_scope('expert_loc'): 1484 | ep_loc = Parallelism([DEFAULT_DEV_STRING] * num_experts, reuse=tf.AUTO_REUSE) 1485 | expert_outputs_loc = ep_loc(expert_fn_loc, **expert_kwargs_loc) 1486 | 1487 | with tf.variable_scope('expert_scale'): 1488 | ep_scale = Parallelism([DEFAULT_DEV_STRING] * num_experts, reuse=tf.AUTO_REUSE) 1489 | expert_outputs_scale = ep_scale(expert_fn_scale, **expert_kwargs_scale) 1490 | 1491 | loss_code_reconst = negative_log_likelihood_gaussian_mixture(code, gate_logits_softmax, expert_outputs_loc, expert_outputs_scale, num_experts) 1492 | loss += loss_code_reconst 1493 | 1494 | # with tf.variable_scope('expert_loc', reuse=tf.AUTO_REUSE): 1495 | # ep_loc_reconst = Parallelism([DEFAULT_DEV_STRING] * num_experts, reuse=tf.AUTO_REUSE) 1496 | # expert_outputs_loc_reconst = ep_loc_reconst(expert_fn_loc, **expert_kwargs_loc_reconst) 1497 | # 1498 | # 1499 | # with tf.variable_scope('expert_scale', reuse=tf.AUTO_REUSE): 1500 | # ep_scale_reconst = Parallelism([DEFAULT_DEV_STRING] * num_experts, reuse=tf.AUTO_REUSE) 1501 | # expert_outputs_scale_reconst = ep_scale_reconst(expert_fn_scale, **expert_kwargs_scale_reconst) 1502 | 1503 | # code_reconst = tf.map_fn(lambda x: sample(x[0], x[1]), (expert_outputs_loc_reconst, expert_outputs_scale_reconst), dtype=tf.float32) 1504 | # code_reconst = tf.stack(code_reconst) 1505 | 1506 | 1507 | # code_bottleneck = tf.layers.dense(tf.nn.softmax(tf.layers.batch_normalization(logits)), 2, use_bias=False) 1508 | # code_reconst = tf.layers.dense(code_bottleneck, code_size, use_bias=False) 1509 | 1510 | # loc = tf.layers.dense(gates, 2, kernel_initializer=init(), use_bias=False) 1511 | # scale = tf.layers.dense(gates, 2, tf.nn.softplus, kernel_initializer=init(), use_bias=False) 1512 | # code_bottleneck = sample(loc, scale) 1513 | # code_reconst = tf.layers.dense(code_bottleneck, code_size, use_bias=False) 1514 | 1515 | # input_code_reconst = tf.expand_dims(tf.nn.softmax(tf.layers.batch_normalization(logits)), -1) 1516 | # 1517 | # code_bottleneck = tf.layers.conv1d(inputs=input_code_reconst, 1518 | # filters=code_size-1, 1519 | # kernel_size = num_experts, 1520 | # kernel_initializer=init(), 1521 | # use_bias=False) 1522 | # code_bottleneck = tf.transpose(code_bottleneck, [0, 2, 1]) 1523 | # 1524 | # code_reconst = tf.layers.conv1d(inputs=code_bottleneck, 1525 | # filters=code_size, 1526 | # kernel_size=code_size-1, 1527 | # kernel_initializer=init(), 1528 | # use_bias=False) 1529 | # code_reconst = tf.transpose(code_reconst, [0, 2, 1]) 1530 | # code_reconst = tf.reshape(code_reconst, shape=[batch_size, code_size]) 1531 | 1532 | # code_bottleneck = tf.layers.dense(tf.nn.softmax(logits), np.min([num_experts, code_size-1]), kernel_initializer=init(), use_bias=False) 1533 | # code_bottleneck = tf.layers.dense(gate_logits_softmax, 2, kernel_initializer=init(), use_bias=False) 1534 | # 1535 | # loc_code_reconst = tf.layers.dense(gate_logits_softmax, code_size, kernel_initializer=init(), use_bias=False) 1536 | # scale_code_reconst = tf.layers.dense(gate_logits_softmax, code_size, tf.nn.softplus, kernel_initializer=init(), use_bias=False) 1537 | # code_reconst = sample(loc_code_reconst, scale_code_reconst) 1538 | 1539 | # loss_code_reconst = loss_coef_reconst_code*tf.reduce_sum(tf.losses.absolute_difference(code, code_reconst)) 1540 | 1541 | # dist_code = tfd.MultivariateNormalDiag(loc=loc_code, scale_diag=scale_code) 1542 | # dist_code_reconst = tfd.MultivariateNormalDiag(loc=loc_code_reconst, scale_diag=scale_code_reconst) 1543 | # 1544 | # loss_code_reconst = loss_coef_reconst_code * tf.reduce_sum(tf.distributions.kl_divergence(dist_code, dist_code_reconst)) 1545 | # 1546 | # loss += loss_code_reconst 1547 | 1548 | # loss_code_reconst = loss_coef_reconst_code * (tf.reduce_sum(tf.losses.absolute_difference(loc_code, loc_code_reconst)) + tf.losses.absolute_difference(scale_code, scale_code_reconst)) 1549 | # loss_code_reconst = loss_coef_reconst_code * tf.reduce_sum(tf.losses.absolute_difference(loc_code, loc_code_reconst)) 1550 | # loss_code_reconst += loss_coef_reconst_code * tf.reduce_sum(tf.losses.absolute_difference(code, code_reconst)) 1551 | # 1552 | # loss_code_reconst = 1553 | # 1554 | # 1555 | # 1556 | # loss += loss_code_reconst 1557 | 1558 | # loss += loss_coef_reconst_code * tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=code, logits=code_reconst)) 1559 | 1560 | if regularize_entropy_on_gating_logits: 1561 | softmax_gating_logits = tf.nn.softmax(logits) 1562 | entropy_gating = tf.reduce_sum(-tf.reduce_sum(softmax_gating_logits * tf.log(softmax_gating_logits), axis=1)) 1563 | 1564 | loss += entropy_gating 1565 | 1566 | # loss_code_reconst = tsne_repel(num_experts, gates, tf.cast(batch_size, dtype=tf.float32), p_transition_prob) 1567 | # loss += loss_code_reconst 1568 | 1569 | 1570 | loss *= loss_coef 1571 | 1572 | return y, loss, gates, load, expert_outputs, logits, distance_matrix, expert_outputs_moe_loss_mean, expert_outputs_moe_loss_std, code_reconst, loss_code_reconst, expert_outputs_loc, expert_outputs_scale 1573 | 1574 | 1575 | def negative_log_likelihood_gaussian_mixture(code, gates, experts_loc, experts_scale, num_expert): 1576 | loss = tf.zeros((1,), dtype=tf.float32) 1577 | 1578 | pdf_eval = list() 1579 | for i in range(num_expert): 1580 | pdf_eval.append(tf.map_fn(lambda x: tfd.MultivariateNormalDiag(loc=x[0], scale_diag=x[1]).prob(x[2]), (experts_loc[i], experts_scale[i], code), dtype=tf.float32)) 1581 | 1582 | pdf_eval = tf.stack(pdf_eval) 1583 | 1584 | gate_times_pdf = tf.multiply(pdf_eval, tf.transpose(gates)) 1585 | 1586 | sum_components = tf.reduce_sum(gate_times_pdf, axis=0) 1587 | sum_components_log = tf.log(tf.clip_by_value(sum_components, 1e-10, 1e10)) 1588 | loss = -tf.reduce_sum(sum_components_log) 1589 | 1590 | 1591 | return(loss) 1592 | 1593 | 1594 | def dist_to(samples, target): 1595 | dists = tf.map_fn(lambda x: tf.abs(target-x), samples, dtype=tf.float32) 1596 | return(tf.reduce_mean(dists)) 1597 | 1598 | 1599 | def pairwise_dist(X): 1600 | """ 1601 | Computes pairwise distance between each pair of points 1602 | Args: 1603 | X - [N,D] matrix representing N D-dimensional vectors 1604 | Returns: 1605 | [N,N] matrix of (squared) Euclidean distances 1606 | """ 1607 | x2 = tf.reduce_sum(X * X, 1, True) 1608 | return x2 - 2 * tf.matmul(X, tf.transpose(X)) + tf.transpose(x2) 1609 | 1610 | # def pairwise_dist(A): 1611 | # """ 1612 | # Computes pairwise distances between each elements of A and each elements of B. 1613 | # Args: 1614 | # A, [m,d] matrix 1615 | # B, [n,d] matrix 1616 | # Returns: 1617 | # D, [m,n] matrix of pairwise distances 1618 | # """ 1619 | # with tf.variable_scope('pairwise_dist'): 1620 | # A = tf.squeeze(tf.stack(A)) 1621 | # B = tf.identity(A) 1622 | # 1623 | # # squared norms of each row in A and B 1624 | # na = tf.reduce_sum(tf.square(A), 1) 1625 | # nb = tf.reduce_sum(tf.square(B), 1) 1626 | # 1627 | # # na as a row and nb as a co"lumn vectors 1628 | # na = tf.reshape(na, [-1, 1]) 1629 | # nb = tf.reshape(nb, [1, -1]) 1630 | # 1631 | # # return pairwise euclidead difference matrix 1632 | # D = tf.sqrt(tf.maximum(na - 2 * tf.matmul(A, B, False, True) + nb, 0.0)) 1633 | # return D 1634 | 1635 | 1636 | 1637 | def local_moe_encoder(x, 1638 | train, 1639 | expert_fn, 1640 | num_experts, 1641 | k=1, 1642 | loss_coef=1e-2, 1643 | pass_x=True, 1644 | pass_gates=False, 1645 | additional_dispatch_params=None, 1646 | noisy_gating=True, 1647 | noise_eps=1e-2, 1648 | name=None): 1649 | """Call a local mixture of experts. 1650 | 1651 | Args: 1652 | x: a tensors with shape [... , input_size] 1653 | train: a boolean scalar. 1654 | expert_fn: a function. 1655 | num_experts: an integer - number of experts 1656 | k: an integer - how many experts to use for each batch element 1657 | loss_coef: a scalar - multiplier on load-balancing losses 1658 | pass_x: a boolean. If true, x will also be dispatched to the experts. 1659 | pass_gates: a boolean. If true, gates will be passed to experts. Might be 1660 | necessary when dealing with sparse encoder-encoder decoder attention 1661 | additional_dispatch_params: The extra tensors that need to be sent to each 1662 | expert. Examples include batch batch coordinates (see 1663 | common_attention.local_expert_attention) 1664 | name: a string 1665 | 1666 | Returns: 1667 | y: a tensor. Has the same shape as x, except for the last dimension, 1668 | which is output_size. 1669 | extra_training_loss: a scalar. This should be added into the overall 1670 | training loss of the model. The backpropagation of this loss 1671 | encourages all experts to be approximately equally used across a batch. 1672 | """ 1673 | 1674 | with tf.variable_scope(name, default_name="local_moe"): 1675 | x_flat = flatten_all_but_last(x) 1676 | 1677 | # The gates indicate which batch elements go to which tensors. 1678 | # load is a measure of approximately how many examples go to each expert 1679 | gates, load, logits = noisy_top_k_gating( 1680 | x_flat, 1681 | num_experts, 1682 | train, 1683 | k, 1684 | initializer=tf.zeros_initializer(), 1685 | noisy_gating=noisy_gating, 1686 | noise_epsilon=noise_eps) 1687 | # This magic object helps us shuffle data between datashards and experts. 1688 | dispatcher = SparseDispatcher(num_experts, gates) 1689 | 1690 | # Set up expert_fn arguments 1691 | expert_kwargs = {} 1692 | if pass_x: 1693 | expert_kwargs["x"] = dispatcher.dispatch(x_flat) 1694 | if pass_gates: 1695 | expert_kwargs["gates"] = dispatcher.expert_to_gates() 1696 | for k, v in six.iteritems(additional_dispatch_params or {}): 1697 | v = flatten_all_but_last(v) 1698 | expert_kwargs[k] = dispatcher.dispatch(v) 1699 | 1700 | ep = Parallelism([DEFAULT_DEV_STRING] * num_experts, reuse=None) 1701 | expert_outputs, loc, scale = ep(expert_fn, **expert_kwargs) 1702 | 1703 | y_flat = dispatcher.combine(expert_outputs) 1704 | loc_flat = dispatcher.combine(loc) 1705 | scale_flat = dispatcher.combine(scale) 1706 | 1707 | y = reshape_like(y_flat, x) 1708 | 1709 | # importance = tf.reduce_sum(gates, 0) 1710 | # loss = loss_coef * (cv_squared(importance) + cv_squared(load)) 1711 | loss = loss_coef * cv_squared(load) 1712 | return y, loss, gates, load, logits, loc_flat, scale_flat 1713 | 1714 | 1715 | class TruncatingDispatcher(object): 1716 | """Helper for implementing a mixture of experts. 1717 | 1718 | A TruncatingDispatcher is useful when you need to deal with 1719 | fixed-sized Tensors. As opposed to a SparseDispatcher, which 1720 | produces batches of different sizes for the different experts, the 1721 | TruncatingDispatcher always produces batches of the same given size, 1722 | and the results are returned stacked in one big tensor. 1723 | 1724 | In the case where an expert is over-capacity, the last items that 1725 | should have gone to that expert are dropped. 1726 | 1727 | Confusingly, the inputs to a TruncatingDispatcher have both a 1728 | "batch" and a "length" dimension. Not only does each expert receive 1729 | the same total number of examples, it also receives the same number 1730 | of examples for each element of "batch". This behavior is necessary 1731 | for applications such as grouped attention, where we have a batch of 1732 | sequences, and we want each sequence to be divided evenly among 1733 | experts. For simpler applications like mixture-of-experts, you can 1734 | reshape the input so that the "batch" dimension is 1, and only the 1735 | "length" dimension is used. 1736 | """ 1737 | 1738 | @add_name_scope("truncating_dispatcher") 1739 | def __init__(self, requests, expert_capacity): 1740 | """Create a TruncatingDispatcher. 1741 | 1742 | Args: 1743 | requests: a boolean `Tensor` of shape `[batch, length, num_experts]`. 1744 | Alternatively, a float or int Tensor containing zeros and ones. 1745 | expert_capacity: a Scalar - maximum number of examples per expert per 1746 | batch element. 1747 | 1748 | Returns: 1749 | a TruncatingDispatcher 1750 | """ 1751 | self._requests = tf.to_float(requests) 1752 | self._expert_capacity = expert_capacity 1753 | expert_capacity_f = tf.to_float(expert_capacity) 1754 | self._batch, self._length, self._num_experts = tf.unstack( 1755 | tf.shape(self._requests), num=3) 1756 | 1757 | # [batch, length, num_experts] 1758 | position_in_expert = tf.cumsum(self._requests, axis=1, exclusive=True) 1759 | # [batch, length, num_experts] 1760 | self._gates = self._requests * tf.to_float( 1761 | tf.less(position_in_expert, expert_capacity_f)) 1762 | batch_index = tf.reshape( 1763 | tf.to_float(tf.range(self._batch)), [self._batch, 1, 1]) 1764 | length_index = tf.reshape( 1765 | tf.to_float(tf.range(self._length)), [1, self._length, 1]) 1766 | expert_index = tf.reshape( 1767 | tf.to_float(tf.range(self._num_experts)), 1768 | [1, 1, self._num_experts]) 1769 | # position in a Tensor with shape [batch * num_experts * expert_capacity] 1770 | flat_position = (position_in_expert + batch_index * 1771 | (tf.to_float(self._num_experts) * expert_capacity_f) + 1772 | expert_index * expert_capacity_f) 1773 | # Tensor of shape [batch * num_experts * expert_capacity]. 1774 | # each element is an integer in [0, length) 1775 | self._indices = tf.unsorted_segment_sum( 1776 | data=tf.reshape((length_index + 1.0) * self._gates, [-1]), 1777 | segment_ids=tf.to_int32(tf.reshape(flat_position, [-1])), 1778 | num_segments=self._batch * self._num_experts * expert_capacity) 1779 | self._indices = tf.reshape( 1780 | self._indices, [self._batch, self._num_experts, expert_capacity]) 1781 | # Tensors of shape [batch, num_experts, expert_capacity]. 1782 | # each element is 0.0 or 1.0 1783 | self._nonpadding = tf.minimum(self._indices, 1.0) 1784 | # each element is an integer in [0, length) 1785 | self._indices = tf.nn.relu(self._indices - 1.0) 1786 | # self._flat_indices is [batch, num_experts, expert_capacity], with values 1787 | # in [0, batch * length) 1788 | self._flat_indices = tf.to_int32(self._indices + ( 1789 | tf.reshape(tf.to_float(tf.range(self._batch)), [-1, 1, 1]) * 1790 | tf.to_float(self._length))) 1791 | self._indices = tf.to_int32(self._indices) 1792 | 1793 | @add_name_scope("truncating_dispatcher_dispatch") 1794 | def dispatch(self, inp): 1795 | """Send the inputs to the experts. 1796 | 1797 | Args: 1798 | inp: a `Tensor` of shape "[batch, length, depth]` 1799 | Returns: 1800 | a tensor with shape [batch, num_experts, expert_capacity, depth] 1801 | """ 1802 | inp = tf.reshape(inp, [self._batch * self._length, -1]) 1803 | # [batch, num_experts, expert_capacity, depth] 1804 | ret = tf.gather(inp, self._flat_indices) 1805 | return ret 1806 | 1807 | @add_name_scope("truncating_dispatcher_combine") 1808 | def combine(self, x): 1809 | """Return the output from the experts. 1810 | 1811 | When one example goes to multiple experts, the outputs are summed. 1812 | 1813 | Args: 1814 | x: a Tensor with shape [batch, num_experts, expert_capacity, depth] 1815 | 1816 | Returns: 1817 | a `Tensor` with shape `[batch, length, depth] 1818 | """ 1819 | depth = tf.shape(x)[-1] 1820 | x *= tf.expand_dims(self._nonpadding, -1) 1821 | ret = tf.unsorted_segment_sum( 1822 | x, self._flat_indices, num_segments=self._batch * self._length) 1823 | ret = tf.reshape(ret, [self._batch, self._length, depth]) 1824 | return ret 1825 | 1826 | def nonpadding(self): 1827 | """Which elements of a dispatched Tensor are not padding. 1828 | 1829 | Returns: 1830 | a Zero/One float tensor with shape [batch, num_experts, expert_capacity]. 1831 | """ 1832 | return self._nonpadding 1833 | 1834 | def gates(self): 1835 | """A Tensor indicating which examples go to which experts. 1836 | 1837 | Returns: 1838 | A float32 Tensor with shape [batch, length, num_experts], where each value 1839 | is 0.0 or 1.0. 1840 | """ 1841 | return self._gates 1842 | 1843 | def length_coordinate(self): 1844 | """Length coordinate of dispatched tensor. 1845 | 1846 | Returns: 1847 | a tensor with shape [batch, num_experts, expert_capacity] containing 1848 | integers in the range [0, length) 1849 | """ 1850 | return self._indices 1851 | 1852 | 1853 | def should_generate_summaries(): 1854 | """Is this an appropriate context to generate summaries. 1855 | 1856 | Returns: 1857 | a boolean 1858 | """ 1859 | if "while/" in tf.contrib.framework.get_name_scope(): 1860 | # Summaries don't work well within tf.while_loop() 1861 | return False 1862 | if tf.get_variable_scope().reuse: 1863 | # Avoid generating separate summaries for different data shards 1864 | return False 1865 | return True 1866 | --------------------------------------------------------------------------------