├── src └── scProjection │ ├── scProjection.egg-info │ ├── top_level.txt │ ├── dependency_links.txt │ ├── SOURCES.txt │ └── PKG-INFO │ ├── __pycache__ │ ├── utils.cpython-35.pyc │ ├── __init__.cpython-35.pyc │ ├── vaeBackend.cpython-35.pyc │ ├── vaeTrain.cpython-35.pyc │ ├── batchBackend.cpython-35.pyc │ ├── deconvModel.cpython-35.pyc │ ├── saving_utils.cpython-35.pyc │ ├── architectures.cpython-35.pyc │ └── deconvBackend.cpython-35.pyc │ ├── __init__.py │ ├── saving_utils.py │ ├── deconvBackend.py │ ├── architectures.py │ ├── utils.py │ ├── batchBackend.py │ ├── deconvModel.py │ ├── vaeBackend.py │ └── vaeTrain.py ├── dist ├── scProjection-0.2.tar.gz ├── scProjection-0.21.tar.gz ├── scProjection-0.22.tar.gz ├── scProjection-0.2-py3-none-any.whl ├── scProjection-0.21-py3-none-any.whl └── scProjection-0.22-py3-none-any.whl ├── pyproject.toml ├── setup.cfg ├── setup.py ├── LICENSE └── README.md /src/scProjection/scProjection.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/scProjection/scProjection.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /dist/scProjection-0.2.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quon-titative-biology/scProjection/HEAD/dist/scProjection-0.2.tar.gz -------------------------------------------------------------------------------- /dist/scProjection-0.21.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quon-titative-biology/scProjection/HEAD/dist/scProjection-0.21.tar.gz -------------------------------------------------------------------------------- /dist/scProjection-0.22.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quon-titative-biology/scProjection/HEAD/dist/scProjection-0.22.tar.gz -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=54", 4 | "wheel" 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | -------------------------------------------------------------------------------- /dist/scProjection-0.2-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quon-titative-biology/scProjection/HEAD/dist/scProjection-0.2-py3-none-any.whl -------------------------------------------------------------------------------- /dist/scProjection-0.21-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quon-titative-biology/scProjection/HEAD/dist/scProjection-0.21-py3-none-any.whl -------------------------------------------------------------------------------- /dist/scProjection-0.22-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quon-titative-biology/scProjection/HEAD/dist/scProjection-0.22-py3-none-any.whl -------------------------------------------------------------------------------- /src/scProjection/__pycache__/utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quon-titative-biology/scProjection/HEAD/src/scProjection/__pycache__/utils.cpython-35.pyc -------------------------------------------------------------------------------- /src/scProjection/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quon-titative-biology/scProjection/HEAD/src/scProjection/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /src/scProjection/__pycache__/vaeBackend.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quon-titative-biology/scProjection/HEAD/src/scProjection/__pycache__/vaeBackend.cpython-35.pyc -------------------------------------------------------------------------------- /src/scProjection/__pycache__/vaeTrain.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quon-titative-biology/scProjection/HEAD/src/scProjection/__pycache__/vaeTrain.cpython-35.pyc -------------------------------------------------------------------------------- /src/scProjection/__pycache__/batchBackend.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quon-titative-biology/scProjection/HEAD/src/scProjection/__pycache__/batchBackend.cpython-35.pyc -------------------------------------------------------------------------------- /src/scProjection/__pycache__/deconvModel.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quon-titative-biology/scProjection/HEAD/src/scProjection/__pycache__/deconvModel.cpython-35.pyc -------------------------------------------------------------------------------- /src/scProjection/__pycache__/saving_utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quon-titative-biology/scProjection/HEAD/src/scProjection/__pycache__/saving_utils.cpython-35.pyc -------------------------------------------------------------------------------- /src/scProjection/__pycache__/architectures.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quon-titative-biology/scProjection/HEAD/src/scProjection/__pycache__/architectures.cpython-35.pyc -------------------------------------------------------------------------------- /src/scProjection/__pycache__/deconvBackend.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quon-titative-biology/scProjection/HEAD/src/scProjection/__pycache__/deconvBackend.cpython-35.pyc -------------------------------------------------------------------------------- /src/scProjection/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | # pylint: disable=unused-import,wildcard-import 6 | from .deconvModel import * 7 | -------------------------------------------------------------------------------- /src/scProjection/scProjection.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | README.md 2 | pyproject.toml 3 | setup.cfg 4 | setup.py 5 | src/scProjection.egg-info/PKG-INFO 6 | src/scProjection.egg-info/SOURCES.txt 7 | src/scProjection.egg-info/dependency_links.txt 8 | src/scProjection.egg-info/top_level.txt -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = scProjection 3 | version = 0.0.1 4 | author = Nelson Johansen; Gerald Quon 5 | author_email = njjohansen@ucdavis.edu; gquon@ucdavis.edu 6 | description = Projection and Deconvolution using deep heirarchical and generative neural network. 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown 9 | url = 10 | classifiers = 11 | Programming Language :: Python :: 3 12 | License :: OSI Approved :: MIT License 13 | Operating System :: OS Independent 14 | [options] 15 | package_dir = 16 | = src 17 | packages = find: 18 | python_requires = >=3.6 19 | 20 | [options.packages.find] 21 | where = src 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup(name='scProjection', 4 | version='0.22', 5 | description='Projection and Deconvolution using deep heirarchical and generative neural network.', 6 | url='https://github.com/ucdavis/quonlab/tree/master/development/deconvAllen', 7 | author='Nelson Johansen, Gerald Quon', 8 | author_email='njjohansen@ucdavis.edu, gquon@ucdavis.edu', 9 | license='MIT', 10 | classifiers=[ 11 | "Programming Language :: Python :: 3", 12 | "License :: OSI Approved :: MIT License", 13 | "Operating System :: OS Independent", 14 | ], 15 | install_requires=[ 16 | 'tensorflow', 17 | 'tensorflow_probability', 18 | 'sklearn-learn', 19 | 'numpy' 20 | ], 21 | package_dir={"": "src"}, 22 | packages=setuptools.find_packages(where="src"), 23 | python_requires=">=3.6") 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # scProjection 2 | 3 | Projecting RNA measurements onto single cell atlases to extract cell type-specific expression profiles using scProjection. Refer to our paper:https://www.nature.com/articles/s41467-023-40744-6 4 | 5 | ## Tutorials 6 | 7 | First follow the install instructions below, at the bottom of the page, before following the tutorials. 8 | 9 | [Tutorial 1: Deconvolution of CellBench mixtures](https://github.com/quon-titative-biology/examples/blob/master/scProjection_cellbench/scProjection_cellbench.md) 10 | 11 | [Tutorial 2: Deconvolution of spatial MERFISH data](https://github.com/quon-titative-biology/examples/blob/master/scProjection_spatial/MERFISH_deconv_example.md) 12 | 13 | [Tutorial 3: Projection of pseudo bulk data](https://github.com/quon-titative-biology/examples/tree/master/scProjection_pseudobulk/readme.md) 14 | 15 | [Tutorial 4: Imputation of gene expression patterns of spatial osmFISH data](https://github.com/quon-titative-biology/examples/blob/master/scProjection_imputation/readme.md) 16 | 17 | ## Install scProjection 18 | ```shell 19 | pip3 install scProjection 20 | ``` 21 | The install time should be less than 30 min. 22 | ## Package requirements 23 | 24 | scProjection requires: Python 3. This is a guide to installing python on different operating systems. 25 | 26 | ### (Python) 27 | #### All platforms: 28 | 1. [Download install binaries for Python 3 here](https://www.python.org/downloads/release/) 29 | #### Alternative (On Windows): 30 | 1. Download Python 3 31 | 2. Make sure pip is included in the installation. 32 | 33 | #### Alternative (On Ubuntu): 34 | 1. sudo apt update 35 | 2. sudo apt install python3-dev python3-pip 36 | 37 | #### Alternative (On MacOS): 38 | 1. /usr/bin/ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)" 39 | 2. export PATH="/usr/local/bin:/usr/local/sbin:$PATH" 40 | 3. brew update 41 | 4. brew install python # Python 3 42 | 43 | ## Setup of virtualenv 44 | 45 | scProjection also requires: tensorflow, tensorflow-probability, sklearn and numpy. It is generally easier to setup the dependencies using a virtual environment which can be done as follows: 46 | 47 | ```shell 48 | ## Create the virtual environment 49 | virtualenv -p python3 pyvTf2 50 | 51 | ## Launch the virtual environment 52 | source ./pyvTf2/bin/activate 53 | 54 | ## Setup dependencies 55 | pip3 install tensorflow 56 | pip3 install tensorflow-probability 57 | pip3 install scikit-learn 58 | pip3 install numpy 59 | 60 | ## Install scProjection 61 | pip3 install scProjection 62 | ``` 63 | 64 | ## Updates 65 | #### (3/16/2023) More tutorials have been added. 66 | #### (5/23/2022) Codebase from publication made public. Need to improve user interface with method. 67 | #### (11/9/2022) Added more tutorials with examples running scProjection in both R and Python 68 | -------------------------------------------------------------------------------- /src/scProjection/saving_utils.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import sys, os, glob, gzip, time 7 | from functools import partial 8 | from importlib import import_module 9 | 10 | ## Math 11 | import numpy as np 12 | 13 | ## tensorflow imports 14 | import tensorflow.compat.v1 as tf 15 | tf.disable_v2_behavior() 16 | import tensorflow_probability as tfp 17 | from tensorflow.python.platform import app 18 | from tensorflow.python.platform import flags 19 | 20 | # def save_mixtures(self, sess, resultsObj, FLAGS, VAE, mixture_input_data, datatype, batch_size=100): 21 | # batch_size = np.amin((batch_size, mixture_input_data.shape[0])) 22 | # results=[] 23 | # ## Loop through bulk in batches 24 | # for i in range(0, mixture_input_data.shape[0], batch_size): 25 | # end_itr = np.amin((i + batch_size, mixture_input_data.shape[0])) 26 | # data_batch = mixture_input_data[i:end_itr,:] 27 | # ## Compute mixture from weighted components 28 | # for index, value in enumerate(self.current_celltypes, 0): 29 | # with tf.variable_scope(value, tf.AUTO_REUSE): 30 | # emb = sess.run(VAE[value].emb_saver_sample, feed_dict={VAE[value].fdata_ph: data_batch}) 31 | # purified = sess.run(VAE[value].rec_sampled_saver, feed_dict={VAE[value].emb_sampled_ph: emb[tf.newaxis,]}) 32 | # ## Component specific contribution to mixture profile 33 | # test = sess.run(tf.gather(self.proportions[:,index,tf.newaxis], list(range(i,end_itr)), axis=0)) 34 | # pure_data += sess.run(tf.multiply(purified, tf.gather(self.proportions[:,index,tf.newaxis], list(range(i,end_itr)), axis=0))) 35 | # results.append(pure_data) 36 | # resultsObj.mixture_saver(datatype, results) 37 | 38 | def save_samples(sess, 39 | resultsObj, 40 | FLAGS, 41 | VAE, 42 | component_labels, 43 | stage="prebatch"): 44 | if not os.path.exists(os.path.abspath(FLAGS.logdir + '/' + stage + '/sampled_results')): 45 | os.makedirs(os.path.abspath(FLAGS.logdir + '/' + stage + '/sampled_results')) 46 | ## Save component and mixture results 47 | for scope in np.unique(component_labels): 48 | print("Saving: "+scope) 49 | VAE[scope].save_sampled(sess=sess, 50 | FLAGS=FLAGS, 51 | resultsObj=resultsObj, 52 | datatype=stage, 53 | scope=scope, 54 | dir=stage+"/sampled") 55 | 56 | def save_results(sess, 57 | resultsObj, 58 | FLAGS, 59 | VAE, 60 | component_labels, 61 | component_data, 62 | hvg_masks, 63 | marker_gene_masks, 64 | stage="component", 65 | component_valid_data=None, 66 | component_valid_labels=None, 67 | mixture_input_data=None): 68 | ## Log a bunch of results! 69 | if FLAGS.log_results is True: 70 | ## Create results directories 71 | if not os.path.exists(os.path.abspath(FLAGS.logdir + '/' + stage + '/component_train_results')): 72 | os.makedirs(os.path.abspath(FLAGS.logdir + '/' + stage + '/component_train_results')) 73 | if not os.path.exists(os.path.abspath(FLAGS.logdir + '/' + stage + '/component_valid_results')): 74 | os.makedirs(os.path.abspath(FLAGS.logdir + '/' + stage + '/component_valid_results')) 75 | if mixture_input_data is not None: 76 | if not os.path.exists(os.path.abspath(FLAGS.logdir + '/' + stage + '/mixture_results')): 77 | os.makedirs(os.path.abspath(FLAGS.logdir + '/' + stage + '/mixture_results')) 78 | ## Save component and mixture results 79 | for scope in np.unique(component_labels): 80 | with tf.variable_scope(scope, tf.AUTO_REUSE): 81 | print("Saving: "+scope) 82 | comp_data = component_data[np.where(component_labels == scope)[0],:] 83 | ## Component vae will only model hvgs specific to cell type 84 | if mixture_input_data is not None: 85 | if hvg_masks is not None: 86 | comp_data = comp_data[:,np.nonzero(hvg_masks[scope])[0]] 87 | mixture_input_data_scope = mixture_input_data[:,np.nonzero(hvg_masks[scope])[0]] 88 | else: 89 | mixture_input_data_scope = mixture_input_data 90 | 91 | if comp_data.shape[0] > 0: 92 | VAE[scope].save_results(sess=sess, 93 | resultsObj=resultsObj, 94 | datatype="component", 95 | mode="train", 96 | stage=stage, 97 | FLAGS=FLAGS, 98 | data=comp_data, 99 | scope=scope, 100 | dir=stage+"/component_train") 101 | if component_valid_data is not None and component_valid_data.shape[0] > 0: 102 | VAE[scope].save_results(sess=sess, 103 | resultsObj=resultsObj, 104 | datatype="component", 105 | mode="test", 106 | stage=stage, 107 | FLAGS=FLAGS, 108 | data=component_valid_data[np.where(component_valid_labels == scope)[0],:], 109 | scope=scope, 110 | dir=stage+"/component_valid") 111 | if mixture_input_data is not None: 112 | VAE[scope].save_results(sess=sess, 113 | resultsObj=resultsObj, 114 | datatype="purified", 115 | mode="train", 116 | stage=stage, 117 | FLAGS=FLAGS, 118 | data=mixture_input_data_scope, 119 | scope=scope, 120 | dir=stage+"/mixture") 121 | -------------------------------------------------------------------------------- /src/scProjection/deconvBackend.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import sys, os 7 | 8 | import tensorflow.compat.v1 as tf 9 | tf.disable_v2_behavior() 10 | import tensorflow_probability as tfp 11 | from tensorflow.python.framework import constant_op 12 | from tensorflow.python.framework import ops 13 | from tensorflow.python.ops import init_ops 14 | from tensorflow.python.ops import array_ops 15 | from tensorflow.python.ops import clip_ops 16 | from tensorflow.python.ops import math_ops 17 | from tensorflow.python.ops import control_flow_ops 18 | from tensorflow.python.ops import variables as tf_variables 19 | from tensorflow.python.training import optimizer as tf_optimizer 20 | from tensorflow.python.training import training_util 21 | 22 | class deconvModel(object): 23 | def __init__(self, FLAGS, VAE, datasets, input_size, output_size, num_samples, reconstruction=None, loadings=None, marker_gene_masks=None, hvg_masks=None, scope="prop_model"): 24 | ## Params for deconvolution model 25 | self.input_size = input_size 26 | self.output_size = output_size 27 | self.num_samples = num_samples 28 | self.batch_size = np.amin((FLAGS.batch_size_mixture, self.num_samples)) ## Batch size should not be larger than dataset size 29 | ## Marker gene mask 30 | self.marker_gene_mask = None 31 | ## Organizing name for cmobined model 32 | self.scope = scope 33 | ## Cell types in model 34 | self.current_celltypes = np.unique(datasets) 35 | self.num_component = len(self.current_celltypes) 36 | ## 37 | self.step = tf.Variable(1, name='prop_step', trainable=False, dtype=tf.int32) 38 | 39 | ## Build combined model 40 | with tf.variable_scope(self.scope, tf.AUTO_REUSE): 41 | self.build_dataset(); 42 | self.build_model(FLAGS); 43 | self.deconvolution(FLAGS, VAE); 44 | ## MSE between reconstruction and measured mixture 45 | self.add_prop_loss(self.data_batch, self.mixture_rec) 46 | ## Create training operation 47 | self.create_train_op(FLAGS) 48 | self.create_diagnostics(FLAGS) 49 | 50 | def build_dataset(self): 51 | with tf.name_scope("mixture_data"): 52 | ## Mixture indexing 53 | self.index = tf.range(start=0, limit=self.num_samples, dtype=tf.int32) 54 | ## Dataset 55 | self.mix_data_ph = tf.placeholder(tf.float32, (None, self.input_size)) 56 | self.mix_dataset = tf.data.Dataset.from_tensor_slices((self.mix_data_ph, self.index)).shuffle(self.num_samples).repeat().batch(self.batch_size) 57 | self.mix_iter_data = self.mix_dataset.make_initializable_iterator() 58 | self.mix_data_batch, self.data_index = self.mix_iter_data.get_next() 59 | 60 | def build_model(self, FLAGS): 61 | with tf.name_scope("mixture_proportions"): 62 | ## Mixture weights 63 | self.mixture_weights = tf.get_variable(name="prop_coef", 64 | shape=[self.num_samples, self.num_component], 65 | dtype=tf.float32, 66 | initializer=tf.zeros_initializer, 67 | regularizer=tf.contrib.layers.l1_regularizer(FLAGS.l1_reg_weight) if (FLAGS.mix_weight_reg is True) else None, 68 | trainable=True) 69 | ## Normalize mixture weights 70 | self.proportions = tf.nn.softmax(self.mixture_weights, axis=-1, name="proportions") 71 | ## Hold results of linear reconstitution of mixture 72 | self.mixture_rec = tf.zeros([self.batch_size, self.output_size], tf.float32) 73 | 74 | def deconvolution(self, FLAGS, VAE): 75 | ## Compute mixture from weighted components 76 | for index, value in enumerate(self.current_celltypes, 0): 77 | #with tf.name_scope(value): 78 | with tf.variable_scope(value, tf.AUTO_REUSE): 79 | data_emb = VAE[value].encoder_func(self.mix_data_batch, is_training=False).sample(FLAGS.num_monte_carlo) 80 | data_rec = tf.reduce_mean(VAE[value].decoder_func(data_emb, is_training=False).mean(), axis=0) 81 | ## TPM specific reconstruction 82 | if FLAGS.tpm_softmax is True: 83 | data_rec = tf.nn.softmax(data_rec, axis=-1) * FLAGS.tpm_scale_factor 84 | ## Component specific contribution to mixture profile 85 | pure_data = tf.multiply(data_rec, tf.gather(self.proportions[:,index,tf.newaxis], self.data_index, axis=0)) 86 | ## Proportion specific reconstructions, could be masked 87 | if VAE[value].marker_mask is not None: 88 | self.mixture_rec = tf.add(self.mixture_rec, tf.multiply(pure_data, VAE[value].marker_mask)) 89 | self.data_batch = tf.multiply(self.mix_data_batch, VAE[value].marker_mask) 90 | else: 91 | self.mixture_rec = tf.add(self.mixture_rec, pure_data) 92 | self.data_batch = self.mix_data_batch 93 | tf.summary.histogram(value + "_weights", self.mixture_weights[:,index]) 94 | tf.summary.histogram(value + "_probs", self.proportions[:,index]) 95 | 96 | def add_prop_loss(self, data, rec): 97 | ## Define decoder reconstruction loss 98 | self.mse_mixture = tf.losses.mean_squared_error(data, 99 | rec, 100 | weights=1.0, 101 | scope=self.scope) 102 | tf.summary.scalar("MSE", self.mse_mixture) 103 | 104 | def create_train_op(self, FLAGS): 105 | """Create and return training operation.""" 106 | with tf.name_scope('Optimizer_prop'): 107 | ## Set up learning rate 108 | if FLAGS.decay_lr is True: 109 | with tf.name_scope("learning_rate_prop"): 110 | self.prop_learning_rate = tf.maximum( 111 | tf.train.exponential_decay( 112 | FLAGS.proportion_learning_rate, 113 | self.step, ## Account for previous training 114 | FLAGS.decay_step*10, 115 | FLAGS.decay_rate), 116 | FLAGS.min_learning_rate) 117 | else: 118 | self.prop_learning_rate = FLAGS.proportion_learning_rate 119 | 120 | ## Collect prop component loss 121 | self.proportion_loss = self.mse_mixture 122 | ## Minimize loss function 123 | optimizer = tf.train.AdamOptimizer(self.prop_learning_rate) 124 | self.train_proportion_op = optimizer.minimize(loss=self.proportion_loss, global_step=self.step, var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope)) 125 | ## Monitor 126 | tf.summary.scalar('Learning_Rate', self.prop_learning_rate) 127 | tf.summary.scalar('Loss_Total', self.proportion_loss) 128 | return(self.train_proportion_op) 129 | 130 | def create_diagnostics(self, FLAGS): 131 | ## Diagnostic 132 | if FLAGS.combined_corr_check is True: 133 | self.corr = tfp.stats.correlation(self.mix_data_batch, self.mixture_rec, sample_axis=-1, event_axis=0) 134 | self.corr = tf.reduce_mean(tf.linalg.tensor_diag_part(self.corr), axis=0) 135 | tf.compat.v1.summary.scalar("Mixture_correlation", self.corr) 136 | -------------------------------------------------------------------------------- /src/scProjection/architectures.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import sys, os 7 | 8 | import tensorflow.compat.v1 as tf 9 | tf.disable_v2_behavior() 10 | import tensorflow_probability as tfp 11 | from tensorflow.python.framework import constant_op 12 | from tensorflow.python.framework import ops 13 | from tensorflow.python.ops import init_ops 14 | from tensorflow.python.ops import array_ops 15 | from tensorflow.python.ops import clip_ops 16 | from tensorflow.python.ops import control_flow_ops 17 | from tensorflow.python.ops import variables as tf_variables 18 | from tensorflow.python.platform import tf_logging as logging 19 | from tensorflow.python.summary import summary 20 | from tensorflow.python.training import monitored_session 21 | from tensorflow.python.training import optimizer as tf_optimizer 22 | from tensorflow.python.training import training_util 23 | 24 | ## Simple prior N(0,I) 25 | def vae_prior(ndim, 26 | scope): 27 | with tf.name_scope('prior_Z'): 28 | return tfp.distributions.MultivariateNormalDiag(loc=tf.zeros(ndim), 29 | scale_diag=tf.ones(ndim), 30 | allow_nan_stats=False, 31 | validate_args=True, 32 | name="normal_prior_Z") 33 | 34 | ## Approximation: q(z|x) 35 | def vae_encoder(inputs, 36 | ndim, 37 | FLAGS, 38 | scope, 39 | num_layers=3, ## This number excludes the embedding layer 40 | hidden_unit_2power=9, ## Specify the maximal number (2^X) hidden units. 41 | l2_weight=1e-4, 42 | dropout_rate=0.3, 43 | batch_norm=True, 44 | batch_renorm=False, 45 | is_training=True): 46 | with tf.name_scope('vae_encoder'): 47 | inputs = tf.cast(inputs, tf.float32) 48 | encoder = tf.layers.dense(inputs=inputs, 49 | units=pow(2,hidden_unit_2power), 50 | activation=tf.nn.relu, 51 | kernel_initializer=tf.glorot_uniform_initializer(), 52 | kernel_regularizer=tf.keras.regularizers.l2(l2_weight), 53 | use_bias=True, 54 | bias_initializer=init_ops.zeros_initializer(), 55 | name='fc1') 56 | if batch_norm: encoder = tf.layers.batch_normalization(encoder, training=is_training, renorm=batch_renorm, name='batch_norm_1') 57 | encoder = tf.layers.dropout(inputs=encoder, rate=dropout_rate, training=is_training, name='drop1') 58 | ## Add additional layers while accounting for input layer 59 | for layer in range(1,num_layers): 60 | encoder = tf.layers.dense(inputs=encoder, 61 | units=pow(2,hidden_unit_2power-layer), 62 | activation=tf.nn.relu, 63 | kernel_initializer=tf.glorot_uniform_initializer(), 64 | kernel_regularizer=tf.keras.regularizers.l2(l2_weight), 65 | use_bias=True, 66 | bias_initializer=init_ops.zeros_initializer(), 67 | name='fc'+str(layer+1)) 68 | if batch_norm: encoder = tf.layers.batch_normalization(encoder, training=is_training, renorm=batch_renorm, name='batch_norm_'+str(layer+1)) 69 | encoder = tf.layers.dropout(inputs=encoder, rate=dropout_rate, training=is_training, name='drop'+str(layer+1)) 70 | ## Embedding layer has user defined number of hidden units (ndim) 71 | encoder = tf.layers.dense(inputs=encoder, 72 | units=2*ndim, 73 | activation=None, 74 | kernel_initializer=tf.glorot_uniform_initializer(), 75 | kernel_regularizer=None, 76 | use_bias=True, 77 | bias_initializer=init_ops.zeros_initializer(), 78 | name='fc'+str(num_layers+1)) 79 | encoder = tfp.distributions.MultivariateNormalDiag( 80 | loc=encoder[..., :ndim], 81 | scale_diag=tf.nn.softplus(encoder[..., ndim:])+1e-8, ## Changed tf.nn.softplas due to numeric issues with large variances (Potentially unused dimensions) 82 | allow_nan_stats=False, 83 | validate_args=False, 84 | name="latent_Z") 85 | return encoder 86 | 87 | ## Reconstruction: p(x|z) 88 | def vae_decoder(inputs, 89 | output_size, 90 | FLAGS, 91 | scope, 92 | l2_weight=1e-4, 93 | hidden_unit_2power=9, 94 | num_layers=3, 95 | dropout_rate=0.3, 96 | batch_norm=True, 97 | batch_renorm=False, 98 | is_training=True): 99 | with tf.name_scope('vae_decoder'): 100 | inputs = tf.cast(inputs, tf.float32) 101 | decoder = tf.layers.dense(inputs=inputs, 102 | units=pow(2,(hidden_unit_2power-num_layers)+1), 103 | activation=tf.nn.relu, 104 | kernel_initializer=tf.glorot_uniform_initializer(), 105 | kernel_regularizer=tf.keras.regularizers.l2(l2_weight), 106 | use_bias=True, 107 | bias_initializer=init_ops.zeros_initializer(), 108 | name='fc1') 109 | if batch_norm: decoder = tf.layers.batch_normalization(decoder, training=is_training, renorm=batch_renorm, name='batch_norm_1') 110 | decoder = tf.layers.dropout(inputs=decoder, rate=dropout_rate, training=is_training, name='drop1') 111 | ## Add additional layers while accounting for input layer 112 | for layer in range(num_layers-2,-1,-1): 113 | decoder = tf.layers.dense(inputs=decoder, 114 | units=pow(2,hidden_unit_2power-layer), 115 | activation=tf.nn.relu, 116 | kernel_initializer=tf.glorot_uniform_initializer(), 117 | kernel_regularizer=tf.keras.regularizers.l2(l2_weight), 118 | use_bias=True, 119 | bias_initializer=init_ops.zeros_initializer(), 120 | name='fc'+str(num_layers-layer)) 121 | if batch_norm: decoder = tf.layers.batch_normalization(decoder, training=is_training, renorm=batch_renorm, name='batch_norm_'+str(num_layers-layer)) 122 | decoder = tf.layers.dropout(inputs=decoder, rate=dropout_rate, training=is_training, name='drop'+str(num_layers-layer)) 123 | decoder_mean = tf.layers.dense(inputs=decoder, 124 | units=output_size, 125 | activation=None, 126 | kernel_initializer=tf.glorot_uniform_initializer(), 127 | kernel_regularizer=None, 128 | use_bias=True, 129 | bias_initializer=init_ops.zeros_initializer(), 130 | name='fc'+str(num_layers+1)) 131 | if FLAGS.decoder_variance == "per_sample": 132 | decoder_var = tf.layers.dense(inputs=decoder, 133 | units=1, 134 | activation=tf.nn.relu, 135 | kernel_initializer=tf.glorot_uniform_initializer(), 136 | kernel_regularizer=None, 137 | use_bias=True, 138 | bias_initializer=init_ops.zeros_initializer(), 139 | name='fc'+str(num_layers+1)+"_var") 140 | decoder = tfp.distributions.MultivariateNormalDiag( 141 | loc=decoder_mean, 142 | scale_diag=(tf.ones(output_size)*tf.nn.softplus(decoder_var)+1e-8), 143 | allow_nan_stats=False, 144 | validate_args=False, 145 | name="reconstructed_cell") 146 | else: 147 | decoder_var = tf.ones(1) 148 | decoder = tfp.distributions.MultivariateNormalDiag( 149 | loc=decoder_mean, 150 | scale_diag=tf.ones(output_size), 151 | allow_nan_stats=False, 152 | validate_args=False, 153 | name="reconstructed_cell") 154 | return decoder #, decoder_var 155 | -------------------------------------------------------------------------------- /src/scProjection/utils.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import sys, os, glob, gzip, time 7 | from functools import partial 8 | from importlib import import_module 9 | 10 | ## Math 11 | import numpy as np 12 | import scipy as sp 13 | 14 | ## tensorflow imports 15 | import tensorflow.compat.v1 as tf 16 | tf.disable_v2_behavior() 17 | import tensorflow_probability as tfp 18 | from tensorflow.python.platform import app 19 | from tensorflow.python.platform import flags 20 | 21 | ## Record evaluation metrics 22 | class modelMetrics(object): 23 | def __init__(self, mode): 24 | self.step = {'train': [], 'test': [], 'train-post': []} 25 | self.loss = {'train': [], 'test': [], 'train-post': []} 26 | self.mse = {'train': [], 'test': [], 'train-post': []} 27 | if mode is "VAE": 28 | self.log_probability = {'train': [], 'test': [], 'train-post': []} 29 | self.kl_divergence = {'train': [], 'test': [], 'train-post': []} 30 | 31 | ## Record results and evaluation metrics 32 | class deconvResult(object): 33 | def __init__(self, celltypeNames): 34 | self.proportions = {} 35 | self.weights = {} 36 | ## Reconstructions and embeddings 37 | self.deconv_data = {'component': {'train': dict.fromkeys(celltypeNames), 'test': dict.fromkeys(celltypeNames)}, 'purified': {'train': dict.fromkeys(celltypeNames), 'test': dict.fromkeys(celltypeNames)}} 38 | self.deconv_emb = {'component': {'train': dict.fromkeys(celltypeNames), 'test': dict.fromkeys(celltypeNames)}, 'purified': {'train': dict.fromkeys(celltypeNames), 'test': dict.fromkeys(celltypeNames)}} 39 | self.deconv_logp = {'component': {'train': dict.fromkeys(celltypeNames), 'test': dict.fromkeys(celltypeNames)}, 'purified': {'train': dict.fromkeys(celltypeNames), 'test': dict.fromkeys(celltypeNames)}} 40 | self.deconv_var = {'component': {'train': dict.fromkeys(celltypeNames), 'test': dict.fromkeys(celltypeNames)}, 'purified': {'train': dict.fromkeys(celltypeNames), 'test': dict.fromkeys(celltypeNames)}} 41 | #self.deconv_inv_data = {'component': {'train': dict.fromkeys(celltypeNames), 'test': dict.fromkeys(celltypeNames)}, 'purified': {'train': dict.fromkeys(celltypeNames), 'test': dict.fromkeys(celltypeNames)}} 42 | ## Reconstructions and embeddings after correction 43 | self.deconv_data_post = {'component': {'train': dict.fromkeys(celltypeNames), 'test': dict.fromkeys(celltypeNames)}, 'purified': {'train': dict.fromkeys(celltypeNames), 'test': dict.fromkeys(celltypeNames)}} 44 | self.deconv_emb_post = {'component': {'train': dict.fromkeys(celltypeNames), 'test': dict.fromkeys(celltypeNames)}, 'purified': {'train': dict.fromkeys(celltypeNames), 'test': dict.fromkeys(celltypeNames)}} 45 | self.deconv_logp_post = {'component': {'train': dict.fromkeys(celltypeNames), 'test': dict.fromkeys(celltypeNames)}, 'purified': {'train': dict.fromkeys(celltypeNames), 'test': dict.fromkeys(celltypeNames)}} 46 | self.deconv_var_post = {'component': {'train': dict.fromkeys(celltypeNames), 'test': dict.fromkeys(celltypeNames)}, 'purified': {'train': dict.fromkeys(celltypeNames), 'test': dict.fromkeys(celltypeNames)}} 47 | #self.deconv_inv_data_post = {'component': {'train': dict.fromkeys(celltypeNames), 'test': dict.fromkeys(celltypeNames)}, 'purified': {'train': dict.fromkeys(celltypeNames), 'test': dict.fromkeys(celltypeNames)}} 48 | ## Sampling 49 | self.deconv_samples_emb = {'prebatch': dict.fromkeys(celltypeNames), 'postbatch': dict.fromkeys(celltypeNames)} 50 | self.deconv_samples_rec = {'prebatch': dict.fromkeys(celltypeNames), 'postbatch': dict.fromkeys(celltypeNames)} 51 | ## Mixture reconstruction 52 | self.deconv_mixtures = {'prebatch': None, 'postbatch': None} 53 | ## Train/test metrics 54 | self.component_metrics = dict([(key, modelMetrics(mode = "VAE")) for key in celltypeNames]) 55 | self.proportion_metrics = modelMetrics(mode = "Mixture") 56 | self.mixture_metrics = modelMetrics(mode = "Mixture") 57 | ## Model details 58 | self.celltypes = celltypeNames 59 | self.results_dir = None 60 | self.flags = None 61 | ## mphate specifics 62 | self.m_phate_logging = dict([(key, []) for key in celltypeNames]) 63 | 64 | def update_component_metrics(self, mode, scope, step, loss, mse, logp, kl): 65 | self.component_metrics[scope].step[mode].append(step) 66 | self.component_metrics[scope].loss[mode].append(loss) 67 | self.component_metrics[scope].mse[mode].append(mse) 68 | self.component_metrics[scope].log_probability[mode].append(logp) 69 | self.component_metrics[scope].kl_divergence[mode].append(kl) 70 | 71 | def update_mixture_metrics(self, mode, loss, mse): 72 | self.proportion_metrics.loss[mode].append(loss) 73 | self.proportion_metrics.mse[mode].append(mse) 74 | 75 | def update_mixture_metrics(self, mode, loss, mse): 76 | self.mixture_metrics.loss[mode].append(loss) 77 | self.mixture_metrics.mse[mode].append(mse) 78 | 79 | def update_proportions(self, step, props): 80 | self.proportions[str(step)] = props 81 | #self.weights[str(step)] = weights 82 | 83 | def update_mphate(self, scope, embedding): 84 | self.m_phate_logging[scope].append(embedding) 85 | 86 | def update_saver(self, datatype, mode, stage, scope, rec_data, embedding, log_prob, rec_var): 87 | if stage == "component": 88 | self.deconv_data[datatype][mode][scope] = rec_data 89 | self.deconv_emb[datatype][mode][scope] = embedding 90 | self.deconv_logp[datatype][mode][scope] = log_prob 91 | self.deconv_var[datatype][mode][scope] = rec_var 92 | elif stage == "combined": 93 | self.deconv_data_post[datatype][mode][scope] = rec_data 94 | self.deconv_emb_post[datatype][mode][scope] = embedding 95 | self.deconv_logp_post[datatype][mode][scope] = log_prob 96 | self.deconv_var_post[datatype][mode][scope] = rec_var 97 | else: 98 | print("Unknown stage") 99 | 100 | def sample_saver(self, datatype, scope, emb_samples, rec_samples): 101 | self.deconv_samples_emb[datatype][scope] = emb_samples 102 | self.deconv_samples_rec[datatype][scope] = rec_samples 103 | 104 | def mixture_saver(self, datatype, mixture_rec): 105 | self.deconv_mixtures[datatype] = mixture_rec 106 | 107 | def __repr__(self): 108 | return "Results of deconvolution" 109 | 110 | def initialize_model_dir(FLAGS): 111 | print("Running training method: " + FLAGS.training_method) 112 | ## Ensure save path exists 113 | FLAGS.logdir = os.path.abspath(FLAGS.logdir) 114 | ## Setup save path 115 | FLAGS.logdir = FLAGS.logdir + "/model_" + FLAGS.model 116 | print(FLAGS.logdir) 117 | ## Make path 118 | if not os.path.exists(os.path.abspath(FLAGS.logdir)): 119 | os.makedirs(os.path.abspath(FLAGS.logdir)) 120 | return FLAGS 121 | 122 | ## Saves flags to a file 123 | def write_flags(FLAGS): 124 | flag_dict = tf.app.flags.FLAGS.flag_values_dict() 125 | with open(FLAGS.logdir + '/run_flags.txt', 'w+') as file_out: 126 | [file_out.write('{0}\t{1}\n'.format(key, value)) for key, value in flag_dict.items()] 127 | 128 | def summary_and_checkpoint(FLAGS, graph): 129 | ## Tensorboard and model checkpointing 130 | summary_op = tf.summary.merge_all() 131 | ## Write summaries 132 | summary_writer = tf.summary.FileWriter(os.path.abspath(FLAGS.logdir) + "/train", graph, filename_suffix='-train') 133 | summary_writer_test = tf.summary.FileWriter(os.path.abspath(FLAGS.logdir) + "/test", filename_suffix='-test') 134 | ## Model saving every X steps or Y hours 135 | saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1.0) 136 | return summary_op, summary_writer, summary_writer_test, saver 137 | 138 | ## Setup masking and annotations 139 | def setup_annotations_masks(component_label, marker_gene_mask, hvg_mask, max_steps_component, celltypes): 140 | ## Convert old format to new dictionary for celltype labels from reference 141 | if type(component_label) is not dict: 142 | print("Converting component label list to dictionary") 143 | component_labels = {'standard_deconv' : component_label} 144 | else: 145 | component_labels = component_label 146 | 147 | ## Convert old format to dictionary for marker gene mask which is shared across cell types 148 | if type(marker_gene_mask) is not dict and marker_gene_mask is not None: 149 | print("Converting marker mask to dictionary") 150 | marker_gene_masks = {} 151 | for scope in celltypes: 152 | marker_gene_masks.update( {scope : marker_gene_mask} ) 153 | else: 154 | marker_gene_masks = marker_gene_mask 155 | 156 | ## Convert old format to dictionary for marker gene mask which is shared across cell types 157 | if type(hvg_mask) is not dict and hvg_mask is not None: 158 | print("Converting hvg mask to dictionary") 159 | hvg_masks = {} 160 | for scope in celltypes: 161 | hvg_masks.update( {scope : hvg_mask} ) 162 | else: 163 | hvg_masks = hvg_mask 164 | 165 | ## Convert old format to dictionary for marker gene mask which is shared across cell types 166 | if type(max_steps_component) is not dict and max_steps_component is not None: 167 | print("Converting hvg mask to dictionary") 168 | max_steps_components = {} 169 | for scope in celltypes: 170 | max_steps_components.update( {scope : max_steps_component} ) 171 | else: 172 | max_steps_components = max_steps_component 173 | 174 | return component_labels['standard_deconv'], marker_gene_masks, hvg_masks, max_steps_components 175 | 176 | def early_stopping(FLAGS, VAE, scope, corr): 177 | ## Early stopping if enabled and minimum number of steps has been achieved 178 | if corr > 0.85: 179 | VAE[scope].patience = VAE[scope].patience + 1 180 | elif corr > 0.90: 181 | VAE[scope].patience = VAE[scope].patience + 10 182 | else: 183 | VAE[scope].patience = VAE[scope].patience - 1 184 | 185 | if VAE[scope].patience >= FLAGS.max_patience: 186 | print("Early stopping triggered for: ", str(scope) + " Corr: " + str(corr)) 187 | VAE[scope].early_stop = True 188 | 189 | def define_hardware(FLAGS): 190 | config = tf.ConfigProto() 191 | config.gpu_options.allow_growth = True 192 | config.log_device_placement = True 193 | config.allow_soft_placement = True 194 | config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 195 | os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.cuda_device 196 | return config 197 | -------------------------------------------------------------------------------- /src/scProjection/batchBackend.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import sys, os 7 | 8 | import tensorflow.compat.v1 as tf 9 | tf.disable_v2_behavior() 10 | import tensorflow_probability as tfp 11 | from tensorflow.python.framework import constant_op 12 | from tensorflow.python.framework import ops 13 | from tensorflow.python.ops import init_ops 14 | from tensorflow.python.ops import array_ops 15 | from tensorflow.python.ops import clip_ops 16 | from tensorflow.python.ops import math_ops 17 | from tensorflow.python.ops import control_flow_ops 18 | from tensorflow.python.ops import variables as tf_variables 19 | from tensorflow.python.training import optimizer as tf_optimizer 20 | from tensorflow.python.training import training_util 21 | 22 | class batchModel(object): 23 | def __init__(self, FLAGS, VAE, datasets, input_size, output_size, num_samples, proportions, mixture_weights, reconstruction=None, loadings=None, VAE_parameters=None, scope="batch_correction"): 24 | ## Params for deconvolution model 25 | self.input_size = input_size 26 | self.output_size = output_size 27 | self.num_samples = num_samples 28 | self.batch_size = np.amin((FLAGS.batch_size_mixture, self.num_samples)) ## Batch size should not be larger than dataset size 29 | ## Organizing name for cmobined model 30 | self.scope = scope 31 | ## Cell types in model 32 | self.current_celltypes = np.unique(datasets) 33 | self.num_component = len(self.current_celltypes) 34 | ## Normalize mixture weights 35 | self.mixture_weights = mixture_weights 36 | self.proportions = proportions 37 | 38 | ## Combined model steps 39 | self.step = tf.Variable(1, name='batch_step', trainable=False, dtype=tf.int32) 40 | 41 | ## Build combined model 42 | with tf.variable_scope(self.scope, tf.AUTO_REUSE): 43 | ## Create parameters and metadata for batch correction 44 | self.build_dataset(); 45 | self.build_model(); 46 | self.batch_correction(FLAGS, VAE); 47 | ## MSE between reconstruction and measured mixture during proportion est. or batch correction 48 | self.add_loss(); 49 | ## Record VAE trainable parameters 50 | self.trainable_params += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "bulk_distribution") 51 | 52 | # ## Corr 53 | # if FLAGS.combined_corr_check is True: 54 | # self.corr = tfp.stats.correlation(self.mix_data_batch, self.mixture_rec, sample_axis=-1, event_axis=0) 55 | # self.corr = tf.reduce_mean(tf.linalg.tensor_diag_part(self.corr), axis=0) 56 | # tf.compat.v1.summary.scalar("Mixture_correlation", self.corr) 57 | 58 | ## Build the optimizer 59 | self.create_train_op(FLAGS) 60 | 61 | def build_dataset(self): 62 | with tf.name_scope("mixture_data"): 63 | ## Mixture indexing 64 | self.index = tf.range(start=0, limit=self.num_samples, dtype=tf.int32) 65 | ## Dataset 66 | self.mix_data_ph = tf.placeholder(tf.float32, (None, self.input_size)) 67 | self.mix_dataset = tf.data.Dataset.from_tensor_slices((self.mix_data_ph, self.index)).shuffle(self.num_samples).repeat().batch(self.batch_size) 68 | self.mix_iter_data = self.mix_dataset.make_initializable_iterator() 69 | self.mix_data_batch, self.data_index = self.mix_iter_data.get_next() 70 | 71 | def build_model(self): 72 | ## Hold results of linear reconstitution of mixture 73 | self.mixture_rec = tf.zeros([self.batch_size, self.output_size], tf.float32) 74 | ## Batch correction loss 75 | self.marker_mse = tf.constant(0., tf.float32) 76 | self.non_marker_mse = tf.constant(0., tf.float32) 77 | self.KL_vae = tf.constant(0., tf.float32) 78 | self.elbo_vae = tf.constant(0., tf.float32) 79 | self.regularization_vae = tf.constant(0., tf.float32) 80 | self.regularization_tied = tf.constant(0., tf.float32) 81 | 82 | def batch_correction(self, FLAGS, VAE): 83 | ## Compute mixture from weighted components 84 | for index, value in enumerate(self.current_celltypes, 0): 85 | with tf.variable_scope(value, tf.AUTO_REUSE): 86 | self.data_emb = VAE[value].encoder_func(self.mix_data_batch, is_training=False).sample(FLAGS.num_monte_carlo) 87 | self.data_dist = VAE[value].decoder_func(self.data_emb, is_training=False) 88 | self.data_rec = tf.reduce_mean(self.data_dist.mean(), axis=0) 89 | 90 | with tf.name_scope("LinearCombo"): 91 | ## Component specific contribution to mixture profile 92 | VAE[value].pure_data = tf.multiply(self.data_rec, tf.gather(self.proportions[:,index,tf.newaxis], self.data_index, axis=0)) 93 | ## Reconstruct mixture celltype-wise 94 | self.mixture_rec = tf.add(self.mixture_rec, VAE[value].pure_data) 95 | ## Record metrics 96 | # if VAE[value].marker_mask is not None: 97 | # self.marker_mse += VAE[value].marker_mse 98 | # self.non_marker_mse += VAE[value].non_marker_mse 99 | self.KL_vae += (VAE[value].avg_KL_div * VAE[value].KL_weight) 100 | self.elbo_vae += VAE[value].elbo 101 | self.regularization_vae += VAE[value].l2_reg 102 | #self.regularization_tied += VAE[value].tied_model_l2 103 | tf.compat.v1.summary.scalar('KL_'+ VAE[value].scope, self.KL_vae) 104 | tf.compat.v1.summary.scalar('ELBO_'+VAE[value].scope, -self.elbo_vae) 105 | tf.compat.v1.summary.scalar('L2_REG_'+ VAE[value].scope, self.regularization_vae) 106 | 107 | ## Gather VAE variances for bulk variance calculation 108 | try: 109 | #self.variances_vae += [tf.ones((10,100))] 110 | self.trainable_params += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, VAE[value].scope) 111 | except AttributeError: 112 | #self.variances_vae = [tf.ones((10,100))] 113 | self.trainable_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, VAE[value].scope) 114 | #tf.summary.histogram(value + "_weights", self.mixture_weights[:,index]) 115 | # tf.summary.histogram(value + "_probs", self.proportions[:,index]) 116 | 117 | ## Define bulk distribution for proportion estimation 118 | with tf.name_scope("bulk_distribution"): 119 | # if FLAGS.decoder_variance == "per_sample": 120 | # merged_vae_vars = tf.concat(self.variances_vae, axis=0) 121 | # merged_vae_vars = tf.tile(merged_vae_vars[tf.newaxis,], [self.batch_size,1]) 122 | # else: 123 | # merged_vae_vars = tf.stack(self.variances_vae, axis=-1) 124 | # merged_vae_vars = tf.tile(merged_vae_vars, [self.batch_size,1]) 125 | # print(merged_vae_vars) 126 | # print(merged_vae_vars) 127 | gathered_probs = tf.gather(self.proportions, self.data_index, axis=0) 128 | #print(gathered_probs) 129 | # combined_var_input = tf.concat([gathered_probs, merged_vae_vars], axis=1) 130 | # print(combined_var_input) 131 | # combined_var_input.set_shape((self.batch_size, self.num_component+(self.num_component*self.batch_size))) 132 | # print(combined_var_input) 133 | self.dist_var = tf.layers.dense(inputs=gathered_probs, #combined_var_input, #tf.tile(tf.reshape(merged_vae_vars, [-1])[tf.newaxis,], [self.batch_size,1])], axis=1), #[tf.newaxis,] 134 | units=1, 135 | activation=None, 136 | kernel_initializer=tf.glorot_uniform_initializer(), 137 | kernel_regularizer=tf.nn.relu, 138 | use_bias=True, 139 | bias_initializer=init_ops.zeros_initializer(), 140 | name='fc_var') 141 | self.mixture_rec_dist = tfp.distributions.MultivariateNormalDiag( 142 | loc=self.mixture_rec, 143 | scale_diag=(tf.ones(self.output_size)*tf.nn.softplus(self.dist_var)+1e-8), 144 | allow_nan_stats=False, 145 | validate_args=True, 146 | name="reconstructed_cell_dist") 147 | ## summaries 148 | # tf.compat.v1.summary.histogram("prop_bulk_dist_mean", self.mixture_rec_dist.mean()) 149 | # tf.compat.v1.summary.histogram("prop_bulk_dist_var", self.mixture_rec_dist.variance()) 150 | 151 | ## Bulk mse loss 152 | def add_loss(self): 153 | ## Log probability 154 | self.batch_loss = -tf.reduce_mean(self.mixture_rec_dist.log_prob(self.mix_data_batch)) 155 | tf.summary.scalar("batch_loss", self.batch_loss) 156 | 157 | ## Create operating to adjust VAEs with fixed proportions 158 | def create_train_op(self, FLAGS): 159 | """Create and return training operation.""" 160 | with tf.name_scope('Optimizer_batch'): 161 | ## Set up learning rate 162 | if FLAGS.decay_lr is True: 163 | with tf.name_scope("learning_rate_proportion"): 164 | self.batch_learning_rate = tf.maximum( 165 | tf.train.exponential_decay( 166 | FLAGS.combined_learning_rate, 167 | self.step, 168 | FLAGS.decay_step, 169 | FLAGS.decay_rate), 170 | FLAGS.min_learning_rate) 171 | else: 172 | self.batch_learning_rate = FLAGS.combined_learning_rate 173 | 174 | ## Minimize loss function 175 | optimizer = tf.train.AdamOptimizer(self.batch_learning_rate) 176 | 177 | ## Just optimizing proportions based on marker expr. recontruction of bulk 178 | self.train_batch_loss = self.batch_loss + (-self.elbo_vae) + (self.KL_vae) + tf.reduce_sum(self.regularization_vae) 179 | if FLAGS.tied_model_l2 is True: 180 | self.train_batch_loss += tf.reduce_sum(self.regularization_tied) 181 | 182 | ## If markers are provided then we can include marker vs. non-marker mse terms at different magnitudes 183 | # if FLAGS.marker_weight > 1: 184 | # self.train_batch_loss += self.marker_mse + self.non_marker_mse 185 | 186 | ## Batch correction optimization 187 | self.train_batch_op = optimizer.minimize(loss=self.train_batch_loss, global_step=self.step, var_list=self.trainable_params) 188 | 189 | ## Monitor 190 | tf.summary.scalar('Loss_Total', self.train_batch_loss) 191 | -------------------------------------------------------------------------------- /src/scProjection/scProjection.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: scProjection 3 | Version: 0.15 4 | Summary: Projection and Deconvolution using deep heirarchical and generative neural network. 5 | Home-page: https://github.com/ucdavis/quonlab/tree/master/development/deconvAllen 6 | Author: Nelson Johansen, Gerald Quon 7 | Author-email: njjohansen@ucdavis.edu, gquon@ucdavis.edu 8 | License: MIT 9 | Description: # Tutorial: Deconvolution of CellBench human lung adenocarcinoma cell line mixtures. 10 | 11 | This tutorial provides a guided deconvolutions of CellBench mixtures sequenced with CEL-Seq2 or SORT-Seq using a matching single cell RNA dataset sequenced with CEL-Seq. 12 | 13 | ## Setup Python (compatible with cluster environment) 14 | 15 | First we must set up a python virtual environment. Move to the directory where you want to set up the environment, then use virtualenv command. **All of the following commands in this section are performed on the command line:** 16 | 17 |
 virtualenv -p /usr/bin/python3 environment_name
 18 |         
