├── README.md ├── codes ├── config.py ├── demo.py ├── helpers.py └── vasc.py └── data ├── biase.txt └── biase_label.txt /README.md: -------------------------------------------------------------------------------- 1 | # VASC 2 | #### Variational autoencoder for single cell RNA-seq datasets 3 | 4 | Single cell RNA sequencing (scRNA-seq) is a powerful technique to analyze the transcriptomic heterogeneities in single cell level. It is an important step for studying the cell sub-populations and lineages from scRNA-seq data by finding an effective low-dimensional representation and visualization of the original data. The scRNA-seq data are more “noisy” than traditional bulk RNA-seq: in the single cell level, the transcriptional fluctuations are much larger than the average of a cell population and the low amount of RNA transcripts will increase the rate of technical dropout events. In this study, we proposed VASC (deep Variational Autoencoder for SCRNA-seq data), a deep multi-layer generative model, for the dimension reduction and visualization. It can do nonlinear hierarchical feature representations and model the dropout events of scRNA-seq data. Tested on more than twenty datasets, VASC show better performances in most cases and higher stability compared with several dimension reduction methods. VASC successfully re-establishes the embryo pre-implantation cell lineage and its associated genes based on the 2D representation of a large-scale scRNA-seq from human embryos. 5 | 6 | ## Prerequisites 7 | + Python 3.5+ 8 | + numpy 1.12.1 9 | + h5py 2.7.0 10 | + sklearn 0.18.1 11 | + tensorflow 1.1.0 12 | + keras 2.0.6 13 | 14 | We recommend to install the newest Anaconda from https://www.continuum.io/downloads. 15 | 16 | ## Codes 17 | Two python files are included: 18 | - vasc.py: contains a class VASC and a function vase 19 | - helpers.py: auxiliary functions 20 | 21 | ## Demo 22 | We gave a demo.py and config.py to demonstrate the use of VASC. 23 | 24 | ## Data 25 | A small dataset from Biase is included for demonstration. 26 | -------------------------------------------------------------------------------- /codes/config.py: -------------------------------------------------------------------------------- 1 | config={ 2 | 'epoch':10000, 3 | 'batch_size':256, 4 | 'latent':2, 5 | 'log':False, 6 | 'scale':True, 7 | 'patience':50 8 | } -------------------------------------------------------------------------------- /codes/demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | from vasc import vasc 4 | from helpers import clustering,measure,print_2D 5 | from config import config 6 | 7 | if __name__ == '__main__': 8 | DATASET = 'biase' #sys.argv[1] 9 | PREFIX = 'biase' #sys.argv[2] 10 | 11 | filename = DATASET+'.txt' 12 | data = open( filename ) 13 | head = data.readline().rstrip().split() 14 | 15 | label_file = open( DATASET+'_label.txt' ) 16 | label_dict = {} 17 | for line in label_file: 18 | temp = line.rstrip().split() 19 | label_dict[temp[0]] = temp[1] 20 | label_file.close() 21 | 22 | label = [] 23 | for c in head: 24 | if c in label_dict.keys(): 25 | label.append(label_dict[c]) 26 | else: 27 | print(c) 28 | 29 | label_set = [] 30 | for c in label: 31 | if c not in label_set: 32 | label_set.append(c) 33 | name_map = {value:idx for idx,value in enumerate(label_set)} 34 | id_map = {idx:value for idx,value in enumerate(label_set)} 35 | label = np.asarray( [ name_map[name] for name in label ] ) 36 | 37 | expr = [] 38 | for line in data: 39 | temp = line.rstrip().split()[1:] 40 | temp = [ float(x) for x in temp] 41 | expr.append( temp ) 42 | 43 | expr = np.asarray(expr).T 44 | n_cell,_ = expr.shape 45 | if n_cell > 150: 46 | batch_size=config['batch_size'] 47 | else: 48 | batch_size=32 49 | #expr = np.exp(expr) - 1 50 | #expr = expr / np.max(expr) 51 | 52 | # 53 | # percentage = [0.5] 54 | # 55 | # for j in range(1): 56 | # print(j) 57 | # p = percentage[j] 58 | # samples = np.random.choice( n_cell,size=int(n_cell*p),replace=True ) 59 | # expr_train = expr[ samples,: ] 60 | # label_train = label[samples] 61 | 62 | #latent = 2 63 | for i in range(1): 64 | print("Iteration:"+str(i)) 65 | res = vasc( expr,var=False, 66 | latent=config['latent'], 67 | annealing=False, 68 | batch_size=batch_size, 69 | prefix=PREFIX, 70 | label=label, 71 | scale=config['scale'], 72 | patience=config['patience'] 73 | ) 74 | # res_file = PREFIX+'_res.h5' 75 | # res_data = h5py.File( name=res_file,mode='r' ) 76 | # dim2 = res_data['RES5'] 77 | # print(np.max(dim2)) 78 | 79 | print(res.shape) 80 | k = len( np.unique(label) ) 81 | cl,_ = clustering( res,k=k) 82 | dm = measure( cl,label ) 83 | 84 | # res_data.close() 85 | ### analysis results 86 | # plot loss 87 | 88 | # plot 2-D visulation 89 | fig = print_2D( points=res,label=label,id_map=id_map ) 90 | # fig.savefig('embryo.eps') 91 | # fig = print_2D( points=res_data['RES5'],label=label,id_map=id_map ) 92 | # fig.show() 93 | # res_data.close() 94 | # time.sleep(30) 95 | #res_data.close() 96 | # plot NMI,ARI curve 97 | # 98 | # pollen = h5py.File( name=DATASET+'_'+str(latent)+'_.h5',mode='w' ) 99 | # pollen.create_dataset( name='NMI',data=nmi) 100 | # pollen.create_dataset( name='ARI',data=ari ) 101 | # pollen.create_dataset( name='HOM',data=hom ) 102 | # pollen.create_dataset( name='COM',data=com ) 103 | # pollen.close() 104 | # 105 | 106 | #print("============SUMMARY==============") 107 | #k = len(np.unique(label)) 108 | #for r in res: 109 | # print("======"+str(r.shape[1])+"========") 110 | # pred,si = clustering( r,k=k ) 111 | # if label is not None: 112 | # metrics = measure( pred,label ) 113 | # -------------------------------------------------------------------------------- /codes/helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib as mpl 3 | #mpl.use('Agg') 4 | 5 | import matplotlib.pyplot as plt 6 | 7 | #plt.ioff() 8 | 9 | import seaborn as sns 10 | from pandas import DataFrame 11 | from sklearn.metrics import normalized_mutual_info_score,adjusted_rand_score,homogeneity_score,completeness_score,silhouette_score 12 | from sklearn.cluster import KMeans 13 | from sklearn.cluster import SpectralClustering 14 | from sklearn.decomposition import PCA 15 | from sklearn.covariance import EllipticEnvelope 16 | 17 | 18 | markers = {',': 'pixel', 'o': 'circle','*': 'star', 'v': 'triangle_down', 19 | '^': 'triangle_up', '<': 'triangle_left', '>': 'triangle_right', 20 | '1': 'tri_down', '2': 'tri_up', '3': 'tri_left', '4': 'tri_right', 21 | '8': 'octagon', 's': 'square', 'p': 'pentagon', 22 | 'h': 'hexagon1', 'H': 'hexagon2', '+': 'plus', 'x': 'x', '.': 'point', 23 | 'D': 'diamond', 'd': 'thin_diamond', '|': 'vline', '_': 'hline', 24 | 'P': 'plus_filled', 'X': 'x_filled', 0: 'tickleft', 25 | 1: 'tickright', 2: 'tickup', 3: 'tickdown', 4: 'caretleft', 5: 'caretright', 26 | 6: 'caretup', 7: 'caretdown', 8: 'caretleftbase', 9: 'caretrightbase', 10: 'caretupbase', 27 | 11: 'caretdownbase', 'None': 'nothing', None: 'nothing', ' ': 'nothing', '': 'nothing'} 28 | markers_keys = list(markers.keys())[:20] 29 | 30 | font = {'family' : 'normal', 31 | 'weight' : 'bold', 32 | 'size' : 30} 33 | 34 | mpl.rc('font', **font) 35 | 36 | sns.set_style("ticks") 37 | 38 | colors = ["windows blue", "amber", 39 | "greyish", "faded green", 40 | "dusty purple","royal blue","lilac", 41 | "salmon","bright turquoise", 42 | "dark maroon","light tan", 43 | "orange","orchid", 44 | "sandy","topaz", 45 | "fuchsia","yellow", 46 | "crimson","cream" 47 | ] 48 | current_palette = sns.xkcd_palette(colors) 49 | 50 | def print_2D( points,label,id_map ): 51 | ''' 52 | points: N_samples * 2 53 | label: (int) N_samples 54 | id_map: map label id to its name 55 | ''' 56 | fig = plt.figure() 57 | #current_palette = sns.color_palette("RdBu_r", max(label)+1) 58 | n_cell,_ = points.shape 59 | if n_cell > 500: 60 | s = 10 61 | else: 62 | s = 20 63 | 64 | ax = plt.subplot(111) 65 | print( np.unique(label) ) 66 | for i in np.unique(label): 67 | ax.scatter( points[label==i,0], points[label==i,1], c=current_palette[i], label=id_map[i], s=s,marker=markers_keys[i] ) 68 | box = ax.get_position() 69 | ax.set_position([box.x0, box.y0 + box.height * 0.1, 70 | box.width, box.height * 0.9]) 71 | 72 | ax.legend(scatterpoints=1,loc='upper center', 73 | bbox_to_anchor=(0.5,-0.08),ncol=6, 74 | fancybox=True, 75 | prop={'size':8} 76 | ) 77 | sns.despine() 78 | return fig 79 | 80 | def print_heatmap( points,label,id_map ): 81 | ''' 82 | points: N_samples * N_features 83 | label: (int) N_samples 84 | id_map: map label id to its name 85 | ''' 86 | # = sns.color_palette("RdBu_r", max(label)+1) 87 | #cNorm = colors.Normalize(vmin=0,vmax=max(label)) #normalise the colormap 88 | #scalarMap = cm.ScalarMappable(norm=cNorm,cmap='Paired') #map numbers to colors 89 | 90 | index = [id_map[i] for i in label] 91 | df = DataFrame( 92 | points, 93 | columns = list(range(points.shape[1])), 94 | index = index 95 | ) 96 | row_color = [current_palette[i] for i in label] 97 | 98 | cmap = sns.cubehelix_palette(as_cmap=True, rot=-.3, light=1) 99 | g = sns.clustermap( df,cmap=cmap,row_colors=row_color,col_cluster=False,xticklabels=False,yticklabels=False) #,standard_scale=1 ) 100 | 101 | return g.fig 102 | 103 | def measure( predicted,true ): 104 | NMI = normalized_mutual_info_score( true,predicted ) 105 | print("NMI:"+str(NMI)) 106 | RAND = adjusted_rand_score( true,predicted ) 107 | print("RAND:"+str(RAND)) 108 | HOMO = homogeneity_score( true,predicted ) 109 | print("HOMOGENEITY:"+str(HOMO)) 110 | COMPLETENESS = completeness_score( true,predicted ) 111 | print("COMPLETENESS:"+str(COMPLETENESS)) 112 | return {'NMI':NMI,'RAND':RAND,'HOMOGENEITY':HOMO,'COMPLETENESS':COMPLETENESS} 113 | 114 | def clustering( points, k=2,name='kmeans'): 115 | ''' 116 | points: N_samples * N_features 117 | k: number of clusters 118 | ''' 119 | if name == 'kmeans': 120 | kmeans = KMeans( n_clusters=k,n_init=100 ).fit(points) 121 | ## print within_variance 122 | #cluster_distance = kmeans.transform( points ) 123 | #within_variance = sum( np.min(cluster_distance,axis=1) ) / float( points.shape[0] ) 124 | #print("AvgWithinSS:"+str(within_variance)) 125 | if len( np.unique(kmeans.labels_) ) > 1: 126 | si = silhouette_score( points,kmeans.labels_ ) 127 | #print("Silhouette:"+str(si)) 128 | else: 129 | si = 0 130 | print("Silhouette:"+str(si)) 131 | return kmeans.labels_,si 132 | 133 | if name == 'spec': 134 | spec= SpectralClustering( n_clusters=k,affinity='cosine' ).fit( points ) 135 | si = silhouette_score( points,spec.labels_ ) 136 | print("Silhouette:"+str(si)) 137 | return spec.labels_,si 138 | 139 | def cart2polar( points ): 140 | ''' 141 | points: N_samples * 2 142 | ''' 143 | return np.c_[np.abs(points), np.angle(points)] 144 | 145 | def outliers_detection(expr): 146 | x = PCA(n_components=2).fit_transform(expr) 147 | ee = EllipticEnvelope() 148 | ee.fit(x) 149 | oo = ee.predict(x) 150 | 151 | return oo 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | -------------------------------------------------------------------------------- /codes/vasc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from keras.layers import Input,Dense,Activation,Lambda,RepeatVector,merge,Reshape,Layer,Dropout,BatchNormalization,Permute 3 | import keras.backend as K 4 | from keras.models import Model 5 | from helpers import measure,clustering,print_2D,print_heatmap,cart2polar,outliers_detection 6 | #from keras.utils.vis_utils import plot_model 7 | from keras import regularizers 8 | from keras.utils.layer_utils import print_summary 9 | import numpy as np 10 | from keras.optimizers import RMSprop,Adagrad,Adam 11 | from keras import metrics 12 | from config import config 13 | import h5py 14 | 15 | tau = 1.0 16 | 17 | def sampling(args): 18 | epsilon_std = 1.0 19 | 20 | if len(args) == 2: 21 | z_mean, z_log_var = args 22 | epsilon = K.random_normal(shape=K.shape(z_mean), 23 | mean=0., 24 | stddev=epsilon_std) 25 | # 26 | return z_mean + K.exp( z_log_var / 2 ) * epsilon 27 | else: 28 | z_mean = args[0] 29 | epsilon = K.random_normal(shape=K.shape(z_mean), 30 | mean=0., 31 | stddev=epsilon_std) 32 | return z_mean + K.exp( 1.0 / 2 ) * epsilon 33 | 34 | 35 | def sampling_gumbel(shape,eps=1e-8): 36 | u = K.random_uniform( shape ) 37 | return -K.log( -K.log(u+eps)+eps ) 38 | 39 | def compute_softmax(logits,temp): 40 | z = logits + sampling_gumbel( K.shape(logits) ) 41 | return K.softmax( z / temp ) 42 | 43 | def gumbel_softmax(args): 44 | logits,temp = args 45 | return compute_softmax(logits,temp) 46 | 47 | class NoiseLayer(Layer): 48 | def __init__(self, ratio, **kwargs): 49 | super(NoiseLayer, self).__init__(**kwargs) 50 | self.supports_masking = True 51 | self.ratio = ratio 52 | 53 | def call(self, inputs, training=None): 54 | def noised(): 55 | return inputs * K.random_binomial(shape=K.shape(inputs), 56 | p=self.ratio 57 | ) 58 | return K.in_train_phase(noised, inputs, training=training) 59 | 60 | def get_config(self): 61 | config = {'ratio': self.ratio} 62 | base_config = super(NoiseLayer, self).get_config() 63 | return dict(list(base_config.items()) + list(config.items())) 64 | 65 | 66 | return dict(list(base_config.items()) + list(config.items())) 67 | 68 | class VASC: 69 | def __init__(self,in_dim,latent=2,var=False): 70 | self.in_dim =in_dim 71 | self.vae = None 72 | self.ae = None 73 | self.aux = None 74 | self.latent = latent 75 | self.var = var 76 | 77 | 78 | def vaeBuild( self ): 79 | var_ = self.var 80 | in_dim = self.in_dim 81 | expr_in = Input( shape=(self.in_dim,) ) 82 | 83 | ##### The first part of model to recover the expr. 84 | h0 = Dropout(0.5)(expr_in) 85 | ## Encoder layers 86 | h1 = Dense( units=512,name='encoder_1',kernel_regularizer=regularizers.l1(0.01) )(h0) 87 | h2 = Dense( units=128,name='encoder_2' )(h1) 88 | h2_relu = Activation('relu')(h2) 89 | h3 = Dense( units=32,name='encoder_3' )(h2_relu) 90 | h3_relu = Activation('relu')(h3) 91 | 92 | 93 | z_mean = Dense( units= self.latent ,name='z_mean' )(h3_relu) 94 | if self.var: 95 | z_log_var = Dense( units=2,name='z_log_var' )(h3_relu) 96 | z_log_var = Activation( 'softplus' )(z_log_var) 97 | 98 | 99 | ## sampling new samples 100 | z = Lambda(sampling, output_shape=(self.latent,))([z_mean,z_log_var]) 101 | else: 102 | z = Lambda(sampling, output_shape=(self.latent,))([z_mean]) 103 | 104 | ## Decoder layers 105 | decoder_h1 = Dense( units=32,name='decoder_1' )(z) 106 | decoder_h1_relu = Activation('relu')(decoder_h1) 107 | decoder_h2 = Dense( units=128,name='decoder_2' )(decoder_h1_relu) 108 | decoder_h2_relu = Activation('relu')(decoder_h2) 109 | decoder_h3 = Dense( units=512,name='decoder_3' )(decoder_h2_relu) 110 | decoder_h3_relu = Activation('relu')(decoder_h3) 111 | expr_x = Dense(units=self.in_dim,activation='sigmoid')(decoder_h3_relu) 112 | 113 | 114 | expr_x_drop = Lambda( lambda x: -x ** 2 )(expr_x) 115 | #expr_x_drop_log = merge( [drop_ratio,expr_x_drop],mode='mul' ) ### log p_drop = log(exp(-\lambda x^2)) 116 | expr_x_drop_p = Lambda( lambda x:K.exp(x) )(expr_x_drop) 117 | expr_x_nondrop_p = Lambda( lambda x:1-x )( expr_x_drop_p ) 118 | expr_x_nondrop_log = Lambda( lambda x:K.log(x+1e-20) )(expr_x_nondrop_p) 119 | expr_x_drop_log = Lambda( lambda x:K.log(x+1e-20) )(expr_x_drop_p) 120 | expr_x_drop_log = Reshape( target_shape=(self.in_dim,1) )(expr_x_drop_log) 121 | expr_x_nondrop_log = Reshape( target_shape=(self.in_dim,1) )(expr_x_nondrop_log) 122 | logits = merge( [expr_x_drop_log,expr_x_nondrop_log],mode='concat',concat_axis=-1 ) 123 | 124 | temp_in = Input( shape=(self.in_dim,) ) 125 | temp_ = RepeatVector( 2 )(temp_in) 126 | print(temp_.shape) 127 | temp_ = Permute( (2,1) )(temp_) 128 | samples = Lambda( gumbel_softmax,output_shape=(self.in_dim,2,) )( [logits,temp_] ) 129 | samples = Lambda( lambda x:x[:,:,1] )(samples) 130 | samples = Reshape( target_shape=(self.in_dim,) )(samples) 131 | ## #print(samples.shape) 132 | 133 | out = merge( [expr_x,samples],mode='mul' ) 134 | 135 | class VariationalLayer(Layer): 136 | def __init__(self, **kwargs): 137 | self.is_placeholder = True 138 | super(VariationalLayer, self).__init__(**kwargs) 139 | 140 | def vae_loss(self, x, x_decoded_mean): 141 | xent_loss = in_dim * metrics.binary_crossentropy(x, x_decoded_mean) 142 | if var_: 143 | kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1) 144 | else: 145 | kl_loss = - 0.5 * K.sum(1 + 1 - K.square(z_mean) - K.exp(1.0), axis=-1) 146 | return K.mean(xent_loss + kl_loss) 147 | 148 | def call(self, inputs): 149 | x = inputs[0] 150 | x_decoded_mean = inputs[1] 151 | loss = self.vae_loss(x, x_decoded_mean) 152 | self.add_loss(loss, inputs=inputs) 153 | # We won't actually use the output. 154 | return x 155 | 156 | y = VariationalLayer()([expr_in, out]) 157 | vae = Model( inputs= [expr_in,temp_in],outputs=y ) 158 | 159 | opt = RMSprop( lr=0.001 ) 160 | vae.compile( optimizer=opt,loss=None ) 161 | 162 | ae = Model( inputs=[expr_in,temp_in],outputs=[ h1,h2,h3,h2_relu,h3_relu, 163 | z_mean,z,decoder_h1,decoder_h1_relu, 164 | decoder_h2,decoder_h2_relu,decoder_h3,decoder_h3_relu, 165 | samples,out 166 | ] ) 167 | aux = Model( inputs=[expr_in,temp_in],outputs=[out] ) 168 | 169 | self.vae = vae 170 | self.ae = ae 171 | self.aux = aux 172 | 173 | def vasc( expr, 174 | epoch = 5000, 175 | latent=2, 176 | patience=50, 177 | min_stop=500, 178 | batch_size=32, 179 | var = False, 180 | prefix='test', 181 | label=None, 182 | log=True, 183 | scale=True, 184 | annealing=False, 185 | tau0 = 1.0, 186 | min_tau = 0.5, 187 | rep=0): 188 | ''' 189 | VASC: variational autoencoder for scRNA-seq datasets 190 | 191 | ============ 192 | Parameters: 193 | expr: expression matrix (n_cells * n_features) 194 | epoch: maximum number of epochs, default 5000 195 | latent: dimension of latent variables, default 2 196 | patience: stop if loss showes insignificant decrease within *patience* epochs, default 50 197 | min_stop: minimum number of epochs, default 500 198 | batch_size: batch size for stochastic optimization, default 32 199 | var: whether to estimate the variance parameters, default False 200 | prefix: prefix to store the results, default 'test' 201 | label: numpy array of true labels, default None 202 | log: if log-transformation should be performed, default True 203 | scale: if scaling (making values within [0,1]) should be performed, default True 204 | annealing: if annealing should be performed for Gumbel approximation, default False 205 | tau0: initial temperature for annealing or temperature without annealing, default 1.0 206 | min_tau: minimal tau during annealing, default 0.5 207 | rep: not used 208 | 209 | ============= 210 | Values: 211 | point: dimension-*latent* results 212 | A file named (*prefix*_*latent*_res.h5): we prefer to use this file to analyse results to the only return values. 213 | This file included the following keys: 214 | POINTS: all intermediated latent results during the iterations 215 | LOSS: loss values during the training procedure 216 | RES*i*: i from 0 to 14 217 | - hidden values just for reference 218 | We recommend use POINTS and LOSS to select the final results in terms of users' preference. 219 | ''' 220 | 221 | 222 | expr[expr<0] = 0.0 223 | 224 | if log: 225 | expr = np.log2( expr + 1 ) 226 | if scale: 227 | for i in range(expr.shape[0]): 228 | expr[i,:] = expr[i,:] / np.max(expr[i,:]) 229 | 230 | # if outliers: 231 | # o = outliers_detection(expr) 232 | # expr = expr[o==1,:] 233 | # if label is not None: 234 | # label = label[o==1] 235 | 236 | 237 | if rep > 0: 238 | expr_train = np.matlib.repmat( expr,rep,1 ) 239 | else: 240 | expr_train = np.copy( expr ) 241 | 242 | vae_ = VASC( in_dim=expr.shape[1],latent=latent,var=var ) 243 | vae_.vaeBuild() 244 | #print_summary( vae_.vae ) 245 | 246 | points = [] 247 | loss = [] 248 | prev_loss = np.inf 249 | #tau0 = 1. 250 | tau = tau0 251 | #min_tau = 0.5 252 | anneal_rate = 0.0003 253 | for e in range(epoch): 254 | cur_loss = prev_loss 255 | 256 | #mask = np.ones( expr_train.shape,dtype='float32' ) 257 | #mask[ expr_train==0 ] = 0.0 258 | if e % 100 == 0 and annealing: 259 | tau = max( tau0*np.exp( -anneal_rate * e),min_tau ) 260 | print(tau) 261 | 262 | tau_in = np.ones( expr_train.shape,dtype='float32' ) * tau 263 | #print(tau_in.shape) 264 | 265 | loss_ = vae_.vae.fit( [expr_train,tau_in],expr_train,epochs=1,batch_size=batch_size, 266 | shuffle=True,verbose=0 267 | ) 268 | train_loss = loss_.history['loss'][0] 269 | cur_loss = min(train_loss,cur_loss) 270 | loss.append( train_loss ) 271 | #val_loss = -loss.history['val_loss'][0] 272 | res = vae_.ae.predict([expr,tau_in]) 273 | points.append( res[5] ) 274 | if label is not None: 275 | k=len(np.unique(label)) 276 | 277 | if e % patience == 1: 278 | print( "Epoch %d/%d"%(e+1,epoch) ) 279 | print( "Loss:"+str(train_loss) ) 280 | if abs(cur_loss-prev_loss) < 1 and e > min_stop: 281 | break 282 | prev_loss = train_loss 283 | if label is not None: 284 | try: 285 | cl,_ = clustering( res[5],k=k ) 286 | measure( cl,label ) 287 | except: 288 | print('Clustering error') 289 | 290 | # 291 | ### analysis results 292 | #cluster_res = np.asarray( cluster_res ) 293 | points = np.asarray( points ) 294 | aux_res = h5py.File( prefix+'_'+str(latent)+'_res.h5',mode='w' ) 295 | #aux_res.create_dataset( name='EXPR',data=expr ) 296 | #aux_res.create_dataset( name='CLUSTER',data=cluster_res ) 297 | aux_res.create_dataset( name='POINTS',data=points ) 298 | aux_res.create_dataset( name='LOSS',data=loss ) 299 | count = 0 300 | for r in res: 301 | aux_res.create_dataset( name='RES'+str(count),data=r) 302 | count += 1 303 | aux_res.close() 304 | 305 | return res[5] 306 | 307 | 308 | 309 | -------------------------------------------------------------------------------- /data/biase_label.txt: -------------------------------------------------------------------------------- 1 | GSM1377859 zygote 2 | GSM1377860 zygote 3 | GSM1377861 zygote 4 | GSM1377862 zygote 5 | GSM1377863 zygote 6 | GSM1377864 zygote 7 | GSM1377865 zygote 8 | GSM1377866 zygote 9 | GSM1377867 zygote 10 | GSM1377868 2cell 11 | GSM1377869 2cell 12 | GSM1377870 2cell 13 | GSM1377871 2cell 14 | GSM1377872 2cell 15 | GSM1377873 2cell 16 | GSM1377874 2cell 17 | GSM1377875 2cell 18 | GSM1377876 2cell 19 | GSM1377877 2cell 20 | GSM1377878 2cell 21 | GSM1377879 2cell 22 | GSM1377880 2cell 23 | GSM1377881 2cell 24 | GSM1377882 2cell 25 | GSM1377883 2cell 26 | GSM1377884 2cell 27 | GSM1377885 2cell 28 | GSM1377886 2cell 29 | GSM1377887 2cell 30 | GSM1377888 4cell 31 | GSM1377889 4cell 32 | GSM1377890 4cell 33 | GSM1377891 4cell 34 | GSM1377892 4cell 35 | GSM1377893 4cell 36 | GSM1377894 4cell 37 | GSM1377895 4cell 38 | GSM1377896 4cell 39 | GSM1377897 4cell 40 | GSM1377898 4cell 41 | GSM1377899 4cell 42 | GSM1377900 4cell 43 | GSM1377901 4cell 44 | GSM1377902 4cell 45 | GSM1377903 4cell 46 | GSM1377904 4cell 47 | GSM1377905 4cell 48 | GSM1377906 4cell 49 | GSM1377907 4cell 50 | GSM1377908 blast 51 | GSM1377909 blast 52 | GSM1377910 blast 53 | GSM1377911 blast 54 | GSM1377912 blast 55 | GSM1377913 blast 56 | GSM1377914 blast 57 | --------------------------------------------------------------------------------