19 | 20 | Now you can freely install any packages using pip install in the activated virtual environment. 21 | 22 |
 pip install package_name
 23 |         
24 | 25 | After setting up the environment, you need to activate it. Remember to activate the relevant environment every time you work on the project. 26 | 27 |
 source environment_name/bin/activate
 28 |         
29 | 30 | Once you are in the virtual environment, install tensorflow version 1.15 and some other pre-reqs: 31 | 32 |
pip3 install tensorflow==1.15rc2
 33 |         pip3 install tfp-nightly
 34 |         pip3 install sklearn
 35 |         
36 | 37 | Now we need to install the deconvolution package: 38 |
 pip3 install /share/quonlab/software/deconvolution/deconv
 39 |         
40 | 41 | The remaining sections of this tutorial will be in R. 42 | 43 | ## Deconvolution goals 44 | The following is a walkthrough of `` and has been designed to provide an overview of data preprocessing, deconvolution and visualization of standard outputs. Here, our primary goals include: 45 | 46 | 1. Preprocess both the single cell dataset and mixture dataset in `R`. 47 | 2. Train `` with and without marker genes. 48 | 3. Visualize cell type proportions and cell type-specific expression profiles for each mixture profile. 49 | 50 | ## Data preprocessing 51 | The data matrices for this tutorial can be found at `/share/quonlab/wkdir/njjohans/public/cellbench/`. 52 | 53 | First, we perform standard scRNA preprocessing steps using the `Seurat` package. After preprocessing we reduce to the top 2,000 highly variable genes and identify marker genes for the cell types in our single cell data. 54 | 55 | ```R 56 | library(Seurat) 57 | options(stringsAsFactors=F) 58 | 59 | working.dir = '/share/quonlab/wkdir/njjohans/public/cellbench/' 60 | 61 | load(paste0(working.dir, 'raw_data/sincell_with_class.RData')) 62 | load(paste0(working.dir, 'raw_data/mRNAmix_qc.RData')) 63 | 64 | ## Mixture data 65 | colnames(sce2_qc) = paste0("CELMix-", colnames(sce2_qc)) 66 | colnames(sce8_qc) = paste0("SORTMix-", colnames(sce8_qc)) 67 | 68 | ## Individual data 69 | colnames(sce_sc_CELseq2_qc) = paste0("CEL-", colnames(sce_sc_CELseq2_qc)) 70 | 71 | ## Reduce to the common genes (rows) 72 | common.genes = Reduce(intersect, list(rownames(counts(sce2_qc)), 73 | rownames(counts(sce8_qc)), 74 | rownames(sce_sc_CELseq2_qc))) 75 | 76 | ## Combine data (this is optional). Single cell and mixture data can also be normalized separately. 77 | cellbench_data = cbind(counts(sce2_qc)[common.genes,], 78 | counts(sce8_qc)[common.genes,], 79 | counts(sce_sc_CELseq2_qc)[common.genes,]) 80 | 81 | ################################################################################ 82 | ## Seurat combined normalization 83 | ################################################################################ 84 | cellbenchSeuratObj <- CreateSeuratObject(counts = cellbench_data, project = "DECONV", min.cells = 1) 85 | ## Batch annotation 86 | cellbenchSeuratObj@meta.data$platform = as.factor(c(rep('CELMix-seq', ncol(counts(sce2_qc))), 87 | rep('SORTMix-seq', ncol(counts(sce8_qc))), 88 | rep('CEL-seq', ncol(counts(sce_sc_CELseq2_qc))))) 89 | ## Cell type annotation 90 | cellbenchSeuratObj@meta.data$cell.type = as.factor(c(sce2_qc@colData$mix, 91 | sce8_qc@colData$mix, 92 | sce_sc_CELseq2_qc@colData$cell_line)) 93 | ## Important that log transformation is not performed during this step. 94 | cellbenchSeuratObj <- NormalizeData(cellbenchSeuratObj, normalization.method="RC", scale.factor=1e4) 95 | cellbenchSeuratObj <- ScaleData(cellbenchSeuratObj, do.scale=T, do.center=T, display.progress = T) 96 | cellbenchSeuratObj <- FindVariableFeatures(cellbenchSeuratObj, nfeatures=2000) 97 | ``` 98 | 99 | ## Marker gene selection 100 | Specification of marker genes is essential for successful deconvolution. Marker gene lists from external sources such as previous studies on the tissue of interest or methods including CIBERSORTx are typically the best starting point. Otherwise, use packages such as Seurat to identify marker genes from single cell data. 101 | 102 | ```R 103 | ################################################################################ 104 | ## Marker genes 105 | ################################################################################ 106 | marker.genes = read.table('/share/quonlab/wkdir/njjohans/public/cellbench/CEL_signature_genes.txt', sep="\t", header=T, stringsAsFactors=F)[,1] 107 | 108 | ## Combine variable features and marker genes. 109 | VariableFeatures(cellbenchSeuratObj) = union(marker.genes, VariableFeatures(cellbenchSeuratObj)) 110 | 111 | ## Create a binary mask for the position of markers in the gene list (optional input to our method). 112 | marker_gene_mask = rep(0, length(VariableFeatures(cellbenchSeuratObj))) 113 | marker_gene_mask[which(VariableFeatures(cellbenchSeuratObj) %in% marker.genes)]=1 114 | ``` 115 | 116 | ## Input overview 117 | Now that we have both preprocessed scRNA and mixture data, here is a look at standard inputs: 118 | 1. Single cell RNAseq data **(cells x genes)**. 119 | 2. Mixture RNAseq data **(cellx x genes)**. 120 | 3. Single cell data, cell type labels **(cells x 1)**. A vector of cell type names, with no white spaces. 121 | 122 | and optionally: 123 | 124 | 4. Marker gene mask **(genes x 1)**. A binary vector where 1's indicate the position of a marker gene. 125 | 126 | ## Deconvolve mixture profiles with ``! 127 | 128 | ```R 129 | library(reticulate) 130 | deconv = import("deconv") 131 | source('/share/quonlab/software/deconvolution/deconvR/nnDeconvClass.R') 132 | 133 | ## cells x genes 134 | component_data = GetAssayData(cellbenchSeuratObj, 'scale.data')[,which(cellbenchSeuratObj[["platform"]] == 'CEL-seq')] 135 | component_data = t(component_data[VariableFeatures(cellbenchSeuratObj),]) 136 | ## cell ids for component_data, should be strings! 137 | component_label = as.character(cellbenchSeuratObj[["cell.type"]][which(cellbenchSeuratObj[["platform"]] == 'CEL-seq'),1]) 138 | ## cells x genes 139 | mixture_data = GetAssayData(cellbenchSeuratObj, 'scale.data')[,which(cellbenchSeuratObj[["platform"]] != 'CEL-seq')] 140 | mixture_data = t(mixture_data[VariableFeatures(cellbenchSeuratObj),]) 141 | 142 | ## Because we are calling Python from R we must be careful about variable type: 143 | ## as.matrix() to ensure data matrices are passed as type matrix 144 | ## as.array() to ensure lists/vectors are passed as type array 145 | ## Integers must be followed by 'L', i.e. 100L. 146 | 147 | ## First we create the deconvolution model: 148 | deconvModel = deconv$deconvModel(component_data = as.matrix(component_data), 149 | component_label = as.array(as.character(component_label)), 150 | mixture_data = as.matrix(mixture_data)) 151 | 152 | ## Now we define the network architecture and run deconvolution! 153 | deconvModel$deconvolve(max_steps_component = 100L, 154 | num_latent_dims = 64L, 155 | num_layers = 3L, 156 | hidden_unit_2power = 9L, 157 | batch_norm_layers = 'True') 158 | 159 | ## Convert the python class to an R S4 class. check: str(deconvResults) 160 | deconvResults = convertDeconv(deconvModel) 161 | ``` 162 | 163 | ## Output overview 164 | The following are standard outputs stored in `deconvModel$deconvResults`: 165 | 166 | 1. `proportions` **(cells x celltypes)**. Each row contains the relative proportion of cell types in the corresponding mixture profile. 167 | 2. `weights` **(cellx x celltypes)**. Unnormalized (softmax) proportions. 168 | 3. `deconv_data$component` **(cells x genes)**. Reconstructions of the single cell data per cell type. 169 | 4. `deconv_data$purified` **(cells x genes)**. Mixture profiles purified to expression consistent with a specific cell type. 170 | 5. The remaining outputs are diagnostics from training/testing used for model evaluation/selection. 171 | 172 | ## Visualize the results of deconvolution 173 | 174 | ```R 175 | library(ComplexHeatmap) 176 | library(circlize) 177 | 178 | ## In the case of CellBench we have labels on the mixture data! 179 | mixture.labels = as.character(cellbenchSeuratObj[["cell.type"]][which(cellbenchSeuratObj[["platform"]] != 'CEL-seq'),1]) 180 | 181 | ## Extract the estimated proportions from the final training step and name the cell type columns 182 | proportions = deconvResults@proportions$`10000` 183 | 184 | ## Create an annotation for our heatmap 185 | row_anno = HeatmapAnnotation( 186 | mixture.type = factor(mixture.labels, levels=1:max(mixture.labels)), 187 | which="row") 188 | 189 | ## Another way to define annotations 190 | colors = c("blue", "green", "purple"); names(colors) = colnames(proportions); 191 | col_anno = HeatmapAnnotation( 192 | cell.type = colnames(proportions), 193 | col = list(cell.type = colors), 194 | which="column") 195 | 196 | ## Plot results 197 | png(paste0("~/proportion_heatmap.png"), width=16, height=16, units='in', res=300) 198 | heatmap = Heatmap(proportions, 199 | top_annotation = col_anno, 200 | col = colorRamp2(c(0, 1), c('white', 'red')), 201 | cluster_rows = T, 202 | cluster_columns = T, 203 | show_row_names = F, 204 | show_column_names = T, 205 | show_row_dend = T, 206 | show_column_dend = F, 207 | column_title = "Cell type proportions", 208 | row_title = "Mixture profiles", 209 | na_col = 'white', 210 | column_names_gp = gpar(fontsize = 24), 211 | column_title_gp = gpar(fontsize = 24), 212 | row_title_gp = gpar(fontsize = 24), 213 | column_names_max_height = unit(20, "cm"), 214 | row_dend_width = unit(3, "cm"), 215 | row_km=4, 216 | show_heatmap_legend=T, 217 | border=T) + row_anno 218 | draw(heatmap) 219 | dev.off() 220 | ``` 221 | ![Proportions](https://github.com/ucdavis/quonlab/blob/master/development/deconvRelease/figures/proportion_heatmap.png) 222 | 223 | ## Masking to marker genes improves deconvolution 224 | 225 | Now lets tell the deconvolution aspect of `` to only utilize marker genes. 226 | 227 | ```R 228 | ## Now we define the network architecture and run deconvolution! 229 | deconvModel$deconvolve(marker_gene_mask = marker_gene_mask, 230 | max_steps_component = 100L, 231 | num_latent_dims = 64L, 232 | num_layers = 3L, 233 | hidden_unit_2power = 9L, 234 | batch_norm_layers = 'True') 235 | 236 | ## Convert the python class to an R S4 class. check: str(deconvResults) 237 | deconvResults = convertDeconv(deconvModel) 238 | ``` 239 | 240 | Platform: UNKNOWN 241 | Classifier: Programming Language :: Python :: 3 242 | Classifier: License :: OSI Approved :: MIT License 243 | Classifier: Operating System :: OS Independent 244 | Requires-Python: >=3.6 245 | Description-Content-Type: text/markdown 246 | -------------------------------------------------------------------------------- /src/scProjection/deconvModel.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | ## Basics 7 | import sys, os, glob, gzip, time, traceback 8 | 9 | ## Verbosity of tensorflow output. Filters: (1 INFO) (2 WARNINGS) (3 ERRORS) 10 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 11 | 12 | ## Math 13 | import numpy as np 14 | 15 | ## tensorflow imports 16 | # import tensorflow as tf 17 | import tensorflow.compat.v1 as tf 18 | tf.disable_v2_behavior() 19 | import logging 20 | import numpy as np 21 | from tensorflow.python.platform import app 22 | from tensorflow.python.platform import flags 23 | from sklearn.model_selection import StratifiedShuffleSplit 24 | 25 | from . import utils 26 | from . import vaeTrain 27 | 28 | class deconvModel(object): 29 | def __init__(self, component_data, component_label, mixture_data, save_label=None): 30 | ## Required parameters 31 | self.component_data = component_data 32 | self.component_label = component_label 33 | self.mixture_data = mixture_data 34 | if save_label is None: 35 | self.save_label = self.component_label 36 | else: 37 | self.save_label = save_label 38 | 39 | ## Deconvolution 40 | self.celltypes = np.unique(self.component_label) 41 | 42 | ## Log model state and save locations 43 | self.results_dir = None 44 | self.flags = None 45 | 46 | ## train, test split 47 | self.train_index = None 48 | self.test_index = None 49 | 50 | ## Initialize results class 51 | self.deconvResults = utils.deconvResult(celltypeNames = self.celltypes) 52 | 53 | def deconvolve( self, 54 | ## Masking 55 | marker_gene_mask = None, 56 | hvg_mask = None, 57 | ## Learning rates (Careful changing these) 58 | component_learning_rate = 1e-3, 59 | proportion_learning_rate = 1e-3, 60 | combined_learning_rate = 1e-4, 61 | decay_lr = False, 62 | ## Network architecture 63 | num_dr = 50, 64 | num_latent_dims = 32, 65 | num_layers = 3, 66 | hidden_unit_2power = 9, 67 | batch_norm_layers = False, 68 | decoder_var = 'per_sample', 69 | latent_x = False, 70 | seed = 1234, 71 | ## Early stopping 72 | early_stopping = True, 73 | early_min_step = 100, 74 | max_patience = 100, 75 | ## VAE training parameters 76 | max_steps_component = 5000, 77 | batch_size_component = 100, 78 | corr_check = False, 79 | ## Combined training parameters 80 | deconvolution = True, 81 | batch_correction = True, 82 | max_steps_combined = 10000, 83 | batch_size_mixture = 100, 84 | max_steps_proportion = 1000, 85 | ## KL term 86 | KL_weight = 1.0, 87 | marker_weight = 1, 88 | ## MCMC 89 | num_monte_carlo = 15, 90 | ## L2 Reg term 91 | tied_model_l2 = True, 92 | ## Training Setup 93 | training_method = 'train', 94 | num_folds = 1, 95 | heldout_percent = 0.33, 96 | ## Hardware params 97 | cuda_device = 0, 98 | ## Logging params 99 | log_step = 5000, 100 | print_step = 100, 101 | log_results = True, 102 | save_results = False, 103 | log_samples = True, 104 | m_phate_logging = False, 105 | ## Path params 106 | logdir = './tmp/', 107 | save_to_object = True, 108 | save_to_disk = False, 109 | model_name = 'default' ): 110 | 111 | ## Try to run deconv, if fail, be sure to clean up tensorflow properly. 112 | try: 113 | ## Setup annotation and masks 114 | component_labels, \ 115 | marker_gene_masks, \ 116 | hvg_masks, \ 117 | max_steps_components = utils.setup_annotations_masks(self.component_label, 118 | marker_gene_mask, 119 | hvg_mask, 120 | max_steps_component, 121 | self.celltypes) 122 | ## Tensorflows variant of argparse 123 | FLAGS = flags.FLAGS 124 | 125 | ## Result flags 126 | flags.DEFINE_string('logdir', os.path.abspath(logdir), 'Save directory.') 127 | flags.DEFINE_string('model', model_name, 'Name of data to use for saving.') 128 | 129 | ## Saving flags 130 | flags.DEFINE_boolean('log_samples', str(log_samples), 'To log samples from each VAE') 131 | flags.DEFINE_boolean('m_phate_logging', str(m_phate_logging), 'Logging for m-phate viz') 132 | flags.DEFINE_boolean('log_results', str(log_results), 'To log detailed results and model files') 133 | flags.DEFINE_boolean('save_results', str(save_results), 'To save detailed results and model files') 134 | flags.DEFINE_boolean('save_to_object', str(save_to_object), 'Purified data is saved to results object') 135 | flags.DEFINE_boolean('save_to_disk', str(save_to_disk), 'Purified data is saved out to disk') 136 | 137 | ## Reporting 138 | flags.DEFINE_integer('log_step', log_step, 'When to output model summaries.') 139 | flags.DEFINE_integer('print_step', print_step, 'When to print out sumamries.') 140 | flags.DEFINE_boolean('monitor_bulk_mse', False, "Should bulk mse be reported during scVAE training") 141 | 142 | ## Prior 143 | flags.DEFINE_integer('num_dr', num_dr, 'Number of dimensions prior to network.') 144 | 145 | ## Component model architecture flags 146 | flags.DEFINE_integer('latent_dims', str(num_latent_dims), 'dim of latent Z') 147 | flags.DEFINE_float('KL_weight', KL_weight, 'Weight on KL term of elbo.') 148 | flags.DEFINE_boolean('KL_warmup', True, 'Slowly increase KL_weight during training.') 149 | flags.DEFINE_integer('num_monte_carlo', num_monte_carlo, 'number of samples to estimate likelihood') 150 | flags.DEFINE_string('decoder_variance', decoder_var, 'Structure of decoder distribution') ## "per_gene", "per_sample", 'default (unit variance)' 151 | flags.DEFINE_boolean('latent_x', latent_x, 'Should the purified data be treated as latent or not') 152 | 153 | ## Early stopping flags 154 | flags.DEFINE_boolean('early_stopping', early_stopping, 'Should early stopping be performed.') 155 | flags.DEFINE_integer('early_min_step', early_min_step, 'How many steps before early stopping procedure starts.') 156 | flags.DEFINE_integer('max_patience', max_patience, 'How long to wait before early stopping.') 157 | 158 | flags.DEFINE_float('component_learning_rate', component_learning_rate, 'Initial learning rate for VAE training.') 159 | flags.DEFINE_float('proportion_learning_rate', proportion_learning_rate, "Initial learning rate for proportion est.") 160 | flags.DEFINE_float('combined_learning_rate', combined_learning_rate, 'Initial learning rate for VAE correction.') 161 | flags.DEFINE_float('min_learning_rate', 1e-8, 'Minimum learning rate.') 162 | flags.DEFINE_boolean('decay_lr', str(decay_lr), 'Should learning rate be decayed') 163 | flags.DEFINE_float('decay_rate', 0.3, 'How fast to decay learning rate.') 164 | flags.DEFINE_integer('decay_step', 1000, 'Decay interval') 165 | 166 | flags.DEFINE_integer('max_steps_component', str(max(max_steps_components.values())), 'maximum number of traning iterations') 167 | flags.DEFINE_integer('max_steps_combined', str(max_steps_combined), 'maximum number of traning iterations') 168 | flags.DEFINE_integer('batch_size_component', str(batch_size_component), 'size of minibatch') 169 | flags.DEFINE_integer('batch_size_mixture', str(batch_size_mixture), 'size of minibatch') 170 | flags.DEFINE_boolean('batch_norm', str(batch_norm_layers), 'To include batch_norm layers in model') 171 | 172 | flags.DEFINE_boolean('kl_analytic', True, 'Built in or manual KL computation, current version of tensorflow probability has a bug. Set to False.') 173 | flags.DEFINE_integer('seed', str(seed), 'random seed for reproducability') 174 | 175 | ## Data specific flags 176 | flags.DEFINE_boolean('deconvolution', deconvolution, 'Should deconvolution be performed') 177 | flags.DEFINE_boolean('rec_project', False, 'Will pc space be used as input') 178 | flags.DEFINE_boolean('L2_norm_data', False, 'Will perform gene-wise l2 normalization.') 179 | 180 | ## Maskings 181 | flags.DEFINE_string('component', None, 'Subset to just one component (for debugging)') 182 | flags.DEFINE_string('component_remove', None, 'Remove specific component from reference data.') 183 | 184 | flags.DEFINE_boolean('tpm_softmax', False, 'Should TPM data be rescaled instead of reconstruction to actual values') 185 | flags.DEFINE_integer('tpm_scale_factor', 10000, 'Scale factor for TPM') 186 | 187 | ## Component model training 188 | flags.DEFINE_boolean('train_component', True, 'Should component network weights be (re)trained.') 189 | flags.DEFINE_integer('num_layers', str(num_layers), "Number of encoder/decoder NN layer.") 190 | flags.DEFINE_integer('hidden_unit_2power', str(hidden_unit_2power), "Starting number of hidden units in the first layer.") 191 | flags.DEFINE_boolean('tied_model_l2', tied_model_l2, 'L2 regularization between prior and current model weights') 192 | 193 | ## Combined model flags 194 | flags.DEFINE_integer('proportion_steps', str(max_steps_proportion), 'How long to train combined model before updating VAE(s)') 195 | flags.DEFINE_boolean('mixture_softmax', True, 'Should softmax be used to compute mixture probabilities.') 196 | flags.DEFINE_boolean('combined_corr_check', corr_check, 'Should pearson correlation of mixture and reconstructed mixture be recorded? (Slow)') 197 | flags.DEFINE_boolean('component_corr_check', corr_check, 'Should pearson correlation of mixture and reconstructed mixture be recorded? (Slow)') 198 | 199 | ## Combined loss weights 200 | flags.DEFINE_boolean('batch_correction', batch_correction, 'Should combined network update components') 201 | flags.DEFINE_integer('marker_weight', marker_weight, 'Initial learning rate for VAE training.') 202 | 203 | ## regularizer options (UNUSED) 204 | flags.DEFINE_float('l1_reg_weight', 1e-5, 'Weight of l1 regularizer for combined model.') 205 | flags.DEFINE_string('input_weight_reg', 'None', 'Should the input gene specific weights have a regularizer') ## Sparse adjustment 206 | flags.DEFINE_string('output_weight_reg', 'None', 'Should the output geme specific weights have a regularizer') ## Sparse adjustment 207 | flags.DEFINE_boolean('mix_weight_reg', False, 'Should the mixture weights have a regularizer') ## Mixture weights 208 | 209 | ## validation setup 210 | flags.DEFINE_string('training_method', training_method, 'How many, if any, folds to perform CV') 211 | flags.DEFINE_integer('kfold_validation', str(num_folds), "How many, if any, folds to perform CV") 212 | flags.DEFINE_float('heldout_percent', str(heldout_percent), "How much data to holdout for testing") 213 | 214 | ## HARDWARE 215 | flags.DEFINE_string('cuda_device', str(cuda_device), 'Select the GPU for this job') 216 | 217 | ## Remove some logging 218 | logging.getLogger('tensorflow').setLevel(logging.FATAL) 219 | 220 | ## Hardware configurations 221 | config = utils.define_hardware(FLAGS) 222 | 223 | ## Setup saving directory 224 | FLAGS = utils.initialize_model_dir(FLAGS) 225 | 226 | ## Write out all run options for reproducability 227 | utils.write_flags(FLAGS) 228 | 229 | ## Adjust data based on training mode 230 | if FLAGS.training_method == "train": 231 | deconvResults = vaeTrain.runVAE(resultsObj = self.deconvResults, 232 | FLAGS=FLAGS, 233 | config=config, 234 | component_data=self.component_data, 235 | component_labels=component_labels, 236 | save_labels=self.save_label, 237 | mixture_data=self.mixture_data, 238 | marker_gene_masks=marker_gene_masks, 239 | hvg_masks=hvg_masks, 240 | max_steps_components=max_steps_components, 241 | component_reconstruction_data=None, 242 | mixture_reconstruction_data=self.mixture_data, 243 | component_valid_data=None, 244 | component_valid_labels=None, 245 | mixture_valid_data=None, 246 | mixture_valid_labels=None, 247 | loadings=None) 248 | elif FLAGS.training_method == "validate": 249 | ## Setup kfold iterator object 250 | split_iter = StratifiedShuffleSplit(n_splits=2, test_size=FLAGS.heldout_percent, random_state=FLAGS.seed) 251 | for train_index, test_index in split_iter.split(self.component_data, component_labels): 252 | deconvResults = vaeTrain.runVAE(resultsObj = self.deconvResults, 253 | FLAGS=FLAGS, 254 | config=config, 255 | component_data=self.component_data[train_index,:], 256 | component_labels=component_labels[train_index], 257 | save_labels=self.save_label, 258 | mixture_data=self.mixture_data, 259 | marker_gene_masks=marker_gene_masks, 260 | hvg_masks=hvg_masks, 261 | max_steps_components=max_steps_components, 262 | component_reconstruction_data=None, 263 | mixture_reconstruction_data=self.mixture_data, 264 | component_valid_data=self.component_data[test_index,:], 265 | component_valid_labels=component_labels[test_index], 266 | mixture_valid_data=None, 267 | mixture_valid_labels=None, 268 | loadings=None) 269 | self.train_index = train_index 270 | self.test_index = test_index 271 | else: 272 | print("Mode not recognized, leaving data as is.") 273 | 274 | ## Final logging 275 | self.results_dir = FLAGS.logdir 276 | self.flags = tf.app.flags.FLAGS.flag_values_dict() 277 | except: 278 | print("An error occured: ") 279 | traceback.print_exc() 280 | finally: 281 | ## Clean up 282 | FLAGS.remove_flag_values(FLAGS.flag_values_dict()) 283 | 284 | def check_args(self): 285 | print("test") 286 | -------------------------------------------------------------------------------- /src/scProjection/vaeBackend.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2018 Quonlab. 3 | Copyright 2016 Google Inc. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import numpy as np 23 | import sys, os 24 | 25 | import tensorflow.compat.v1 as tf 26 | tf.disable_v2_behavior() 27 | import tensorflow_probability as tfp 28 | from tensorflow.python.framework import constant_op 29 | from tensorflow.python.framework import ops 30 | from tensorflow.python.ops import init_ops 31 | from tensorflow.python.ops import array_ops 32 | from tensorflow.python.ops import clip_ops 33 | from tensorflow.python.ops import math_ops 34 | from tensorflow.python.ops import control_flow_ops 35 | from tensorflow.python.ops import variables as tf_variables 36 | from tensorflow.python.training import optimizer as tf_optimizer 37 | from tensorflow.python.training import training_util 38 | 39 | from . import architectures 40 | 41 | class modelVAE(object): 42 | def __init__(self, FLAGS, latent_dims, input_size, output_size, num_samples, reconstruction, loadings, scope, marker_masks=None, hvg_masks=None): 43 | ## Info 44 | print("Constructing VAE: "+scope) 45 | print(" |") 46 | 47 | ## Create component VAE model 48 | with tf.variable_scope(scope, tf.AUTO_REUSE): 49 | ## Params for VAE 50 | self.scope = scope 51 | self.latent_dims = latent_dims 52 | self.input_size = input_size 53 | self.output_size = np.count_nonzero(hvg_masks[self.scope]) if hvg_masks is not None else input_size 54 | self.num_samples = num_samples 55 | self.batch_size = np.amin((FLAGS.batch_size_component, self.num_samples)) 56 | 57 | ## Maskings 58 | self.marker_mask = marker_masks[self.scope] if marker_masks is not None else None 59 | self.non_marker_mask = np.ones(self.output_size)-self.marker_mask if marker_masks is not None else None 60 | self.hvg_mask = hvg_masks[self.scope] if hvg_masks is not None else None 61 | 62 | ## Params for early stopping 63 | self.early_stop = False 64 | self.patience = 0 65 | self.best_val_loss = 1e20 66 | 67 | ## L2_tied 68 | self.tied_model_l2 = tf.constant(value=0, dtype=tf.float32) 69 | 70 | ## Architectures (parameterizing dists) 71 | self.encoder_func = tf.make_template('encoder', 72 | architectures.vae_encoder, 73 | FLAGS=FLAGS, 74 | ndim=self.latent_dims, 75 | num_layers=FLAGS.num_layers, 76 | hidden_unit_2power=FLAGS.hidden_unit_2power, 77 | scope=self.scope, 78 | batch_norm=FLAGS.batch_norm, 79 | create_scope_now_=False) 80 | self.decoder_func = tf.make_template('decoder', 81 | architectures.vae_decoder, 82 | FLAGS=FLAGS, 83 | output_size=self.output_size, 84 | num_layers=FLAGS.num_layers, 85 | hidden_unit_2power=FLAGS.hidden_unit_2power, 86 | scope=self.scope, 87 | batch_norm=FLAGS.batch_norm, 88 | create_scope_now_=False) 89 | 90 | self.latent_prior = architectures.vae_prior(ndim=self.latent_dims, scope=self.scope) 91 | self.reconstruction_prior = architectures.vae_prior(ndim=self.output_size, scope=self.scope) 92 | 93 | ## Basics 94 | self.global_step = tf.train.get_or_create_global_step() 95 | self.step = tf.Variable(1, name='step', trainable=False, dtype=tf.int32) 96 | 97 | ## Set up inputs. Mini-batch 98 | with tf.name_scope("component_data"): 99 | self.index = tf.range(start=0, limit=self.num_samples, dtype=tf.int32) 100 | self.data_ph = tf.placeholder(tf.float32, (None, self.input_size)) 101 | self.rec_data_ph = tf.placeholder(tf.float32, (None, self.input_size)) 102 | self.train_dataset = tf.data.Dataset.from_tensor_slices((self.data_ph, self.index)).shuffle(self.num_samples).repeat().batch(self.batch_size).prefetch(buffer_size=tf.data.experimental.AUTOTUNE) 103 | self.test_dataset = tf.data.Dataset.from_tensor_slices((self.data_ph, self.index)).batch(tf.shape(self.data_ph, out_type=tf.dtypes.int64)[0]).prefetch(buffer_size=tf.data.experimental.AUTOTUNE) 104 | self.iterator = tf.data.Iterator.from_structure(self.train_dataset.output_types, 105 | self.train_dataset.output_shapes) 106 | # self.iterator = self.dataset.make_initializable_iterator() 107 | self.data_batch, self.data_index = self.iterator.get_next() 108 | ## Choice of datasets 109 | self.train_init_op = self.iterator.make_initializer(self.train_dataset) 110 | self.test_init_op = self.iterator.make_initializer(self.test_dataset) 111 | 112 | ## Holds purified bulk data, helper for combined model 113 | with tf.name_scope("purified_data"): 114 | self.pure_data = None 115 | 116 | ## Compute embeddings and logits. 117 | self.emb = self.encoder_func(inputs=self.data_batch, is_training=True) 118 | self.latent_sample = self.emb.sample(FLAGS.num_monte_carlo) 119 | self.rec = self.decoder_func(inputs=self.latent_sample, is_training=True) 120 | if FLAGS.tpm_softmax is True: 121 | self.rec_data = tf.nn.softmax(tf.reduce_mean(self.rec.mean(), axis=0), axis=-1) * FLAGS.tpm_scale_factor 122 | else: 123 | #self.rec_data = tf.reduce_mean(self.rec.mean(), axis=0) 124 | self.rec_data = tf.reduce_mean(self.rec.mean(), axis=0) 125 | 126 | ## Set up KL weight warmup 127 | with tf.name_scope("KL_warmup"): 128 | if FLAGS.KL_warmup is True: self.KL_weight = tf.cast(tf.minimum( 129 | FLAGS.KL_weight*(tf.cast(self.step, dtype=tf.float32)/(FLAGS.max_steps_component/tf.constant(2.0, dtype=tf.float32))), 130 | FLAGS.KL_weight), 131 | dtype=tf.float32) 132 | else: self.KL_weight = FLAGS.KL_weight 133 | tf.compat.v1.summary.scalar("KL_weight", self.KL_weight) 134 | tf.compat.v1.summary.scalar("training_step", self.step) 135 | ## Monitor patience counter 136 | tf.compat.v1.summary.scalar("Patience", self.patience) 137 | 138 | ## Add VAE specific loss 139 | self.add_ELBO_loss(FLAGS) 140 | 141 | ## Metrics 142 | self.mse = tf.losses.mean_squared_error(self.rec_data, self.data_batch, loss_collection="metrics") 143 | tf.compat.v1.summary.scalar(self.scope+"_component_mse", self.mse) 144 | 145 | ## Correlation for tracking training performance 146 | if FLAGS.component_corr_check is True: 147 | self.corr = tfp.stats.correlation(self.rec_data, self.data_batch, sample_axis=-1, event_axis=0) 148 | self.corr = tf.reduce_mean(tf.linalg.tensor_diag_part(self.corr), axis=0) 149 | tf.compat.v1.summary.histogram(scope+"_component_correlation", self.corr) 150 | 151 | ## Set up learning rate 152 | if FLAGS.decay_lr is True: 153 | with tf.name_scope("learning_rate"): 154 | self.t_learning_rate = tf.maximum( 155 | tf.train.exponential_decay( 156 | FLAGS.component_learning_rate, 157 | self.step, 158 | FLAGS.decay_step, 159 | FLAGS.decay_rate), 160 | FLAGS.min_learning_rate) 161 | else: 162 | self.t_learning_rate = FLAGS.component_learning_rate 163 | 164 | ## Create training operation 165 | self.update_op=tf.compat.v1.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.scope) 166 | self.create_train_op(self.t_learning_rate, update_ops=self.update_op) 167 | 168 | ## Batch correction loss terms 169 | if self.marker_mask is not None: 170 | self.compute_batch_loss(FLAGS) 171 | 172 | ## Record VAE encoder dist. parameters 173 | # tf.compat.v1.summary.histogram(scope+"_mean", self.emb.mean()) 174 | # tf.compat.v1.summary.histogram(scope+"_variance", self.emb.variance()) 175 | 176 | # ## Record VAE decoder dist. parameters 177 | # tf.compat.v1.summary.histogram(scope+"_mean_decoder", self.rec.mean()) 178 | # tf.compat.v1.summary.histogram(scope+"_variance_decoder", self.rec.variance()) 179 | 180 | ## Define weight saving 181 | #self.record_network_weights() 182 | #self.tied_weight_regularizer() 183 | 184 | ## Define saving operations 185 | self.define_saver_ops(FLAGS=FLAGS, is_training=False) 186 | self.define_inverse_saver_ops(FLAGS=FLAGS, is_training=True) 187 | self.define_summary_ops() 188 | 189 | def add_logp_loss(self, FLAGS): 190 | with tf.variable_scope('LogLikelihood'): 191 | ## `distortion` is the negative log likelihood. 192 | self.distortion = -self.rec.log_prob(self.data_batch) 193 | self.avg_distortion = tf.reduce_mean(input_tensor=self.distortion) 194 | tf.compat.v1.summary.scalar("Reconstruction", self.avg_distortion) 195 | 196 | def add_KL_loss_encoder(self, FLAGS): 197 | with tf.variable_scope('KL'): 198 | self.rate = tfp.distributions.kl_divergence(self.emb, self.latent_prior, allow_nan_stats=False) 199 | self.avg_KL_div = tf.reduce_mean(input_tensor=self.rate) 200 | tf.compat.v1.summary.scalar("KL_encoder", self.avg_KL_div) 201 | 202 | def add_KL_loss_decoder(self, FLAGS): 203 | with tf.variable_scope('KL_decoder'): 204 | self.rate_decoder = tfp.distributions.kl_divergence(self.rec, self.reconstruction_prior, allow_nan_stats=False) 205 | self.avg_KL_div_decoder = tf.reduce_mean(input_tensor=self.rate_decoder) 206 | tf.compat.v1.summary.scalar("KL_decoder", self.avg_KL_div_decoder) 207 | 208 | def add_ELBO_loss(self, FLAGS, smoothing=0.0): 209 | with tf.variable_scope('ELBO'): 210 | ## Components of elbo 211 | self.add_logp_loss(FLAGS) 212 | self.add_KL_loss_encoder(FLAGS) 213 | ## Per sample loss 214 | if FLAGS.latent_x is True: 215 | self.add_KL_loss_decoder(FLAGS) 216 | elbo_local = -((self.KL_weight * self.rate) + (self.KL_weight * self.rate_decoder) + self.distortion) 217 | else: 218 | elbo_local = -((self.KL_weight * self.rate) + self.distortion) 219 | ## Mean loss for batch 220 | self.elbo = tf.reduce_mean(elbo_local) 221 | tf.losses.add_loss(-self.elbo, 222 | loss_collection=tf.GraphKeys.LOSSES) 223 | tf.compat.v1.summary.scalar("elbo", -self.elbo) 224 | 225 | def create_train_op(self, learning_rate, update_ops=None, check_numerics=True, summarize_gradients=False): 226 | """Create and return training operation.""" 227 | with tf.control_dependencies(update_ops): 228 | with tf.name_scope('Optimizer'): 229 | self.l2_reg = tf.losses.get_regularization_loss(scope=self.scope) 230 | #self.train_loss = tf.add_n(tf.losses.get_losses(loss_collection=tf.GraphKeys.LOSSES, scope=self.scope) + tf.losses.get_regularization_losses(scope=self.scope)) 231 | self.train_loss = tf.add_n(tf.losses.get_losses(loss_collection=tf.GraphKeys.LOSSES, scope=self.scope) + tf.losses.get_regularization_losses(scope=self.scope)) 232 | ## Minimize loss function 233 | optimizer = tf.train.AdamOptimizer(learning_rate) 234 | self.train_op = optimizer.minimize(loss=self.train_loss, global_step=self.step, var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope)) 235 | ## Monitor 236 | #tf.compat.v1.summary.scalar('Learning_Rate_component', learning_rate) 237 | #tf.compat.v1.summary.scalar('Loss_Total_component', self.train_loss) 238 | #tf.compat.v1.summary.scalar('Loss_Regularization', self.l2_reg) 239 | return(self.train_op) 240 | 241 | ## Operation to compute batch correction loss 242 | def compute_batch_loss(self, FLAGS): 243 | with tf.variable_scope(self.scope, tf.AUTO_REUSE): 244 | ## Reconstruction of markers: marker_mask 245 | marker_measured = tf.multiply(self.data_batch, self.marker_mask) 246 | marker_reconstruction = tf.multiply(tf.reduce_mean(self.rec.mean(), axis=0), self.marker_mask) 247 | ## Reweight marker priority using partial observations 248 | # if FLAGS.marker_weight > 1: 249 | # marker_measured = tf.tile(marker_measured, (FLAGS.marker_weight, 1)) 250 | # marker_reconstruction = tf.tile(marker_reconstruction, (FLAGS.marker_weight, 1)) 251 | self.marker_mse = tf.losses.mean_squared_error(marker_measured, 252 | marker_reconstruction, 253 | weights=1.0, 254 | scope=self.scope) 255 | ## Reconstruction of non_markers: np.ones - marker_mask 256 | if np.sum(self.non_marker_mask) > 0: 257 | non_marker_measured = tf.boolean_mask(self.data_batch, self.non_marker_mask) 258 | non_marker_reconstruction = tf.multiply(tf.reduce_mean(self.rec.mean(), axis=0), self.non_marker_mask) 259 | self.non_marker_mse = tf.losses.mean_squared_error(non_marker_measured, 260 | non_marker_reconstruction, 261 | weights=1., 262 | scope=self.scope) 263 | else: 264 | self.non_marker_mse = 0.0 265 | 266 | # ## Operation to compute batch correction loss 267 | # def compute_batch_loss(self, FLAGS): 268 | # with tf.variable_scope(self.scope, tf.AUTO_REUSE): 269 | # ## Reconstruction of markers: marker_mask 270 | # marker_measured = tf.boolean_mask(self.data_batch, self.marker_mask, axis=1) 271 | # marker_reconstruction = tf.boolean_mask(tf.reduce_mean(self.rec.sample(FLAGS.num_monte_carlo)[0], axis=0), self.marker_mask, axis=1) 272 | # ## Reweight marker priority using partial observations 273 | # if FLAGS.marker_weight > 1: 274 | # marker_measured = tf.tile(marker_measured, (FLAGS.marker_weight, 1)) 275 | # marker_reconstruction = tf.tile(marker_reconstruction, (FLAGS.marker_weight, 1)) 276 | # self.marker_mse = tf.losses.mean_squared_error(marker_measured, 277 | # marker_reconstruction, 278 | # weights=1.0, 279 | # scope=self.scope) 280 | # ## Reconstruction of non_markers: np.ones - marker_mask 281 | # if np.sum(self.non_marker_mask) > 0: 282 | # non_marker_measured = tf.boolean_mask(self.data_batch, self.non_marker_mask, axis=1) 283 | # non_marker_reconstruction = tf.boolean_mask(tf.reduce_mean(self.rec.sample(FLAGS.num_monte_carlo)[0], axis=0), self.non_marker_mask, axis=1) 284 | # self.non_marker_mse = tf.losses.mean_squared_error(non_marker_measured, 285 | # non_marker_reconstruction, 286 | # weights=1., 287 | # scope=self.scope) 288 | # else: 289 | # self.non_marker_mse = 0.0 290 | 291 | # def record_network_weights(self): 292 | # ## Get all weights from current VAE and save just their names 293 | # self.weight_names = [tensor.name for tensor in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope) if tensor.name.endswith('kernel:0')] 294 | # self.saved_weights = {} 295 | # ## Holds all the ops required to "fill", fixed weight matrices after training VAEs. 296 | # for name in self.weight_names: 297 | # weight_matrix = tf.get_default_graph().get_tensor_by_name(name) 298 | # self.saved_weights[name] = tf.Variable(tf.zeros(weight_matrix.shape), shape=weight_matrix.shape) 299 | # #tf.compat.v1.summary.histogram("fixed_"+name[:-2], self.saved_weights[name]) 300 | 301 | def assign_network_weights(self, sess, assign=True): 302 | ## Hardcopy the weight matrices from dense layers 303 | for name in self.weight_names: 304 | print("RECORDING: "+name) 305 | if assign is True: 306 | self.saved_weights[name].load(sess.run(name)) 307 | np.savetxt('/home/ucdnjj/matrix_files/'+name[:-2].replace("/","_")+'.csv', sess.run(self.saved_weights[name]), delimiter=',') 308 | else: 309 | np.savetxt('/home/ucdnjj/matrix_files_end/'+name[:-2].replace("/","_")+'.csv', sess.run(self.saved_weights[name]), delimiter=',') 310 | 311 | # def tied_weight_regularizer(self): 312 | # with tf.name_scope('tied_model_l2'): 313 | # for name in self.weight_names: 314 | # self.tied_model_l2 += tf.nn.l2_loss(tf.subtract(tf.get_default_graph().get_tensor_by_name(name), self.saved_weights[name])) 315 | # tf.compat.v1.summary.scalar("tied_model_l2", self.tied_model_l2) 316 | 317 | def define_summary_ops(self): 318 | ## VAE's summary set 319 | self.summary_op = tf.summary.merge_all(scope=self.scope) 320 | 321 | ## Defines copies of the network that take in placeholders for final pass of compelete dataset 322 | def define_saver_ops(self, FLAGS, is_training): 323 | ## Saver helper function for encoder 324 | self.fdata_ph = tf.placeholder(tf.float32, (None, self.input_size)) 325 | self.emb_saver = self.encoder_func(inputs=self.fdata_ph, is_training=is_training) 326 | self.emb_saver_mean = self.emb_saver.mean() 327 | self.emb_saver_var = self.emb_saver.variance() 328 | self.emb_saver_sample = self.emb_saver.sample(FLAGS.num_monte_carlo)[0] 329 | 330 | ## Saver helper function for decoder 331 | self.emb_ph = tf.placeholder(tf.float32, (None, self.latent_dims)) 332 | self.rec_saver = self.decoder_func(inputs=self.emb_ph, is_training=is_training) 333 | self.rec_saver_lp = self.rec_saver.log_prob(self.fdata_ph) 334 | 335 | ## Saver helper function for sampled decoder 336 | self.emb_sampled_ph = tf.placeholder(tf.float32, (None, None, self.latent_dims)) 337 | self.rec = self.decoder_func(inputs=self.emb_sampled_ph, is_training=is_training) 338 | #self.rec_var = tf.reduce_mean(self.rec_var, axis=0) 339 | self.rec_mean = tf.reduce_mean(self.rec.mean(), axis=0) 340 | 341 | def define_inverse_saver_ops(self, FLAGS, is_training=True): 342 | self.inv_emb_saver = self.encoder_func(inputs=self.fdata_ph, is_training=is_training) 343 | self.inv_emb_saver_sample = self.inv_emb_saver.sample(FLAGS.num_monte_carlo)[0] 344 | self.inv_rec = self.decoder_func(inputs=self.emb_sampled_ph, is_training=is_training) 345 | self.inv_rec_mean = tf.reduce_mean(self.inv_rec.mean(), axis=0) 346 | 347 | def mixture_mse_monitor(self, data, FLAGS, is_training=False): 348 | emb_comp_temp = sess.run(self.emb_saver_sample, feed_dict={self.fdata_ph: data}) 349 | reconstruction = sess.run(self.rec_mean, feed_dict={self.emb_ph: emb_comp_temp}) 350 | return sess.run(tf.compat.v1.losses.mean_squared_error(data, reconstruction)) 351 | 352 | ## Evaluate endpoint op in batchs of data. 353 | def save_sampled(self, sess, resultsObj, datatype, scope, FLAGS, dir, nsample=100): 354 | with tf.variable_scope(self.scope, tf.AUTO_REUSE): 355 | samples = sess.run(self.latent_prior.sample(nsample)) 356 | rec_comp = sess.run(self.rec_sampled_saver, feed_dict={self.emb_sampled_ph: samples[tf.newaxis,]}) 357 | resultsObj.sample_saver(datatype, scope, samples, rec_comp) 358 | if FLAGS.save_results is True: 359 | np.savetxt(FLAGS.logdir + "/"+dir+"_results/reconstructions.csv", rec_comp, delimiter=",") 360 | 361 | ## Save scRNA outputs 362 | def save_results(self, sess, resultsObj, datatype, mode, stage, FLAGS, data, scope, dir, batch_size=100): 363 | ## Handle large mixture input 364 | rec_comp = []; emb_comp = []; rec_comp_lprob = []; var_comp = []; 365 | with tf.variable_scope(self.scope, tf.AUTO_REUSE): 366 | for i in range(0, data.shape[0], batch_size): 367 | data_batch = data[i:i + batch_size] 368 | emb_comp_temp = sess.run(self.emb_saver_sample, feed_dict={self.fdata_ph: data_batch}) 369 | emb_comp.append(emb_comp_temp) 370 | rec_comp_lprob.append(sess.run(self.rec_saver_lp, feed_dict={self.emb_ph: emb_comp_temp, self.fdata_ph: data_batch})) 371 | rec_comp.append(sess.run(self.rec_mean, feed_dict={self.emb_sampled_ph: emb_comp_temp[tf.newaxis,]})) 372 | #var_comp.append(sess.run(self.rec_var, feed_dict={self.emb_sampled_ph: emb_comp_temp[tf.newaxis,]})) 373 | 374 | ## Combined batched results 375 | emb_comp = np.concatenate(emb_comp) 376 | rec_comp_lprob = np.concatenate(rec_comp_lprob) 377 | rec_comp = np.concatenate(rec_comp) 378 | #if FLAGS.decoder_variance != 'fixed': 379 | # var_comp = np.concatenate(var_comp) 380 | ## 381 | # inv_rec_comp = np.concatenate(inv_rec_comp) 382 | 383 | if FLAGS.save_to_object is True: 384 | resultsObj.update_saver(datatype, mode, stage, scope, rec_comp, emb_comp, rec_comp_lprob, var_comp) 385 | 386 | if FLAGS.save_to_disk is True: 387 | print("Not setup") 388 | -------------------------------------------------------------------------------- /src/scProjection/vaeTrain.py: -------------------------------------------------------------------------------- 1 | 2 | #! /usr/bin/env python 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function ## 6 | 7 | import sys, os, glob, gzip, time 8 | from functools import partial 9 | from importlib import import_module 10 | 11 | ## Math 12 | import numpy as np 13 | 14 | ## tensorflow imports 15 | import tensorflow.compat.v1 as tf 16 | tf.disable_v2_behavior() 17 | import tensorflow_probability as tfp 18 | from tensorflow.python.platform import app 19 | from tensorflow.python.platform import flags 20 | 21 | ## 22 | from . import vaeBackend 23 | from . import batchBackend 24 | from . import deconvBackend 25 | from . import utils 26 | from . import saving_utils 27 | 28 | def runVAE(resultsObj, 29 | FLAGS, 30 | config, 31 | component_data, 32 | component_labels, 33 | save_labels, 34 | mixture_data, 35 | marker_gene_masks=None, 36 | hvg_masks=None, 37 | max_steps_components=None, 38 | component_reconstruction_data=None, 39 | mixture_reconstruction_data=None, 40 | component_valid_data=None, 41 | component_valid_labels=None, 42 | mixture_valid_data=None, 43 | mixture_valid_labels=None, 44 | loadings=None): 45 | """Initializes and trains component VAE(s) and combined model. 46 | Args: 47 | FLAGS (flags.FLAGS): Tensorflow flags. 48 | config (ConfigProto): Hardware configuration options. 49 | component_data (matrix): Component VAE input data. 50 | component_labels (str): Membership vector for each sample in component data. 51 | mixture_data (matrix): Combined model input data to estimate proportions for. 52 | marker_gene_masks (str): Genes to be used during deconvolution. 53 | component_reconstruction_data (matrix): Data to compare against component VAE reconstructions. 54 | mixture_reconstruction_data (matrix): Data to compare against combined model reconstructions. 55 | component_valid_data (matrix): Component VAE validation (test) data. 56 | component_valid_labels (str): Membership vector for each sample in validation (test) component data. 57 | mixture_valid_data (matrix): Component VAE validation (test) data. 58 | mixture_valid_labels (str): Membership vector for each sample in validation (test) component data. 59 | loadings (matrix): Projection matrix to convert from low dimensions to gene expression space. 60 | Returns: 61 | proportions: Mixture proportions for each cell with respect to all component VAE(s). 62 | weights: Unnormalized mixture probabilities. 63 | """ 64 | 65 | # ## Start deconvolution at top of tree (Python dictionary retains insertion order) 66 | # for key in component_labels.keys(): 67 | ## Dictionary containing all VAE models named by unique elements within datasets 68 | VAE = {} 69 | ## Define network structure 70 | graph = tf.Graph() 71 | with graph.as_default(): 72 | global_step_component = tf.train.get_or_create_global_step() 73 | ## Create component VAEs 74 | for scope in np.unique(component_labels): 75 | ## Set up encoder model. q(z|x) 76 | VAE[scope] = vaeBackend.modelVAE(FLAGS, 77 | latent_dims=FLAGS.latent_dims, 78 | input_size=component_data.shape[-1], 79 | output_size=component_data.shape[-1] if FLAGS.rec_project is False else component_reconstruction_data.shape[-1], ## Determines the output size of decoder 80 | num_samples=component_data[np.where(component_labels == scope)[0],:].shape[0], 81 | reconstruction=component_reconstruction_data, 82 | loadings=tf.cast(loadings, tf.float32) if loadings is not None else None, 83 | scope=scope, 84 | marker_masks=marker_gene_masks, 85 | hvg_masks=hvg_masks) 86 | 87 | ## Construct combined model for training/all mixture data which can potentially update VAEs 88 | if mixture_data is not None and FLAGS.deconvolution is True: 89 | 90 | deconvModel = deconvBackend.deconvModel(FLAGS=FLAGS, 91 | VAE=VAE, 92 | datasets=component_labels, 93 | input_size=mixture_data.shape[-1], 94 | output_size=mixture_data.shape[-1], ## Determines the output size of decoder 95 | num_samples=mixture_data.shape[0], 96 | reconstruction=mixture_data, 97 | marker_gene_masks=marker_gene_masks, 98 | loadings=tf.cast(loadings, tf.float32) if loadings is not None else None, 99 | scope="deconvolution") 100 | 101 | batchModel = batchBackend.batchModel(FLAGS=FLAGS, 102 | VAE=VAE, 103 | mixture_weights=deconvModel.mixture_weights, 104 | proportions=deconvModel.proportions, 105 | datasets=component_labels, 106 | input_size=mixture_data.shape[-1], 107 | output_size=mixture_data.shape[-1], ## Determines the output size of decoder 108 | num_samples=mixture_data.shape[0], 109 | reconstruction=mixture_data, 110 | loadings=tf.cast(loadings, tf.float32) if loadings is not None else None, 111 | scope="batch_correction") 112 | 113 | ## Tensorboard summary and model checkpoint saving 114 | summary_op, summary_writer, summary_writer_test, saver = utils.summary_and_checkpoint(FLAGS, graph) 115 | 116 | ## Training scope 117 | with tf.Session(graph=graph, config=config) as sess: 118 | 119 | ## Set the logging level for tensorflow to only fatal issues 120 | tf.logging.set_verbosity(tf.logging.FATAL) 121 | 122 | ## Define seed at the graph-level 123 | ## From docs: If the graph-level seed is set, but the operation seed is not: 124 | ## The system deterministically picks an operation seed in conjunction with 125 | ## the graph-level seed so that it gets a unique random sequence. 126 | tf.set_random_seed(FLAGS.seed) 127 | 128 | ## Initialize everything 129 | tf.global_variables_initializer().run() 130 | print("Done random initialization") 131 | 132 | if not os.path.exists(os.path.abspath(FLAGS.logdir)): 133 | os.makedirs(os.path.abspath(FLAGS.logdir)) 134 | 135 | ## Assert that nothing more can be added to the graph 136 | # tf.get_default_graph().finalize() 137 | 138 | ## Track epoch loss 139 | epoch_loss=[] 140 | 141 | ## Don't normalize output 142 | if component_reconstruction_data is None: 143 | component_reconstruction_data = component_data 144 | if mixture_reconstruction_data is None: 145 | mixture_reconstruction_data = mixture_data 146 | 147 | ## Initialize the component VAE Dataset iterators 148 | for scope in np.unique(component_labels): 149 | sess.run(VAE[scope].train_init_op, feed_dict={VAE[scope].data_ph: component_data[np.where(component_labels == scope)[0],:], 150 | VAE[scope].rec_data_ph: component_reconstruction_data[np.where(component_labels == scope)[0],:]}) 151 | 152 | ## Component training! 153 | if FLAGS.train_component is True: 154 | for step in range(1,FLAGS.max_steps_component+1): 155 | for scope in np.unique(component_labels): 156 | if (step <= max_steps_components[scope]) and (VAE[scope].early_stop is not True): 157 | with tf.variable_scope(scope, tf.AUTO_REUSE): 158 | _, _, summaries_train, train_loss, logp, kl_loss, mse_loss = sess.run([VAE[scope].train_op, VAE[scope].update_op, VAE[scope].summary_op, VAE[scope].train_loss, VAE[scope].avg_distortion, VAE[scope].avg_KL_div, VAE[scope].mse], 159 | options=None, 160 | run_metadata=None) 161 | ## Log mphate data 162 | if FLAGS.m_phate_logging is True and step < 500: 163 | embedding = sess.run(VAE[scope].emb_saver_sample, feed_dict={VAE[scope].fdata_ph: mixture_data}) 164 | resultsObj.update_mphate(scope, embedding) 165 | 166 | if step % FLAGS.print_step == 0 or step == FLAGS.max_steps_component: 167 | if FLAGS.component_corr_check is True: 168 | train_corr = sess.run(VAE[scope].corr) 169 | print("(%s) Step %s: %-6.2f MSE: %-8.4f LogP: %-8.4f KLL %-8.4f Corr %-8.4f" % (scope, step, train_loss, mse_loss, logp, kl_loss, train_corr)) 170 | else: 171 | print("(%s) Step %s: %-6.2f MSE: %-8.4f LogP: %-8.4f KLL %-8.4f" % (scope, step, train_loss, mse_loss, logp, kl_loss)) 172 | 173 | if FLAGS.monitor_bulk_mse is True: 174 | bulk_loss = VAE[scope].mixture_mse_monitor(mixture_data, FLAGS) 175 | print("%-6.2f" % bulk_loss) 176 | 177 | if component_valid_data is not None: 178 | comp_valid_data = component_valid_data[np.where(component_valid_labels == scope)[0],:] 179 | if comp_valid_data.shape[0] > 0 : 180 | ## Swap to testing data 181 | sess.run(VAE[scope].test_init_op, feed_dict={VAE[scope].data_ph: comp_valid_data, 182 | VAE[scope].rec_data_ph: comp_valid_data}) 183 | ## Test model 184 | if FLAGS.component_corr_check is True: 185 | summaries_test, test_loss, test_logp, test_kl_loss, test_mse_loss, test_corr = sess.run([VAE[scope].summary_op, VAE[scope].train_loss, VAE[scope].avg_distortion, VAE[scope].avg_KL_div, VAE[scope].mse, VAE[scope].corr]) 186 | else: 187 | summaries_test, test_loss, test_logp, test_kl_loss, test_mse_loss, test_corr = sess.run([VAE[scope].summary_op, VAE[scope].train_loss, VAE[scope].avg_distortion, VAE[scope].avg_KL_div, VAE[scope].mse]) 188 | ## Early stopping if enabled and minimum number of steps has been achieved 189 | if FLAGS.early_stopping is True and step >= FLAGS.early_min_step and FLAGS.component_corr_check is True: 190 | utils.early_stopping(FLAGS, VAE, scope, test_corr) 191 | 192 | ## Swap back to training data 193 | sess.run(VAE[scope].train_init_op, feed_dict={VAE[scope].data_ph: component_data[np.where(component_labels == scope)[0],:], 194 | VAE[scope].rec_data_ph: component_reconstruction_data[np.where(component_labels == scope)[0],:]}) 195 | else: 196 | ## Early stopping if enabled and minimum number of steps has been achieved 197 | if FLAGS.early_stopping is True and step >= FLAGS.early_min_step and FLAGS.component_corr_check is True: 198 | train_corr = sess.run(VAE[scope].corr) 199 | utils.early_stopping(FLAGS, VAE, scope, train_corr) 200 | 201 | ## Report loss 202 | if step % FLAGS.log_step == 0 or step == FLAGS.max_steps_component: 203 | summary_writer.add_summary(summaries_train, str(step)) 204 | summary_writer.flush() 205 | 206 | ## Record results 207 | resultsObj.update_component_metrics("train", scope, step, train_loss, mse_loss, logp, kl_loss) 208 | 209 | ## Run test data if available 210 | if component_valid_data is not None: 211 | if comp_valid_data.shape[0] > 0 : 212 | resultsObj.update_component_metrics("test", scope, step, test_loss, test_mse_loss, test_logp, test_kl_loss) 213 | ## Summary reports (Tensorboard), testing summary 214 | summary_writer_test.add_summary(summaries_test, str(step)) 215 | summary_writer_test.flush() 216 | 217 | 218 | ## Save model and summaries 219 | if step % FLAGS.log_step == 0 or step == FLAGS.max_steps_component: 220 | ## Summary reports (Tensorboard) 221 | summary_writer.add_summary(summaries_train, str(step)) 222 | summary_writer.flush() 223 | if component_valid_data is not None: 224 | summary_writer_test.add_summary(summaries_test, str(step)) 225 | summary_writer_test.flush() 226 | ## Write out graph 227 | if step == FLAGS.max_steps_component: 228 | save = saver.save(sess, os.path.abspath(FLAGS.logdir + '/ckpt/model_component.ckpt'), step) 229 | 230 | ## Save reconstruction results prior to mixture_stage training 231 | if FLAGS.log_results is True: 232 | saving_utils.save_results(sess=sess, 233 | resultsObj=resultsObj, 234 | FLAGS=FLAGS, 235 | VAE=VAE, 236 | component_labels=save_labels, 237 | component_data=component_data, 238 | hvg_masks=hvg_masks, 239 | marker_gene_masks=marker_gene_masks, 240 | stage="component", 241 | component_valid_labels=component_valid_labels, 242 | component_valid_data=component_valid_data, 243 | mixture_input_data=mixture_data) 244 | # if FLAGS.log_samples is True: 245 | # saving_utils.save_samples(sess=sess, 246 | # resultsObj=resultsObj, 247 | # FLAGS=FLAGS, 248 | # VAE=VAE, 249 | # component_labels=component_labels, 250 | # stage="prebatch") 251 | else: 252 | ## Load a previously saved model 253 | saver.restore(sess, tf.train.latest_checkpoint(os.path.abspath(FLAGS.logdir + '/ckpt'))) 254 | 255 | ## Proportion training 256 | if mixture_data is not None and FLAGS.deconvolution is True: 257 | 258 | ## Initialize the mixture dataset 259 | sess.run(deconvModel.mix_iter_data.initializer, feed_dict={deconvModel.mix_data_ph: mixture_data}) 260 | 261 | ## Initialize the mixture dataset 262 | sess.run(batchModel.mix_iter_data.initializer, feed_dict={batchModel.mix_data_ph: mixture_data}) 263 | 264 | for step in range(1,FLAGS.proportion_steps+1): 265 | ## Proportion optimizer 266 | _, summaries, prop_loss = sess.run([deconvModel.train_proportion_op, summary_op, deconvModel.proportion_loss]) 267 | 268 | if step % FLAGS.print_step == 0 or step == FLAGS.proportion_steps: 269 | print("(Proportions) Step: %s MSE: %-8.4f" % (step, prop_loss)) 270 | 271 | ## Report loss 272 | if step % FLAGS.log_step == 0: 273 | summary_writer.add_summary(summaries, str(step+FLAGS.max_steps_component)) 274 | 275 | if step % np.ceil(mixture_data.shape[0]/FLAGS.batch_size_mixture) == 0: 276 | epoch_loss = np.mean(epoch_loss) 277 | resultsObj.update_mixture_metrics("train", epoch_loss, epoch_loss) 278 | epoch_loss = [] ## Reset 279 | else: 280 | epoch_loss.append(prop_loss) 281 | 282 | ## Save model and summaries 283 | if step % FLAGS.log_step == 0 or step == FLAGS.proportion_steps: 284 | ## Summary reports (Tensorboard) 285 | summary_writer.add_summary(summaries, str(step+FLAGS.max_steps_component)) 286 | ## Get estiamted proportions 287 | proportions = sess.run(deconvModel.proportions) 288 | #weights = sess.run(deconvModel.mixture_weights) 289 | resultsObj.update_proportions(step, proportions) 290 | if step == FLAGS.proportion_steps: 291 | ## Write out graph 292 | save = saver.save(sess, os.path.abspath(FLAGS.logdir + '/ckpt/model_combined.ckpt'), step+FLAGS.max_steps_component) 293 | 294 | ## Save mixture probabilities 295 | # proportions = sess.run(deconvModel.proportions) 296 | # np.savetxt(FLAGS.logdir + "/probabilities.csv", 297 | # proportions, 298 | # delimiter=",", 299 | # header=','.join(str(x) for x in np.unique(component_labels)), 300 | # comments='') 301 | 302 | ## Save mixture weights 303 | # weights = sess.run(deconvModel.mixture_weights) 304 | # np.savetxt(FLAGS.logdir + "/weights.csv", 305 | # weights, 306 | # delimiter=",", 307 | # header=','.join(str(x) for x in np.unique(component_labels)), 308 | # comments='') 309 | 310 | # deconvModel.save_mixtures(sess=sess, 311 | # resultsObj=resultsObj, 312 | # FLAGS=FLAGS, 313 | # VAE=VAE, 314 | # mixture_input_data=mixture_data, 315 | # datatype='prebatch', 316 | # batch_size=100) 317 | 318 | ## Mixture batch-correction training 319 | if FLAGS.batch_correction is True: 320 | ## Record the VAE weights after pretraining 321 | # if FLAGS.tied_model_l2 is True and step == FLAGS.max_steps_component: 322 | # for scope in np.unique(component_labels): 323 | # print("Assigning weights") 324 | # VAE[scope].assign_network_weights(sess=sess) 325 | 326 | ## Train! 327 | for step in range(1,FLAGS.max_steps_combined+1): 328 | ## Combined optimizer 329 | _, summaries, train_loss, mse_loss, marker_mse, non_marker_mse = sess.run([batchModel.train_batch_op, summary_op, batchModel.train_batch_loss, batchModel.batch_loss, batchModel.marker_mse, batchModel.non_marker_mse]) 330 | 331 | if step % FLAGS.print_step == 0 or step == FLAGS.max_steps_combined: 332 | print("(Batch Correct) Step %s: %-6.2f LogP(mixture): %-8.4f MSE(marker): %-8.4f MSE(Non-marker): %-8.4f" % (step, train_loss, mse_loss, marker_mse, non_marker_mse)) 333 | 334 | ## Report loss 335 | if step % FLAGS.log_step == 0: 336 | for scope in np.unique(component_labels): 337 | with tf.variable_scope(scope, tf.AUTO_REUSE): 338 | train_loss, logp, kl_loss, mse_loss = sess.run([VAE[scope].train_loss, VAE[scope].avg_distortion, VAE[scope].avg_KL_div, VAE[scope].mse], options=None, run_metadata=None) 339 | resultsObj.update_component_metrics("train-post", scope, step, train_loss, mse_loss, logp, kl_loss) 340 | #print(sess.run(VAE[scope].data_index)) 341 | summary_writer.add_summary(summaries, str(step+FLAGS.max_steps_component+FLAGS.proportion_steps)) 342 | resultsObj.update_mixture_metrics("train", train_loss, mse_loss) 343 | 344 | ## Save model and summaries 345 | if step % FLAGS.log_step == 0 or step == FLAGS.max_steps_combined: 346 | ## Summary reports (Tensorboard) 347 | summary_writer.add_summary(summaries, str(step+FLAGS.max_steps_component+FLAGS.proportion_steps)) 348 | ## Get estiamted proportions 349 | proportions = sess.run(batchModel.proportions) 350 | #weights = sess.run(batchModel.mixture_weights) 351 | resultsObj.update_proportions(step, proportions) 352 | if step == FLAGS.max_steps_combined: 353 | ## Write out graph 354 | save = saver.save(sess, os.path.abspath(FLAGS.logdir + '/ckpt/model_combined.ckpt'), step+FLAGS.max_steps_component+FLAGS.proportion_steps) 355 | 356 | if FLAGS.log_results is True: 357 | print("Logging combined results") 358 | ## Save reconstruction results prior to mixture_stage training 359 | saving_utils.save_results(sess=sess, 360 | resultsObj=resultsObj, 361 | FLAGS=FLAGS, 362 | VAE=VAE, 363 | component_labels=save_labels, 364 | component_data=component_data, 365 | hvg_masks=hvg_masks, 366 | marker_gene_masks=marker_gene_masks, 367 | stage="combined", 368 | component_valid_labels=component_valid_labels, 369 | component_valid_data=component_valid_data, 370 | mixture_input_data=mixture_data) 371 | 372 | # deconvModel.save_mixtures(sess=sess, 373 | # resultsObj=resultsObj, 374 | # FLAGS=FLAGS, 375 | # VAE=VAE, 376 | # mixture_input_data=mixture_data, 377 | # datatype='postbatch', 378 | # batch_size=100) 379 | 380 | # print(deconvModel.trainable_prop) 381 | # print(deconvModel.trainable_comb) 382 | #print(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) 383 | ## Record the VAE weights after pretraining 384 | # if FLAGS.tied_model_l2 is True: 385 | # for scope in np.unique(component_labels): 386 | # print("Assigning weights") 387 | # VAE[scope].assign_network_weights(sess=sess, assign=False) 388 | ## Return depending on deconvolution 389 | return resultsObj 390 | --------------------------------------------------------------------------------