├── .gitignore ├── .travis.yml ├── CHANGELOG.md ├── LICENSE ├── Makefile ├── README.md ├── ci └── conda_requirements.txt ├── examples ├── cf │ ├── check_rhamnolipids.ipynb │ ├── lcms_nt.biom │ ├── metabolite-metadata.txt │ ├── microbe-metadata.txt │ ├── otus_nt.biom │ ├── q2_run.sh │ └── taxonomy.tsv └── soils │ ├── check_soils.ipynb │ ├── metabolites.biom │ ├── microbes.biom │ └── run.sh ├── img ├── biplot.png ├── heatmap.png ├── mmvec.png ├── paired-heatmap.png ├── paired-summary.png ├── single-summary.png └── tensorboard.png ├── mmvec ├── __init__.py ├── heatmap.py ├── multimodal.py ├── q2 │ ├── __init__.py │ ├── _method.py │ ├── _stats.py │ ├── _summary.py │ ├── _transformer.py │ ├── _visualizers.py │ ├── assets │ │ └── index.html │ ├── plugin_setup.py │ └── tests │ │ ├── test_method.py │ │ └── test_visualizers.py ├── tests │ ├── data │ │ ├── ms_hits.txt │ │ ├── otu_hits.txt │ │ ├── soil_metabolites.biom │ │ ├── soil_microbes.biom │ │ ├── x_test.biom │ │ ├── x_train.biom │ │ ├── y_test.biom │ │ └── y_train.biom │ ├── test_heatmap.py │ ├── test_multimodal.py │ └── test_util.py └── util.py ├── scripts └── mmvec └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Temporary files 2 | *~ 3 | \#*# 4 | 5 | *.py[cod] 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Packages 11 | *.egg 12 | *.egg-info 13 | dist 14 | build 15 | eggs 16 | parts 17 | bin 18 | var 19 | sdist 20 | develop-eggs 21 | .installed.cfg 22 | lib 23 | lib64 24 | __pycache__ 25 | 26 | # Installer logs 27 | pip-log.txt 28 | 29 | # Unit test / coverage reports 30 | .coverage 31 | .tox 32 | nosetests.xml 33 | 34 | # Translations 35 | *.mo 36 | 37 | # Mr Developer 38 | .mr.developer.cfg 39 | .project 40 | .pydevproject 41 | 42 | # vi 43 | .*.swp 44 | 45 | # Sphinx builds 46 | doc/source/generated 47 | 48 | # OSX files 49 | .DS_Store 50 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | # Travis yml file inspired by scikit-bio 2 | # Check on http://lint.travis-ci.org/ after modifying it! 3 | sudo: false 4 | language: python 5 | env: 6 | - PYVERSION=3.5 7 | before_install: 8 | - export MPLBACKEND='Agg' 9 | - wget -q https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh 10 | - export MINICONDA_PREFIX="$HOME/miniconda" 11 | - bash miniconda.sh -b -p $MINICONDA_PREFIX 12 | - export PATH="$MINICONDA_PREFIX/bin:$PATH" 13 | - conda config --set always_yes yes 14 | - conda update -q conda 15 | - conda info -a 16 | install: 17 | - wget -q https://raw.githubusercontent.com/qiime2/environment-files/master/2020.5/staging/qiime2-2020.5-py36-linux-conda.yml 18 | - conda env create -q -n test_env --file qiime2-2020.5-py36-linux-conda.yml 19 | - conda install --yes -q -n test_env --file ci/conda_requirements.txt -c biocore 20 | - source activate test_env 21 | - pip install -e . 22 | script: 23 | - make all 24 | notifications: 25 | webhooks: 26 | on_success: change 27 | on_failure: always 28 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # mmvec changelog 2 | 3 | ## Version 1.0.5 4 | - Adding summary commands to diagnose MMvec in the qiime2 interface [#151](https://github.com/biocore/mmvec/pull/151) 5 | 6 | ## Version 1.0.4 (2020-04-24) 7 | # Enhancements 8 | - `equalize_biplot` option has been able to visualize microbes and metabolites on the same scale. [#131](https://github.com/biocore/mmvec/pull/131) 9 | 10 | ## Version 1.0.3 (2019-12-12) 11 | # Enhancements 12 | - Tensorflow is now pinned to any version below 2.0 in [#112](https://github.com/biocore/mmvec/pull/112) 13 | - Learning rate defaults have been fixed to `1e-5` in [#110](https://github.com/biocore/mmvec/pull/110) 14 | 15 | # Bug fixes 16 | - Inputs are now expected to be metabolites x microbes in heatmaps [#100](https://github.com/biocore/mmvec/pull/100) 17 | 18 | ## Version 1.0.2 (2019-10-18) 19 | # Bug fixes 20 | - Inputs are now expected to be metabolites x microbes in heatmaps [#100](https://github.com/biocore/mmvec/pull/100) 21 | 22 | # Enhancements 23 | - Ranks are transposed and viewable in qiime metadata tabulate [#99](https://github.com/biocore/mmvec/pull/99) 24 | 25 | # Bug fixes 26 | - Ranks are now calculated consistently between q2 and standalone cli [#99](https://github.com/biocore/mmvec/pull/99) 27 | 28 | ## Version 1.0.0 (2019-09-30) 29 | # Enhancements 30 | - Paired heatmaps are available [#89](https://github.com/biocore/mmvec/pull/89) 31 | - Heatmap tutorials are available [#90](https://github.com/biocore/mmvec/pull/90) 32 | 33 | # Bug fixes 34 | - The ordering of the eigenvalues are now reversed [#92](https://github.com/biocore/mmvec/pull/92) 35 | - The qiime2 assets setup is corrected [#91](https://github.com/biocore/mmvec/pull/91) 36 | 37 | ## Version 0.6.0 (2019-09-05) 38 | 39 | # Enhancements 40 | - Ranks from CLI can now be imported into qiime2 [#84](https://github.com/biocore/mmvec/pull/84) 41 | - Ranks can be visualized as heatmaps [#69](https://github.com/biocore/mmvec/pull/69) 42 | 43 | # Bug fixes 44 | - ConditionalFormat has been fixed [#68](https://github.com/biocore/mmvec/pull/68) 45 | 46 | ## Version 0.4.0 (2019-07-22) 47 | 48 | # Enhancements 49 | - Simpler standalone CLI interface - now all outputs have named rows and columns. [#61](https://github.com/biocore/mmvec/pull/61) 50 | 51 | # Bug fixes 52 | - The ranks file is no longer empty. [#61](https://github.com/biocore/mmvec/pull/61) 53 | 54 | ## Version 0.3.0 (2019-06-20) 55 | 56 | Initial beta release. 57 | 58 | # Bug fixes 59 | - Biplots are now being properly centered in the qiime2 interface [#58](https://github.com/biocore/mmvec/pull/58) 60 | 61 | 62 | ## Version 0.2.0 (2019-04-22) 63 | 64 | Initial alpha release. MMvec API, standalone command line interface and qiime2 interface should be stable. 65 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, Jamie Morton 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .DEFAULT_GOAL := help 2 | 3 | TEST_COMMAND = nosetests 4 | help: 5 | @echo 'Use "make test" to run all the unit tests and docstring tests.' 6 | @echo 'Use "make pep8" to validate PEP8 compliance.' 7 | @echo 'Use "make html" to create html documentation with sphinx' 8 | @echo 'Use "make all" to run all the targets listed above.' 9 | test: 10 | $(TEST_COMMAND) 11 | pep8: 12 | pycodestyle mmvec setup.py scripts 13 | flake8 mmvec setup.py scripts 14 | 15 | all: pep8 test 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/biocore/mmvec.svg?branch=master)](https://travis-ci.org/biocore/mmvec) 2 | 3 | # MMvec 4 | Neural networks for estimating microbe-metabolite interactions through their co-occurence probabilities. 5 | 6 | ![](https://github.com/biocore/mmvec/raw/master/img/mmvec.png "mmvec") 7 | 8 | # Installation 9 | 10 | MMvec can be installed via pypi as follows 11 | 12 | ``` 13 | pip install mmvec 14 | ``` 15 | 16 | If you are planning on using GPUs, be sure to `pip install tensorflow-gpu <= 1.14.0`. 17 | 18 | MMvec can also be installed via conda as follows 19 | 20 | ``` 21 | conda install mmvec -c conda-forge 22 | ``` 23 | 24 | **Warning** : Note that this option may not work in cluster environments, it maybe workwhile to pip install within a virtual environment. It is possible to pip install mmvec within a conda environment, including qiime2 conda environments. However, pip and conda are known to have compatibility issues, so proceed with caution. 25 | 26 | **Update** : conda has not aged very well since this package was released. Below is are updated install instructions using mamba install (without qiime2) 27 | ``` 28 | conda create -n mmvec_env mamba python=3.7 -c conda-forge 29 | conda activate mmvec_env 30 | mamba install mmvec -c conda-forge 31 | ``` 32 | 33 | Finally, MMvec is **only** compatible with qiime2 environments 2020.6 or before. Stay tuned for future updates. 34 | 35 | # Input data 36 | 37 | The two basic tables required to run mmvec are: 38 | 39 | - Metabolite counts (.biom): A table with metabolites in rows and samples in columns. 40 | - Microbe abundance (.biom): A relative abundance table with microbial species in rows and samples in columns. 41 | 42 | # Getting started 43 | 44 | To get started you can run a quick example as follows. This will learn microbe-metabolite vectors (mmvec) 45 | which can be used to estimate microbe-metabolite conditional probabilities that are accurate up to rank. 46 | 47 | ``` 48 | mmvec paired-omics \ 49 | --microbe-file examples/cf/otus_nt.biom \ 50 | --metabolite-file examples/cf/lcms_nt.biom \ 51 | --summary-dir summary 52 | ``` 53 | 54 | While this is running, you can open up another session and run `tensorboard --logdir .` for diagnosis, see FAQs below for more details. 55 | 56 | If you investigate the summary folder, you will notice that there are a number of files deposited. 57 | 58 | See the following url for a more complete tutorial with real datasets. 59 | 60 | https://github.com/knightlab-analyses/multiomic-cooccurences 61 | 62 | More information can found under `mmvec --help` 63 | 64 | # Qiime2 plugin 65 | 66 | If you want to run this in a qiime environment, install this in your 67 | qiime2 conda environment (see qiime2 installation instructions [here](https://qiime2.org/)) and run the following 68 | 69 | ``` 70 | pip install git+https://github.com/biocore/mmvec.git 71 | qiime dev refresh-cache 72 | ``` 73 | 74 | This should allow your q2 environment to recognize mmvec. Before we test 75 | the qiime2 plugin, go to the `examples/cf` folder and run the following commands to import an example dataset 76 | 77 | ``` 78 | qiime tools import \ 79 | --input-path otus_nt.biom \ 80 | --output-path otus_nt.qza \ 81 | --type FeatureTable[Frequency] 82 | 83 | qiime tools import \ 84 | --input-path lcms_nt.biom \ 85 | --output-path lcms_nt.qza \ 86 | --type FeatureTable[Frequency] 87 | ``` 88 | 89 | Then you can run mmvec 90 | ``` 91 | qiime mmvec paired-omics \ 92 | --i-microbes otus_nt.qza \ 93 | --i-metabolites lcms_nt.qza \ 94 | --p-summary-interval 1 \ 95 | --output-dir model_summary 96 | ``` 97 | 98 | In the results, there are three files, namely `model_summary/conditional_biplot.qza`, `model_summary/conditionals.qza` and `model_summary/model_stats.qza`. 99 | The conditional biplot is a biplot representation the 100 | conditional probability matrix so that you can visualize these microbe-metabolite interactions in an exploratory manner. This can be directly visualized in 101 | Emperor as shown below. We also have the estimated conditional probability matrix given in `results/conditionals.qza`, 102 | which an be unzip to yield a tab-delimited table via `unzip results/conditionals`. Each row can be ranked, 103 | so the top most occurring metabolites for a given microbe can be obtained by identifying the highest co-occurrence probabilities for each microbe. 104 | 105 | These log conditional probabilities can also be viewed directly with `qiime metadata tabulate`. This can be 106 | created as follows 107 | 108 | ``` 109 | qiime metadata tabulate \ 110 | --m-input-file results/conditionals.qza \ 111 | --o-visualization conditionals-viz.qzv 112 | ``` 113 | 114 | 115 | Then you can run the following to generate a emperor biplot. 116 | 117 | ``` 118 | qiime emperor biplot \ 119 | --i-biplot conditional_biplot.qza \ 120 | --m-sample-metadata-file metabolite-metadata.txt \ 121 | --m-feature-metadata-file taxonomy.tsv \ 122 | --o-visualization emperor.qzv 123 | 124 | ``` 125 | 126 | The resulting biplot should look like something as follows 127 | 128 | ![biplot](https://github.com/biocore/mmvec/raw/master/img/biplot.png "Biplot") 129 | 130 | Here, the metabolite represent points and the arrows represent microbes. The points close together are indicative of metabolites that 131 | frequently co-occur with each other. Furthermore, arrows that have a small angle between them are indicative of microbes that co-occur with each other. 132 | Arrows that point in the same direction as the metabolites are indicative of microbe-metabolite co-occurrences. In the biplot above, the red arrows 133 | correspond to Pseudomonas aeruginosa, and the red points correspond to Rhamnolipids that are likely produced by Pseudomonas aeruginosa. 134 | 135 | Another way to examine these associations is to build heatmaps of the log 136 | conditional probabilities between observations, using the `heatmap` action: 137 | 138 | ``` 139 | qiime mmvec heatmap \ 140 | --i-ranks ranks.qza \ 141 | --m-microbe-metadata-file taxonomy.tsv \ 142 | --m-microbe-metadata-column Taxon \ 143 | --m-metabolite-metadata-file metabolite-metadata.txt \ 144 | --m-metabolite-metadata-column Compound_Source \ 145 | --p-level 5 \ 146 | --o-visualization ranks-heatmap.qzv 147 | ``` 148 | 149 | This action generates a clustered heatmap displaying the log conditional 150 | probabilities between microbes and metabolites. Larger positive log conditional 151 | probabilities indicate a stronger likelihood of co-occurrence. Low and negative 152 | values indicate no relationship, not necessarily a negative correlation. Rows 153 | (microbial features) can be annotated according to feature metadata, as shown 154 | in this example; we provide a taxonomic classification file and the semicolon- 155 | delimited taxonomic rank (`level`) that should be displayed in the color-coded 156 | margin annotation. Set `level` to `-1` to display the full annotation 157 | (including of non-delimited feature metadata). Separate parameters are 158 | available to annotate the x-axis (metabolites) in a similar fashion. Row and 159 | column clustering can be adjusted using the `method` and `metric` parameters. 160 | This action will generate a heatmap that looks similar to this: 161 | 162 | ![heatmap](https://github.com/biocore/mmvec/raw/master/img/heatmap.png "Heatmap") 163 | 164 | Biplots and heatmaps give a great overview of co-occurrence associations, but 165 | do not provide information about the abundances of these co-occurring features 166 | in each sample. This can be done with the `paired-heatmap` action: 167 | 168 | ``` 169 | qiime mmvec paired-heatmap \ 170 | --i-ranks ranks.qza \ 171 | --i-microbes-table otus_nt.qza \ 172 | --i-metabolites-table lcms_nt.qza \ 173 | --m-microbe-metadata-file taxonomy.tsv \ 174 | --m-microbe-metadata-column Taxon \ 175 | --p-features TACGAAGGGTGCAAGCGTTAATCGGAATTACTGGGCGTAAAGCGCGCGTAGGTGGTTCAGCAAGTTGGATGTGAAATCCCCGGGCTCAACCTGGGAACTGCATCCAAAACTACTGAGCTAGAGTACGGTAGAGGGTGGTGGAATTTCCTG \ 176 | --p-features TACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTGTGGCTTAACCATAGTAGGCTTTGGAAACTGTTTAACTTGAGTGCAAGAGGGGAGAGTGGAATTCCATGT \ 177 | --p-top-k-microbes 0 \ 178 | --p-normalize rel_row \ 179 | --p-top-k-metabolites 100 \ 180 | --p-level 6 \ 181 | --o-visualization paired-heatmap-top2.qzv 182 | ``` 183 | 184 | This action generates paired heatmaps that are aligned on the y-axis (sample 185 | IDs): the left panel displays the abundances of each selected microbial feature 186 | in each sample, and the right panel displays the abundances of the top k 187 | metabolite features associated with each of these microbes in each sample. 188 | Microbes can be selected automatically using the `top-k-microbes` parameter 189 | (which selects the microbes with the top k highest relative abundances) or they 190 | can be selected by name using the `features` parameter (if using the QIIME 2 191 | plugin command-line interface as shown in this example, multiple features are 192 | selected by passing this parameter multiple times, e.g., `--p-features feature1 193 | --p-features feature2`; for python interfaces, pass a list of features: 194 | `features=[feature1, feature2]`). As with the `heatmap` action, microbial 195 | features can be annotated by passing in `microbe-metadata` and specifying a 196 | taxonomic `level` to display. The output looks something like this: 197 | 198 | ![paired-heatmap](https://github.com/biocore/mmvec/raw/master/img/paired-heatmap.png "Paired Heatmap") 199 | 200 | 201 | More information behind the actions and parameters can found under `qiime mmvec --help` 202 | 203 | # Model diagnostics 204 | 205 | ## QIIME2 Convergence Summaries 206 | 207 | If you are using the qiime2 interface, there won't be a tensorboard interface. 208 | But there will still be training loss curves and cross-validation statistics reported, which are currently not available in the tensorboard interface. To run this with a single model, run the following 209 | 210 | ``` 211 | qiime mmvec summarize-single \ 212 | --i-model-stats model_summary/model_stats.qza \ 213 | --o-visualization model-summary.qzv 214 | ``` 215 | 216 | An example of what this will look like is given as follows 217 | ![single_summary](https://github.com/biocore/mmvec/raw/master/img/single-summary.png "Single Summary") 218 | 219 | ## Null models and QIIME 2 + MMvec 220 | 221 | If you're running mmvec through QIIME 2, the 222 | `qiime mmvec summarize-paired` command allows you to view two sets of 223 | diagnostic plots at once as follows: 224 | 225 | ``` 226 | # Null model with only biases 227 | qiime mmvec paired-omics \ 228 | --i-microbes otus_nt.qza \ 229 | --i-metabolites lcms_nt.qza \ 230 | --p-latent-dim 0 \ 231 | --p-summary-interval 1 \ 232 | --output-dir null_summary 233 | 234 | qiime mmvec summarize-paired \ 235 | --i-model-stats model_summary/model_stats.qza \ 236 | --i-baseline-stats null_summary/model_stats.qza \ 237 | --o-visualization paired-summary.qzv 238 | ``` 239 | 240 | An example of what this will look like is given as follows 241 | ![paired_summary](https://github.com/biocore/mmvec/raw/master/img/paired-summary.png "Paired Summary") 242 | 243 | It is important to note here that the null model has a worst cross-validation error than the first MMvec model we trained. However to make the models exactly comparable, the same samples must be used for training and cross-validation. See the `--p-training-column` option to manually specify samples for training and testing. 244 | 245 | These summaries can also be extended to analyze any two models of interest. This can help with picking optimal hyper-parameters. 246 | 247 | ## Interpreting _Q2_ values 248 | The _Q2_ score is adapted from the Partial least squares literature. Here it is given by `Q^2 = 1 - m1/m2`, where `m1` indicates the average absolute model error and `m2` indicates the average absolute null or baseline model error. If _Q2_ is close to 1, that indicates a high predictive accuracy on the cross validation samples. If _Q2_ is low or below zero, that indicates poor predictive accuracy, suggesting possible overfitting. This statistic behaves similarly to the _R2_ classically used in a ordinary linear regression if `--p-formula` is `"1"` in the `m2` model. 249 | 250 | If the _Q2_ score is extremely close to 0 (or negative), this indicates that the model is overfit or that the metadata supplied to the model are not predictive of microbial composition across samples. You can think about this in terms of "how does using the metadata columns in my formula *improve* a model?" If there isn't really an improvement, then you may want to reconsider your formula. 251 | 252 | ... [But as long as your _Q2_ score is above zero, your model is learning something useful](https://forum.qiime2.org/t/songbird-optimizing-the-loss-function/13479/8). 253 | 254 | # FAQs 255 | 256 | **Q**: Looks like there are two different commands, a standalone script and a qiime2 interface. Which one should I use?!? 257 | 258 | **A**: It'll depend on how deep in the weeds you'll want to get. For most intents and purposes, the qiime2 interface will more practical for most analyses. There are 3 major reasons why the standalone scripts are more preferable to the qiime2 interface, namely 259 | 260 | 1. Customized acceleration : If you want to bring down your runtime from a few days to a few hours, you may need to compile Tensorflow to handle hardware specific instructions (i.e. GPUs / SIMD instructions). It probably is possible to enable GPU compatiability within a conda environment with some effort, but since conda packages binaries, SIMD instructions will not work out of the box. 261 | 262 | 2. Checkpoints : If you are not sure how long your analysis should run, the standalone script can allow you record checkpoints, which can allow you to recover your model parameters. This enables you to investigate your model while the model is training. 263 | 264 | 3. More model parameters : The standalone script will return the bias parameters learned for each dataset (i.e. microbe and metabolite abundances). These are stored under the summary directory (specified by `--summary`) under the names `embeddings.csv`. This file will hold the coordinates for the microbes and metabolites, along with biases. There are 4 columns in this file, namely `feature_id`, `axis`, `embed_type` and `values`. `feature_id` is the name of the feature, whether it be a microbe name or a metabolite feature id. `axis` corresponds to the name of the axis, which either corresponds to a PC axis or bias. `embed_type` denotes if the coordinate corresponds to a microbe or metabolite. `values` is the coordinate value for the given `axis`, `embed_type` and `feature_id`. This can be useful for accessing the raw parameters and building custom biplots / ranks visualizations - this also has the advantage of requiring much less memory to manipulate. 265 | 266 | It is also important to note that you don't have to explicitly choose - it is very doable to run the standalone version first, then import those output files into qiime2. Importing can be done as follows 267 | 268 | ``` 269 | qiime tools import --input-path --output-path conditionals.qza --type FeatureData[Conditional] 270 | 271 | qiime tools import --input-path --output-path ordination.qza --type 'PCoAResults % Properties("biplot")' 272 | ``` 273 | 274 | **Q** : You mentioned that you can use GPUs. How can you do that?? 275 | 276 | **A** : This can be done by running `pip install tensorflow-gpu` in your environment. See details [here](https://www.tensorflow.org/install/gpu). 277 | 278 | At the moment, these capabilities are only available for the standalone CLI due to complications of installation. See the `--arm-the-gpu` option in the standalone interface. 279 | 280 | **Q** : Neural networks scare me - don't they overfit the crap out of your data? 281 | 282 | **A** : Here, we are using shallow neural networks (so only two layers). This falls under the same regime as PCA and SVD. But just as you can overfit PCA/SVD, you can also overfit mmvec. Which is why we have Tensorboard enabled for diagnostics. You can visualize the `cv_rmse` to gauge if there is overfitting -- if your run is strictly decreasing, then that is a sign that you are probably not overfitting. But this is not necessarily indicative that you have reach the optimal -- you want to check to see if `logloss` has reached a plateau as shown above. 283 | 284 | **Q** : I'm confused, what is Tensorboard? 285 | 286 | **A** : Tensorboard is a diagnostic tool that runs in a web browser - note that this is only explicitly supported in the standalone version of mmvec. To open tensorboard, make sure you’re in the mmvec environment and cd into the folder you are running the script above from. Then run: 287 | 288 | ``` 289 | tensorboard --logdir . 290 | ``` 291 | 292 | Returning line will look something like: 293 | 294 | ``` 295 | TensorBoard 1.9.0 at http://Lisas-MacBook-Pro-2.local:6006 (Press CTRL+C to quit) 296 | ``` 297 | Open the website (highlighted in red) in a browser. (Hint; if that doesn’t work try putting only the port number (here it is 6006), adding localhost, localhost:6006). Leave this tab alone. Now any mmvec output directories that you add to the folder that tensorflow is running in will be added to the webpage. 298 | 299 | 300 | If working properly, it will look something like this 301 | ![tensorboard](https://github.com/biocore/mmvec/raw/master/img/tensorboard.png "Tensorboard") 302 | 303 | FIRST graph in Tensorflow; 'Prediction accuracy'. Labelled `cv_rmse` 304 | 305 | This is a graph of the prediction accuracy of the model; the model will try to guess the metabolite intensitiy values for the testing samples that were set aside in the script above, using only the microbe counts in the testing samples. Then it looks at the real values and sees how close it was. 306 | 307 | The second graph is the `likelihood` - if your `likelihood` values are plateaued, that is a sign that you have converged and reached at a local minima. 308 | 309 | The x-axis is the number of iterations (meaning times the model is training across the entire dataset). Every time you iterate across the training samples, you also run the test samples and the averaged results are being plotted on the y-axis. 310 | 311 | 312 | The y-axis is the average number of counts off for each feature. The model is predicting the sequence counts for each feature in the samples that were set aside for testing. So in the graph above it means that, on average, the model is off by ~0.75 intensity units, which is low. However, this is ABSOLUTE error not relative error (unfortunately we don't know how to compute relative errors because of the sparsity in these datasets). 313 | 314 | You can also compare multiple runs with different parameters to see which run performed the best. Useful parameters to note are `--epochs` and `--batch-size`. If you are committed to fine-tuning parameters, be sure to look at the `training-column` example make the testing samples consistent across runs. 315 | 316 | 317 | **Q** : What's up with the `--training-column` argument? 318 | 319 | **A** : That is used for cross-validation if you have a specific reproducibility question that you are interested in answering. It can also make it easier to compare cross validation results across runs. If this is specified, only samples labeled "Train" under this column will be used for building the model and samples labeled "Test" will be used for cross validation. In other words the model will attempt to predict the microbe abundances for the "Test" samples. The resulting prediction accuracy is used to evaluate the generalizability of the model in order to determine if the model is overfitting or not. If this argument is not specified, then 10 random samples will be chosen for the test dataset. If you want to specify more random samples to allocate for cross-validation, the `num-random-test-examples` argument can be specified. 320 | 321 | 322 | **Q** : What sort of parameters should I focus on when picking a good model? 323 | 324 | **A** : There are 3 different parameters to focus on, `input-prior`, `output-prior` and `latent-dim` 325 | 326 | The `--input-prior` and `--output-prior` options specifies the width of the prior distribution of the coefficients, where the `--input-prior` is typically specific to microbes and the `--output-prior` is specific to metabolites. 327 | For a prior of 1, this means 99% of entries in the embeddings are within -3 and +3 log fold change. A prior of 0.1 would impose the constraint that 99% of the embeddings are within -0.3 and +0.3 log fold change. The higher `--input-prior` and `--output-prior` is, the more parameters can have bigger changes, so you want to keep this relatively small for small experimental studies, particularly if there are less than 20 samples (we have not been able to run MMvec on a study with fewer than 12 samples without overfitting). 328 | If you see overfitting (accuracy and fit increasing over iterations in tensorboard) you may consider reducing the `--input-prior` and `--output-prior` in order to reduce the parameter space. 329 | 330 | Another parameter worth thinking about is `--latent-dim`, which controls the number of dimensions used to approximate the conditional probability matrix. This also specifies the dimensions of the microbe/metabolite embeddings that are stored in the biplot file. The more dimensions this has, the more accurate the embeddings can be -- but the higher the chance of overfitting there is. The rule of thumb to follow is in order to fit these models, you need at least 10 times as many samples as there are latent dimensions (this is following a similar rule of thumb for fitting straight lines). So if you have 100 samples, you should definitely not have a latent dimension of more than 10. Furthermore, you can still overfit certain microbes and metabolites. For example, you are fitting a model with those 100 samples and just 1 latent dimension, you can still easily overfit microbes and metabolites that appear in less than 10 samples -- so even fitting models with just 1 latent dimension will require some microbes and metabolites that appear in less than 10 samples to be filtered out. 331 | 332 | 333 | **Q** : What does a good model fit look like?? 334 | 335 | **A** : Again the numbers vary greatly by dataset. But you want to see the both the `logloss` and `cv_rmse` curves decaying, and plateau as close to zero as possible. 336 | 337 | **Q** : Should we filter low abundance microbes and metabolites? 338 | 339 | **A** : A rule of thumb that we recommend is to filter out microbes and metabolites that appear in less than 10 samples. The rationale here is that it isn't practical to fit a line with less than 10 samples. By default we filter out microbes that appear in less than 10 samples; this can be controlled by the `--min-feature-count` option. 340 | 341 | **Q** : How long should I expect this program to run? 342 | 343 | **A** : Both `epochs` and `batch-size` contribute to determining how long the algorithm will run, namely 344 | 345 | **Number of iterations = `epoch #` multiplied by the ( Total # of microbial reads / `batch-size` parameter)** 346 | 347 | This also depends on if your program will converge. The `learning-rate` specifies the resolution, smaller step size = smaller resolution, which will increase the accuracy, but may take longer to converge. You may need to consult with Tensorboard to make sure that your model fit is sane. See this paper for more details on gradient descent: https://arxiv.org/abs/1609.04747 348 | 349 | 350 | If you are running this on a CPU, 16 cores, a run that reaches convergence should take about 1 day. 351 | If you have a GPU - you maybe able to get this down to a few hours. However, some finetuning of the `batch-size` parameter maybe required -- instead of having a small `batch-size` < 100, you'll want to bump up the `batch-size` to between 1000 and 10000 to fully leverage the speedups available on the GPU. 352 | 353 | As a good reference, the cystic fibrosis dataset can be processed within 10 minutes on a single CPU and within 1 minute on a GPU. 354 | 355 | **Q** : Can I run the standalone version of mmvec and import those outputs to visualize in qiime2? 356 | 357 | **A** : Yes you can! If you ran the standalone `mmvec paired-omics` command and you specified your ranks and ordination to be stored under `conditionals.tsv` and `ordination.txt`, you can import those as qiime2 Artifacts as follows. 358 | 359 | ``` 360 | qiime tools import --input-path conditionals.tsv --output-path ranks.qza --type "FeatureData[Conditional]" 361 | qiime tools import --input-path ordination.txt --output-path biplot.qza --type "PCoAResults % Properties('biplot')" 362 | ``` 363 | 364 | **Q** : Can MMvec handle small sample studies? 365 | 366 | **A** : We have ran MMvec with published studies as few as 19 samples. However running MMvec in these small sample regimes requires careful tuning of `--latent-dimension` in addition to the `--input-prior` and `--output-prior` commands. The [desert biocrust experiment](https://github.com/biocore/mmvec/tree/master/examples/soils) maybe a good dataset to refer to when analyzing these sorts of datasets. It is important to note that we have not been able to run MMvec on fewer than 12 samples. 367 | 368 | Credits to Lisa Marotz ([@lisa55asil](https://github.com/lisa55asil)), Yoshiki Vazquez-Baeza ([@ElDeveloper](https://github.com/ElDeveloper)), Julia Gauglitz ([@jgauglitz](https://github.com/jgauglitz)) and Nickolas Bokulich ([@nbokulich](https://github.com/nbokulich)) for their README contributions. 369 | 370 | **Q** You mentioned that MMvec learns co-occurrence probabilities. How can I extract these probabilities? 371 | 372 | **A** MMvec will output a file of co-occurrence probabilities, where the rows are metabolites and columns are microbes. You can extract the co-occurrence probabilities by applying a softmax transform along the columns. In python, this done as follows 373 | ```python 374 | import pandas as pd 375 | from skbio.stats.composition import clr_inv as softmax 376 | ranks = pd.read_table('ranks.txt', index_col=0) 377 | probs = ranks.apply(softmax) 378 | probs.to_csv('conditional_probs.txt', sep='\t') 379 | ``` 380 | 381 | # Citation 382 | If you found this tool useful please cite us at 383 | ``` 384 | @article{morton2019learning, 385 | title={Learning representations of microbe--metabolite interactions}, 386 | author={Morton, James T and Aksenov, Alexander A and Nothias, Louis Felix and Foulds, James R and Quinn, Robert A and Badri, Michelle H and Swenson, Tami L and Van Goethem, Marc W and Northen, Trent R and Vazquez-Baeza, Yoshiki and others}, 387 | journal={Nature methods}, 388 | volume={16}, 389 | number={12}, 390 | pages={1306--1314}, 391 | year={2019}, 392 | publisher={Nature Publishing Group} 393 | } 394 | ``` 395 | -------------------------------------------------------------------------------- /ci/conda_requirements.txt: -------------------------------------------------------------------------------- 1 | biom-format 2 | nose 3 | scikit-bio >0.5.1 4 | tqdm 5 | tensorflow <=1.14 6 | coveralls 7 | pycodestyle 8 | flake8 9 | -------------------------------------------------------------------------------- /examples/cf/check_rhamnolipids.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "\u001b[34m8a937f0c-d349-40e0-acf9-9e221ba5b292\u001b[m\u001b[m microbe-metadata.txt\r\n", 13 | "biplot.qza otus_nt.biom\r\n", 14 | "check_rhamnolipids.ipynb otus_nt.qza\r\n", 15 | "emperor.qzv q2_run.sh\r\n", 16 | "heatmap.qzv ranks.qza\r\n", 17 | "lcms_nt.biom \u001b[34msummary\u001b[m\u001b[m\r\n", 18 | "lcms_nt.qza \u001b[34msummarydir\u001b[m\u001b[m\r\n", 19 | "metabolite-metadata.txt \u001b[34mtesting\u001b[m\u001b[m\r\n" 20 | ] 21 | } 22 | ], 23 | "source": [ 24 | "!ls" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "name": "stdout", 34 | "output_type": "stream", 35 | "text": [ 36 | "\u001b[34mlatent_dim_3_input_prior_1.00_output_prior_1.00_beta1_0.90_beta2_0.95\u001b[m\u001b[m\r\n", 37 | "latent_dim_3_input_prior_1.00_output_prior_1.00_beta1_0.90_beta2_0.95_embedding.txt\r\n", 38 | "latent_dim_3_input_prior_1.00_output_prior_1.00_beta1_0.90_beta2_0.95_ordination.txt\r\n", 39 | "latent_dim_3_input_prior_1.00_output_prior_1.00_beta1_0.90_beta2_0.95_ranks.txt\r\n" 40 | ] 41 | } 42 | ], 43 | "source": [ 44 | "!ls testing" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "# Standalone check" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 3, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "import pandas as pd\n", 61 | "import numpy as np\n", 62 | "fname = 'latent_dim_3_input_prior_1.00_output_prior_1.00_beta1_0.90_beta2_0.95_ranks.txt'\n", 63 | "ranks = pd.read_csv(f'summary/{fname}', sep='\\t', index_col=0)\n", 64 | "microbe_metadata = pd.read_csv('microbe-metadata.txt', sep='\\t', index_col=0)\n", 65 | "metabolite_metadata = pd.read_csv('metabolite-metadata.txt', sep='\\t', index_col=0)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 4, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "microbe_metadata = microbe_metadata.loc[ranks.columns]\n", 75 | "i = microbe_metadata.Taxon.apply(lambda x: 'Pseudomonas' in x)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 5, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "pseudomonas = microbe_metadata.loc[i].index" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 6, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "metabolite_metadata = metabolite_metadata.dropna(subset=['expert_annotation'])" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 7, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "data": { 103 | "text/plain": [ 104 | "19" 105 | ] 106 | }, 107 | "execution_count": 7, 108 | "metadata": {}, 109 | "output_type": "execute_result" 110 | } 111 | ], 112 | "source": [ 113 | "np.sum(ranks.loc[metabolite_metadata.index, pseudomonas[0]] > 0)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": {}, 119 | "source": [ 120 | "# qiime2 check" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 8, 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "name": "stderr", 130 | "output_type": "stream", 131 | "text": [ 132 | "/Users/jmorton/miniconda3/envs/qiime2-2019.7/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", 133 | " _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n", 134 | "/Users/jmorton/miniconda3/envs/qiime2-2019.7/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", 135 | " _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n", 136 | "/Users/jmorton/miniconda3/envs/qiime2-2019.7/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", 137 | " _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n", 138 | "/Users/jmorton/miniconda3/envs/qiime2-2019.7/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", 139 | " _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n", 140 | "/Users/jmorton/miniconda3/envs/qiime2-2019.7/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", 141 | " _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n", 142 | "/Users/jmorton/miniconda3/envs/qiime2-2019.7/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", 143 | " np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n", 144 | "/Users/jmorton/miniconda3/envs/qiime2-2019.7/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", 145 | " _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n", 146 | "/Users/jmorton/miniconda3/envs/qiime2-2019.7/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", 147 | " _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n", 148 | "/Users/jmorton/miniconda3/envs/qiime2-2019.7/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", 149 | " _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n", 150 | "/Users/jmorton/miniconda3/envs/qiime2-2019.7/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", 151 | " _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n", 152 | "/Users/jmorton/miniconda3/envs/qiime2-2019.7/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", 153 | " _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n", 154 | "/Users/jmorton/miniconda3/envs/qiime2-2019.7/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", 155 | " np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n" 156 | ] 157 | } 158 | ], 159 | "source": [ 160 | "import qiime2\n", 161 | "ranks = qiime2.Artifact.load('ranks.qza').view(pd.DataFrame)\n", 162 | "microbe_metadata = pd.read_csv('microbe-metadata.txt', sep='\\t', index_col=0)\n", 163 | "metabolite_metadata = pd.read_csv('metabolite-metadata.txt', sep='\\t', index_col=0)\n", 164 | "microbe_metadata = microbe_metadata.loc[ranks.columns]\n", 165 | "i = microbe_metadata.Taxon.apply(lambda x: 'Pseudomonas' in x)\n", 166 | "pseudomonas = microbe_metadata.loc[i].index" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 9, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "metabolite_metadata = metabolite_metadata.dropna(subset=['expert_annotation'])" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 10, 181 | "metadata": {}, 182 | "outputs": [ 183 | { 184 | "data": { 185 | "text/plain": [ 186 | "19" 187 | ] 188 | }, 189 | "execution_count": 10, 190 | "metadata": {}, 191 | "output_type": "execute_result" 192 | } 193 | ], 194 | "source": [ 195 | "np.sum(ranks.loc[metabolite_metadata.index, pseudomonas[0]] > 0)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [] 204 | } 205 | ], 206 | "metadata": { 207 | "kernelspec": { 208 | "display_name": "Python 3", 209 | "language": "python", 210 | "name": "python3" 211 | }, 212 | "language_info": { 213 | "codemirror_mode": { 214 | "name": "ipython", 215 | "version": 3 216 | }, 217 | "file_extension": ".py", 218 | "mimetype": "text/x-python", 219 | "name": "python", 220 | "nbconvert_exporter": "python", 221 | "pygments_lexer": "ipython3", 222 | "version": "3.6.7" 223 | } 224 | }, 225 | "nbformat": 4, 226 | "nbformat_minor": 2 227 | } 228 | -------------------------------------------------------------------------------- /examples/cf/lcms_nt.biom: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biocore/mmvec/88ca33b408a85b6bf90fae06982936247b860272/examples/cf/lcms_nt.biom -------------------------------------------------------------------------------- /examples/cf/otus_nt.biom: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biocore/mmvec/88ca33b408a85b6bf90fae06982936247b860272/examples/cf/otus_nt.biom -------------------------------------------------------------------------------- /examples/cf/q2_run.sh: -------------------------------------------------------------------------------- 1 | qiime tools import --input-path otus_nt.biom --output-path otus_nt.qza --type FeatureTable[Frequency] 2 | qiime tools import --input-path lcms_nt.biom --output-path lcms_nt.qza --type FeatureTable[Frequency] 3 | 4 | qiime mmvec paired-omics \ 5 | --i-microbes otus_nt.qza \ 6 | --i-metabolites lcms_nt.qza \ 7 | --p-epochs 100 \ 8 | --p-learning-rate 1e-3 \ 9 | --o-conditionals ranks.qza \ 10 | --o-model-stats stats.qza \ 11 | --o-conditional-biplot biplot.qza \ 12 | --p-summary-interval 1 \ 13 | --p-equalize-biplot \ 14 | --verbose 15 | 16 | qiime emperor biplot \ 17 | --i-biplot biplot.qza \ 18 | --m-sample-metadata-file metabolite-metadata.txt \ 19 | --m-feature-metadata-file microbe-metadata.txt \ 20 | --p-number-of-features 50 \ 21 | --o-visualization emperor.qzv \ 22 | --p-ignore-missing-samples 23 | 24 | qiime mmvec heatmap \ 25 | --i-ranks ranks.qza \ 26 | --o-visualization heatmap.qzv 27 | -------------------------------------------------------------------------------- /examples/soils/check_soils.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "%matplotlib inline" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "# standalone check" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": {}, 25 | "outputs": [ 26 | { 27 | "name": "stdout", 28 | "output_type": "stream", 29 | "text": [ 30 | "\u001b[34mlatent_dim_1_input_prior_1.00_output_prior_1.00_beta1_0.90_beta2_0.95\u001b[m\u001b[m\r\n", 31 | "latent_dim_1_input_prior_1.00_output_prior_1.00_beta1_0.90_beta2_0.95_embedding.txt\r\n", 32 | "latent_dim_1_input_prior_1.00_output_prior_1.00_beta1_0.90_beta2_0.95_ordination.txt\r\n", 33 | "latent_dim_1_input_prior_1.00_output_prior_1.00_beta1_0.90_beta2_0.95_ranks.txt\r\n" 34 | ] 35 | } 36 | ], 37 | "source": [ 38 | "!ls summarydir" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 3, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "datadir = 'summarydir/latent_dim_1_input_prior_1.00_output_prior_1.00_beta1_0.90_beta2_0.95'\n", 48 | "ranks = pd.read_csv(datadir + '_ranks.txt', index_col=0, sep='\\t')" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 4, 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "data": { 58 | "text/plain": [ 59 | "featureid\n", 60 | "(2,3-dihydroxy-3-methylbutanoate) -3.987261\n", 61 | "(2,5-diaminohexanoate) -1.352668\n", 62 | "(3-hydroxypyridine) -0.020257\n", 63 | "(3-methyladenine) 0.959734\n", 64 | "(4-oxoproline) 2.986923\n", 65 | "Name: rplo 1 (Cyanobacteria), dtype: float64" 66 | ] 67 | }, 68 | "execution_count": 4, 69 | "metadata": {}, 70 | "output_type": "execute_result" 71 | } 72 | ], 73 | "source": [ 74 | "ranks['rplo 1 (Cyanobacteria)'].head()" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 5, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "microcoleus_metabolites = {'(3-methyladenine)', '7-methyladenine', '4-guanidinobutanoate', 'uracil',\n", 84 | " 'xanthine', 'hypoxanthine', '(N6-acetyl-lysine)', 'cytosine',\n", 85 | " 'N-acetylornithine', 'N-acetylornithine', 'succinate', \n", 86 | " 'adenosine', 'guanine', 'adenine'}" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 6, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "idx = ranks['rplo 1 (Cyanobacteria)'] > 0 \n", 96 | "detected_molecules = set(ranks.index[idx])" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 7, 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "data": { 106 | "text/plain": [ 107 | "{'(3-methyladenine)',\n", 108 | " '(N6-acetyl-lysine)',\n", 109 | " '4-guanidinobutanoate',\n", 110 | " '7-methyladenine',\n", 111 | " 'N-acetylornithine',\n", 112 | " 'adenine',\n", 113 | " 'adenosine',\n", 114 | " 'cytosine',\n", 115 | " 'guanine',\n", 116 | " 'hypoxanthine',\n", 117 | " 'succinate',\n", 118 | " 'uracil',\n", 119 | " 'xanthine'}" 120 | ] 121 | }, 122 | "execution_count": 7, 123 | "metadata": {}, 124 | "output_type": "execute_result" 125 | } 126 | ], 127 | "source": [ 128 | "detected_molecules & microcoleus_metabolites" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 8, 134 | "metadata": {}, 135 | "outputs": [ 136 | { 137 | "data": { 138 | "text/plain": [ 139 | "13" 140 | ] 141 | }, 142 | "execution_count": 8, 143 | "metadata": {}, 144 | "output_type": "execute_result" 145 | } 146 | ], 147 | "source": [ 148 | "len(microcoleus_metabolites)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 9, 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "data": { 158 | "text/plain": [ 159 | "13" 160 | ] 161 | }, 162 | "execution_count": 9, 163 | "metadata": {}, 164 | "output_type": "execute_result" 165 | } 166 | ], 167 | "source": [ 168 | "len(detected_molecules & microcoleus_metabolites)" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 10, 174 | "metadata": {}, 175 | "outputs": [ 176 | { 177 | "data": { 178 | "text/plain": [ 179 | "featureid\n", 180 | "xanthine 0.642930\n", 181 | "(N6-acetyl-lysine) 4.409032\n", 182 | "succinate 0.878566\n", 183 | "guanine 3.086299\n", 184 | "adenine 4.947557\n", 185 | "N-acetylornithine 1.247694\n", 186 | "7-methyladenine 0.232607\n", 187 | "cytosine 3.205279\n", 188 | "hypoxanthine 0.661717\n", 189 | "4-guanidinobutanoate 3.998861\n", 190 | "(3-methyladenine) 0.959734\n", 191 | "adenosine 4.981767\n", 192 | "uracil 1.782586\n", 193 | "Name: rplo 1 (Cyanobacteria), dtype: float64" 194 | ] 195 | }, 196 | "execution_count": 10, 197 | "metadata": {}, 198 | "output_type": "execute_result" 199 | } 200 | ], 201 | "source": [ 202 | "ranks['rplo 1 (Cyanobacteria)'].loc[microcoleus_metabolites]" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 11, 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "assert len(detected_molecules & microcoleus_metabolites) == 13" 212 | ] 213 | }, 214 | { 215 | "cell_type": "markdown", 216 | "metadata": {}, 217 | "source": [ 218 | "# qiime2 check" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 12, 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "import qiime2\n" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 13, 233 | "metadata": {}, 234 | "outputs": [ 235 | { 236 | "name": "stderr", 237 | "output_type": "stream", 238 | "text": [ 239 | "/Users/jmorton/miniconda3/envs/qiime2-2019.7/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", 240 | " _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n", 241 | "/Users/jmorton/miniconda3/envs/qiime2-2019.7/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", 242 | " _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n", 243 | "/Users/jmorton/miniconda3/envs/qiime2-2019.7/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", 244 | " _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n", 245 | "/Users/jmorton/miniconda3/envs/qiime2-2019.7/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", 246 | " _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n", 247 | "/Users/jmorton/miniconda3/envs/qiime2-2019.7/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", 248 | " _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n", 249 | "/Users/jmorton/miniconda3/envs/qiime2-2019.7/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", 250 | " np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n", 251 | "/Users/jmorton/miniconda3/envs/qiime2-2019.7/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", 252 | " _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n", 253 | "/Users/jmorton/miniconda3/envs/qiime2-2019.7/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", 254 | " _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n", 255 | "/Users/jmorton/miniconda3/envs/qiime2-2019.7/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", 256 | " _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n", 257 | "/Users/jmorton/miniconda3/envs/qiime2-2019.7/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", 258 | " _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n", 259 | "/Users/jmorton/miniconda3/envs/qiime2-2019.7/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", 260 | " _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n", 261 | "/Users/jmorton/miniconda3/envs/qiime2-2019.7/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", 262 | " np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n" 263 | ] 264 | } 265 | ], 266 | "source": [ 267 | "ranks = qiime2.Artifact.load('ranks.qza').view(pd.DataFrame)" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 14, 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "idx = ranks['rplo 1 (Cyanobacteria)'] > 0 \n", 277 | "detected_molecules = set(ranks.index[idx])\n", 278 | "assert len(detected_molecules & microcoleus_metabolites) == 13" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 15, 284 | "metadata": {}, 285 | "outputs": [ 286 | { 287 | "data": { 288 | "text/plain": [ 289 | "featureid\n", 290 | "xanthine 0.742247\n", 291 | "(N6-acetyl-lysine) 4.348102\n", 292 | "succinate 0.720436\n", 293 | "guanine 2.874898\n", 294 | "adenine 4.852340\n", 295 | "N-acetylornithine 1.147904\n", 296 | "7-methyladenine 0.340129\n", 297 | "cytosine 3.000772\n", 298 | "hypoxanthine 0.522730\n", 299 | "4-guanidinobutanoate 3.888838\n", 300 | "(3-methyladenine) 0.750547\n", 301 | "adenosine 4.898254\n", 302 | "uracil 1.723071\n", 303 | "Name: rplo 1 (Cyanobacteria), dtype: float64" 304 | ] 305 | }, 306 | "execution_count": 15, 307 | "metadata": {}, 308 | "output_type": "execute_result" 309 | } 310 | ], 311 | "source": [ 312 | "ranks['rplo 1 (Cyanobacteria)'].loc[microcoleus_metabolites]" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": null, 318 | "metadata": {}, 319 | "outputs": [], 320 | "source": [] 321 | } 322 | ], 323 | "metadata": { 324 | "kernelspec": { 325 | "display_name": "Python 3", 326 | "language": "python", 327 | "name": "python3" 328 | }, 329 | "language_info": { 330 | "codemirror_mode": { 331 | "name": "ipython", 332 | "version": 3 333 | }, 334 | "file_extension": ".py", 335 | "mimetype": "text/x-python", 336 | "name": "python", 337 | "nbconvert_exporter": "python", 338 | "pygments_lexer": "ipython3", 339 | "version": "3.6.7" 340 | } 341 | }, 342 | "nbformat": 4, 343 | "nbformat_minor": 2 344 | } 345 | -------------------------------------------------------------------------------- /examples/soils/metabolites.biom: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biocore/mmvec/88ca33b408a85b6bf90fae06982936247b860272/examples/soils/metabolites.biom -------------------------------------------------------------------------------- /examples/soils/microbes.biom: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biocore/mmvec/88ca33b408a85b6bf90fae06982936247b860272/examples/soils/microbes.biom -------------------------------------------------------------------------------- /examples/soils/run.sh: -------------------------------------------------------------------------------- 1 | mmvec paired-omics\ 2 | --microbe-file microbes.biom \ 3 | --metabolite-file metabolites.biom \ 4 | --num-testing-examples 1 \ 5 | --min-feature-count 0 \ 6 | --latent-dim 1 \ 7 | --learning-rate 1e-3 \ 8 | --epochs 3000 9 | 10 | qiime tools import --input-path microbes.biom --output-path microbes.biom.qza --type FeatureTable[Frequency] 11 | qiime tools import --input-path metabolites.biom --output-path metabolites.biom.qza --type FeatureTable[Frequency] 12 | 13 | qiime mmvec paired-omics \ 14 | --i-microbes microbes.biom.qza \ 15 | --i-metabolites metabolites.biom.qza \ 16 | --p-epochs 100 \ 17 | --p-learning-rate 1e-3 \ 18 | --o-conditionals ranks.qza \ 19 | --o-conditional-biplot biplot.qza \ 20 | --verbose 21 | 22 | -------------------------------------------------------------------------------- /img/biplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biocore/mmvec/88ca33b408a85b6bf90fae06982936247b860272/img/biplot.png -------------------------------------------------------------------------------- /img/heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biocore/mmvec/88ca33b408a85b6bf90fae06982936247b860272/img/heatmap.png -------------------------------------------------------------------------------- /img/mmvec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biocore/mmvec/88ca33b408a85b6bf90fae06982936247b860272/img/mmvec.png -------------------------------------------------------------------------------- /img/paired-heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biocore/mmvec/88ca33b408a85b6bf90fae06982936247b860272/img/paired-heatmap.png -------------------------------------------------------------------------------- /img/paired-summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biocore/mmvec/88ca33b408a85b6bf90fae06982936247b860272/img/paired-summary.png -------------------------------------------------------------------------------- /img/single-summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biocore/mmvec/88ca33b408a85b6bf90fae06982936247b860272/img/single-summary.png -------------------------------------------------------------------------------- /img/tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biocore/mmvec/88ca33b408a85b6bf90fae06982936247b860272/img/tensorboard.png -------------------------------------------------------------------------------- /mmvec/__init__.py: -------------------------------------------------------------------------------- 1 | from .heatmap import _heatmap_choices, _cmaps 2 | 3 | __version__ = "1.0.6" 4 | 5 | __all__ = ['_heatmap_choices', '_cmaps'] 6 | -------------------------------------------------------------------------------- /mmvec/heatmap.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import pandas as pd 5 | import warnings 6 | 7 | 8 | _heatmap_choices = { 9 | 'metric': {'braycurtis', 'canberra', 'chebyshev', 'cityblock', 10 | 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 11 | 'jaccard', 'kulsinski', 'mahalanobis', 'matching', 'minkowski', 12 | 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 13 | 'sokalsneath', 'sqeuclidean', 'yule'}, 14 | 'method': {'single', 'complete', 'average', 'weighted', 'centroid', 15 | 'median', 'ward'}} 16 | 17 | _cmaps = { 18 | 'heatmap': [ 19 | 'PiYG', 'PRGn', 'BrBG', 'PuOr', 'RdGy', 'RdBu', 20 | 'RdYlBu', 'RdYlGn', 'Spectral', 'coolwarm', 'bwr', 'seismic', 21 | 'viridis', 'plasma', 'inferno', 'magma', 'cividis'], 22 | 'margins': [ 23 | 'cubehelix', 'Pastel1', 'Pastel2', 'Paired', 24 | 'Accent', 'Dark2', 'Set1', 'Set2', 'Set3', 25 | 'tab10', 'tab20', 'tab20b', 'tab20c', 26 | 'viridis', 'plasma', 'inferno', 'magma', 'cividis']} 27 | 28 | 29 | def ranks_heatmap(ranks, microbe_metadata=None, metabolite_metadata=None, 30 | method='average', metric='euclidean', 31 | color_palette='seismic', margin_palette='cubehelix', 32 | x_labels=False, y_labels=False, level=3): 33 | ''' 34 | Generate clustermap of microbe X metabolite conditional probabilities. 35 | 36 | Parameters 37 | ---------- 38 | ranks: pd.DataFrame of conditional probabilities. 39 | Microbes (rows) X metabolites (columns). 40 | microbe_metadata: pd.Series of microbe metadata for annotating plots 41 | metabolite_metadata: pd.Series of metabolite metadata for annotating plots 42 | method: str 43 | Hierarchical clustering method used in clustermap. 44 | metric: str 45 | Hierarchical clustering distance metric used in clustermap. 46 | color_palette: str 47 | Color palette for clustermap. 48 | margin_palette: str 49 | Name of color palette to use for annotating metadata 50 | along margin(s) of clustermap. 51 | x_labels: bool 52 | Plot x-axis (metabolite) labels? 53 | y_labels: bool 54 | Plot y-axis (microbe) labels? 55 | level: int 56 | taxonomic level for annotating clustermap. Set to -1 if not parsing 57 | semicolon-delimited taxonomies or wish to print entire annotation. 58 | 59 | Returns 60 | ------- 61 | sns.clustermap 62 | ''' 63 | # subset microbe metadata based on rows/columns 64 | if microbe_metadata is not None: 65 | microbe_metadata, ranks, row_colors, row_class_colors = \ 66 | _process_microbe_metadata( 67 | ranks, microbe_metadata, level, margin_palette) 68 | else: 69 | row_colors = None 70 | 71 | # subset metabolite metadata based on rows/columns 72 | if metabolite_metadata is not None: 73 | metabolite_metadata, ranks, col_colors, col_class_colors = \ 74 | _process_metabolite_metadata( 75 | ranks, metabolite_metadata, margin_palette) 76 | else: 77 | col_colors = None 78 | 79 | # Generate heatmap 80 | hotmap = sns.clustermap(ranks, cmap=color_palette, center=0, 81 | col_colors=col_colors, row_colors=row_colors, 82 | figsize=(12, 12), method=method, metric=metric, 83 | cbar_kws={'label': 'Log Conditional\nProbability'}) 84 | 85 | # add legends 86 | if col_colors is not None: 87 | for label in col_class_colors.keys(): 88 | hotmap.ax_col_dendrogram.bar( 89 | 0, 0, color=col_class_colors[label], label=label, linewidth=0) 90 | hotmap.ax_col_dendrogram.legend( 91 | title=metabolite_metadata.name, ncol=5, bbox_to_anchor=(0.9, 0.95), 92 | bbox_transform=plt.gcf().transFigure) 93 | if row_colors is not None: 94 | for label in row_class_colors.keys(): 95 | hotmap.ax_row_dendrogram.bar( 96 | 0, 0, color=row_class_colors[label], label=label, linewidth=0) 97 | hotmap.ax_row_dendrogram.legend( 98 | title=microbe_metadata.name, ncol=1, bbox_to_anchor=(0.2, 0.7), 99 | bbox_transform=plt.gcf().transFigure) 100 | 101 | # toggle axis labels 102 | if not x_labels: 103 | hotmap.ax_heatmap.set_xticklabels('') 104 | if not y_labels: 105 | hotmap.ax_heatmap.set_yticklabels('') 106 | 107 | plt.subplots_adjust(left=0.2) 108 | return hotmap 109 | 110 | 111 | def paired_heatmaps(ranks, microbes_table, metabolites_table, microbe_metadata, 112 | features=None, top_k_microbes=2, top_k_metabolites=50, 113 | keep_top_samples=True, level=-1, normalize='log10', 114 | color_palette='magma'): 115 | ''' 116 | Creates paired heatmaps of microbe abundances and metabolite abundances. 117 | 118 | Parameters 119 | ---------- 120 | ranks: pd.DataFrame of conditional probabilities. 121 | Microbes (rows) X metabolites (columns). 122 | microbes_table: biom.Table 123 | Microbe feature abundances per sample. 124 | metabolites_table: biom.Table 125 | Metabolite feature abundances per sample. 126 | microbe_metadata: pd.Series 127 | Microbe metadata for annotating plots 128 | features: list 129 | Select microbial feature IDs to display on paired heatmap. 130 | top_k_microbes: int 131 | Select top k microbes with highest abundances to display on heatmap. 132 | top_k_metabolites: int 133 | Select top k metabolites associated with each of the chosen features to 134 | display on heatmap. 135 | keep_top_samples: bool 136 | Toggle whether to display only samples in which selected microbes are 137 | the most abundant ASV. 138 | level: int 139 | taxonomic level for annotating clustermap. 140 | Set to -1 if not parsing semicolon-delimited taxonomies. 141 | normalize: str 142 | Column normalization strategy to use for heatmaps. Must 143 | be "log10", "z_score", or None 144 | color_palette: str 145 | Color palette for heatmaps. 146 | ''' 147 | if top_k_microbes is features is None: 148 | raise ValueError('Must select features by name and/or use the ' 149 | 'top_k_microbes parameter to select features to ' 150 | 'include in the heatmap.') 151 | 152 | # validate microbes 153 | if features is not None: 154 | microbe_ids = set(microbes_table.ids('observation')) 155 | missing_microbes = set(features) - microbe_ids 156 | if len(missing_microbes) > 0: 157 | raise ValueError('features must represent feature IDs in ' 158 | 'microbes_table. Missing microbe(s): {0}'.format( 159 | missing_microbes)) 160 | else: 161 | features = [] 162 | 163 | microbes_table = microbes_table.to_dataframe().T 164 | metabolites_table = metabolites_table.to_dataframe().T 165 | 166 | # optionally normalize tables 167 | if normalize != 'None': 168 | microbes_table = _normalize_table(microbes_table, normalize) 169 | metabolites_table = _normalize_table(metabolites_table, normalize) 170 | cbar_label = normalize + ' Frequency' 171 | else: 172 | cbar_label = 'Frequency' 173 | 174 | # find top k microbes (highest relative abundances) 175 | if top_k_microbes is not None: 176 | # select top relative abundances 177 | top_microbes = microbes_table.apply( 178 | lambda x: x / x.sum(), axis=1).sum().sort_values(ascending=False) 179 | # TODO: add option for selecting top_k_microbes by rank 180 | # top_microbes = ranks.max(axis=1).sort_values(ascending=False) 181 | top_microbes = top_microbes[:top_k_microbes].index 182 | # merge top k microbes with selected features 183 | # use list comprehension instead of casting as set to preserve order. 184 | features = features + [f for f in top_microbes if f not in features] 185 | 186 | # select samples in which microbes are most abundant feature 187 | if keep_top_samples: 188 | select_microbes = microbes_table[microbes_table.apply( 189 | pd.Series.idxmax, axis=1).isin(features)] 190 | 191 | # filter select microbes from microbe table and sort by abundance 192 | sort_orders = [False] + [True] * (len(features) - 1) 193 | select_microbes = select_microbes[features] 194 | select_microbes = select_microbes.sort_values( 195 | features, ascending=sort_orders) 196 | 197 | # find top K metabolites (highest positive ranks) for each microbe 198 | if top_k_metabolites != 'all': 199 | top_metabolites = dict.fromkeys(m for x in [ranks.loc[f].sort_values( 200 | ascending=False)[:top_k_metabolites].index 201 | for f in features] for m in x).keys() 202 | select_metabolites = metabolites_table[top_metabolites] 203 | else: 204 | select_metabolites = metabolites_table 205 | 206 | # align sample IDs across tables 207 | select_microbes, select_metabolites = select_microbes.align( 208 | select_metabolites, join='inner', axis=0) 209 | 210 | # optionally annotate microbe data with taxonomy 211 | if microbe_metadata is not None: 212 | annotations = microbe_metadata.reindex(select_microbes.columns) 213 | # parse semicolon-delimited taxonomy 214 | if level > -1: 215 | annotations = _parse_taxonomy_strings(annotations, level) 216 | else: 217 | annotations = select_microbes.columns 218 | 219 | # generate heatmaps 220 | heatmaps, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 6)) 221 | 222 | sns.heatmap(select_microbes.values, cmap=color_palette, 223 | cbar_kws={'label': cbar_label}, ax=axes[0], 224 | xticklabels=annotations, yticklabels=False, robust=True) 225 | sns.heatmap(select_metabolites.values, cmap=color_palette, 226 | cbar_kws={'label': cbar_label}, ax=axes[1], 227 | xticklabels=False, yticklabels=False, robust=True) 228 | axes[0].set_title('Microbe abundances') 229 | axes[0].set_ylabel('Samples') 230 | axes[0].set_xlabel('Microbes') 231 | axes[1].set_title('Metabolite abundances') 232 | axes[1].set_xlabel('Metabolites') 233 | return select_microbes, select_metabolites, heatmaps 234 | 235 | 236 | def _parse_heatmap_metadata_annotations(metadata_column, margin_palette): 237 | ''' 238 | Transform feature or sample metadata into color vector for annotating 239 | margin of clustermap. 240 | Parameters 241 | ---------- 242 | metadata_column: pd.Series of metadata for annotating plots 243 | margin_palette: str 244 | Name of color palette to use for annotating metadata 245 | along margin(s) of clustermap. 246 | Returns 247 | ------- 248 | Returns vector of colors for annotating clustermap and dict mapping colors 249 | to classes. 250 | ''' 251 | # Create a categorical palette to identify md col 252 | metadata_column = metadata_column.astype(str) 253 | col_names = sorted(metadata_column.unique()) 254 | 255 | # Select Color palette 256 | if margin_palette == 'colorhelix': 257 | col_palette = sns.cubehelix_palette( 258 | len(col_names), start=2, rot=3, dark=0.3, light=0.8, reverse=True) 259 | else: 260 | col_palette = sns.color_palette(margin_palette, len(col_names)) 261 | class_colors = dict(zip(col_names, col_palette)) 262 | 263 | # Convert the palette to vectors that will be drawn on the matrix margin 264 | col_colors = metadata_column.map(class_colors) 265 | 266 | return col_colors, class_colors 267 | 268 | 269 | def _parse_taxonomy_strings(taxonomy_series, level): 270 | ''' 271 | taxonomy_series: pd.Series of semicolon-delimited taxonomy strings 272 | level: int 273 | taxonomic level for annotating clustermap. 274 | Returns 275 | ------- 276 | Returns a pd.Series of taxonomy names at specified level, 277 | or terminal annotation 278 | ''' 279 | return taxonomy_series.apply(lambda x: x.split(';')[:level][-1].strip()) 280 | 281 | 282 | def _process_microbe_metadata(ranks, microbe_metadata, level, margin_palette): 283 | _warn_metadata_filtering('microbe') 284 | microbe_metadata, ranks = microbe_metadata.align( 285 | ranks, join='inner', axis=0) 286 | # parse semicolon-delimited taxonomy 287 | if level > -1: 288 | microbe_metadata = _parse_taxonomy_strings(microbe_metadata, level) 289 | # map metadata categories to row colors 290 | row_colors, row_class_colors = _parse_heatmap_metadata_annotations( 291 | microbe_metadata, margin_palette) 292 | 293 | return microbe_metadata, ranks, row_colors, row_class_colors 294 | 295 | 296 | def _process_metabolite_metadata(ranks, metabolite_metadata, margin_palette): 297 | _warn_metadata_filtering('metabolite') 298 | _ids = set(metabolite_metadata.index) & set(ranks.columns) 299 | ranks = ranks[sorted(_ids)] 300 | metabolite_metadata = metabolite_metadata.reindex(ranks.columns) 301 | # map metadata categories to column colors 302 | col_colors, col_class_colors = _parse_heatmap_metadata_annotations( 303 | metabolite_metadata, margin_palette) 304 | 305 | return metabolite_metadata, ranks, col_colors, col_class_colors 306 | 307 | 308 | def _warn_metadata_filtering(metadata_type): 309 | warning = ('Conditional probabilities table and {0} metadata will be ' 310 | 'filtered to contain only the intersection of IDs in each. If ' 311 | 'this behavior is undesired, ensure that all {0} IDs are ' 312 | 'present in both the table and the metadata ' 313 | 'file'.format(metadata_type)) 314 | warnings.warn(warning, UserWarning) 315 | 316 | 317 | def _normalize_table(table, method): 318 | ''' 319 | Normalize column data in a dataframe for plotting in clustermap. 320 | 321 | table: pd.DataFrame 322 | Input data. 323 | method: str 324 | Normalization method to use. 325 | 326 | Returns normalized table as pd.DataFrame 327 | ''' 328 | if 'col' in method: 329 | axis = 0 330 | elif 'row' in method: 331 | axis = 1 332 | if 'z_score' in method: 333 | res = table.apply(lambda x: (x - x.mean()) / x.std(), axis=axis) 334 | elif 'rel' in method: 335 | res = table.apply(lambda x: x / x.sum(), axis=axis) 336 | elif method == 'log10': 337 | res = table.apply(lambda x: np.log10(x + 1)) 338 | return res.fillna(0) 339 | -------------------------------------------------------------------------------- /mmvec/multimodal.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from tqdm import tqdm 4 | import numpy as np 5 | import tensorflow as tf 6 | from tensorflow.contrib.distributions import Multinomial, Normal 7 | import datetime 8 | 9 | 10 | class MMvec(object): 11 | 12 | def __init__(self, u_mean=0, u_scale=1, v_mean=0, v_scale=1, 13 | batch_size=50, latent_dim=3, 14 | learning_rate=0.1, beta_1=0.8, beta_2=0.9, 15 | clipnorm=10., device_name='/cpu:0', save_path=None): 16 | """ Build a tensorflow model for microbe-metabolite vectors 17 | 18 | Returns 19 | ------- 20 | loss : tf.Tensor 21 | The log loss of the model. 22 | 23 | Notes 24 | ----- 25 | To enable a GPU, set the device to '/device:GPU:x' 26 | where x is 0 or greater 27 | """ 28 | p = latent_dim 29 | self.device_name = device_name 30 | if save_path is None: 31 | basename = "logdir" 32 | suffix = datetime.datetime.now().strftime("%y%m%d_%H%M%S") 33 | save_path = "_".join([basename, suffix]) 34 | 35 | self.p = p 36 | self.u_mean = u_mean 37 | self.u_scale = u_scale 38 | self.v_mean = v_mean 39 | self.v_scale = v_scale 40 | self.batch_size = batch_size 41 | self.latent_dim = latent_dim 42 | 43 | self.learning_rate = learning_rate 44 | self.beta_1 = beta_1 45 | self.beta_2 = beta_2 46 | self.clipnorm = clipnorm 47 | self.save_path = save_path 48 | 49 | def __call__(self, session, trainX, trainY, testX, testY): 50 | """ Initialize the actual graph 51 | 52 | Parameters 53 | ---------- 54 | session : tf.Session 55 | Tensorflow session 56 | trainX : sparse array in coo format 57 | Test input OTU table, where rows are samples and columns are 58 | observations 59 | trainY : np.array 60 | Test output metabolite table 61 | testX : sparse array in coo format 62 | Test input OTU table, where rows are samples and columns are 63 | observations. This is mainly for cross validation. 64 | testY : np.array 65 | Test output metabolite table. This is mainly for cross validation. 66 | """ 67 | self.session = session 68 | self.nnz = len(trainX.data) 69 | self.d1 = trainX.shape[1] 70 | self.d2 = trainY.shape[1] 71 | self.cv_size = len(testX.data) 72 | 73 | # keep the multinomial sampling on the cpu 74 | # https://github.com/tensorflow/tensorflow/issues/18058 75 | with tf.device('/cpu:0'): 76 | X_ph = tf.SparseTensor( 77 | indices=np.array([trainX.row, trainX.col]).T, 78 | values=trainX.data, 79 | dense_shape=trainX.shape) 80 | Y_ph = tf.constant(trainY, dtype=tf.float32) 81 | 82 | X_holdout = tf.SparseTensor( 83 | indices=np.array([testX.row, testX.col]).T, 84 | values=testX.data, 85 | dense_shape=testX.shape) 86 | Y_holdout = tf.constant(testY, dtype=tf.float32) 87 | 88 | total_count = tf.reduce_sum(Y_ph, axis=1) 89 | batch_ids = tf.multinomial( 90 | tf.log(tf.reshape(X_ph.values, [1, -1])), 91 | self.batch_size) 92 | batch_ids = tf.squeeze(batch_ids) 93 | X_samples = tf.gather(X_ph.indices, 0, axis=1) 94 | X_obs = tf.gather(X_ph.indices, 1, axis=1) 95 | sample_ids = tf.gather(X_samples, batch_ids) 96 | 97 | Y_batch = tf.gather(Y_ph, sample_ids) 98 | X_batch = tf.gather(X_obs, batch_ids) 99 | 100 | with tf.device(self.device_name): 101 | self.qUmain = tf.Variable( 102 | tf.random_normal([self.d1, self.p]), name='qU') 103 | self.qUbias = tf.Variable( 104 | tf.random_normal([self.d1, 1]), name='qUbias') 105 | self.qVmain = tf.Variable( 106 | tf.random_normal([self.p, self.d2-1]), name='qV') 107 | self.qVbias = tf.Variable( 108 | tf.random_normal([1, self.d2-1]), name='qVbias') 109 | 110 | qU = tf.concat( 111 | [tf.ones([self.d1, 1]), self.qUbias, self.qUmain], axis=1) 112 | qV = tf.concat( 113 | [self.qVbias, tf.ones([1, self.d2-1]), self.qVmain], axis=0) 114 | 115 | # regression coefficents distribution 116 | Umain = Normal(loc=tf.zeros([self.d1, self.p]) + self.u_mean, 117 | scale=tf.ones([self.d1, self.p]) * self.u_scale, 118 | name='U') 119 | Ubias = Normal(loc=tf.zeros([self.d1, 1]) + self.u_mean, 120 | scale=tf.ones([self.d1, 1]) * self.u_scale, 121 | name='biasU') 122 | 123 | Vmain = Normal(loc=tf.zeros([self.p, self.d2-1]) + self.v_mean, 124 | scale=tf.ones([self.p, self.d2-1]) * self.v_scale, 125 | name='V') 126 | Vbias = Normal(loc=tf.zeros([1, self.d2-1]) + self.v_mean, 127 | scale=tf.ones([1, self.d2-1]) * self.v_scale, 128 | name='biasV') 129 | 130 | du = tf.gather(qU, X_batch, axis=0, name='du') 131 | dv = tf.concat([tf.zeros([self.batch_size, 1]), 132 | du @ qV], axis=1, name='dv') 133 | 134 | tc = tf.gather(total_count, sample_ids) 135 | Y = Multinomial(total_count=tc, logits=dv, name='Y') 136 | num_samples = trainX.shape[0] 137 | norm = num_samples / self.batch_size 138 | logprob_vmain = tf.reduce_sum( 139 | Vmain.log_prob(self.qVmain), name='logprob_vmain') 140 | logprob_vbias = tf.reduce_sum( 141 | Vbias.log_prob(self.qVbias), name='logprob_vbias') 142 | logprob_umain = tf.reduce_sum( 143 | Umain.log_prob(self.qUmain), name='logprob_umain') 144 | logprob_ubias = tf.reduce_sum( 145 | Ubias.log_prob(self.qUbias), name='logprob_ubias') 146 | logprob_y = tf.reduce_sum(Y.log_prob(Y_batch), name='logprob_y') 147 | self.log_loss = - ( 148 | logprob_y * norm + 149 | logprob_umain + logprob_ubias + 150 | logprob_vmain + logprob_vbias 151 | ) 152 | 153 | # keep the multinomial sampling on the cpu 154 | # https://github.com/tensorflow/tensorflow/issues/18058 155 | with tf.device('/cpu:0'): 156 | # cross validation 157 | with tf.name_scope('accuracy'): 158 | cv_batch_ids = tf.multinomial( 159 | tf.log(tf.reshape(X_holdout.values, [1, -1])), 160 | self.cv_size) 161 | cv_batch_ids = tf.squeeze(cv_batch_ids) 162 | X_cv_samples = tf.gather(X_holdout.indices, 0, axis=1) 163 | X_cv = tf.gather(X_holdout.indices, 1, axis=1) 164 | cv_sample_ids = tf.gather(X_cv_samples, cv_batch_ids) 165 | 166 | Y_cvbatch = tf.gather(Y_holdout, cv_sample_ids) 167 | X_cvbatch = tf.gather(X_cv, cv_batch_ids) 168 | holdout_count = tf.reduce_sum(Y_cvbatch, axis=1) 169 | cv_du = tf.gather(qU, X_cvbatch, axis=0, name='cv_du') 170 | pred = tf.reshape( 171 | holdout_count, [-1, 1]) * tf.nn.softmax( 172 | tf.concat([tf.zeros([ 173 | self.cv_size, 1]), 174 | cv_du @ qV], axis=1, name='pred') 175 | ) 176 | 177 | self.cv = tf.reduce_mean( 178 | tf.squeeze(tf.abs(pred - Y_cvbatch)) 179 | ) 180 | 181 | # keep all summaries on the cpu 182 | with tf.device('/cpu:0'): 183 | tf.summary.scalar('logloss', self.log_loss) 184 | tf.summary.scalar('cv_rmse', self.cv) 185 | tf.summary.histogram('qUmain', self.qUmain) 186 | tf.summary.histogram('qVmain', self.qVmain) 187 | tf.summary.histogram('qUbias', self.qUbias) 188 | tf.summary.histogram('qVbias', self.qVbias) 189 | self.merged = tf.summary.merge_all() 190 | 191 | self.writer = tf.summary.FileWriter( 192 | self.save_path, self.session.graph) 193 | 194 | with tf.device(self.device_name): 195 | with tf.name_scope('optimize'): 196 | optimizer = tf.train.AdamOptimizer( 197 | self.learning_rate, beta1=self.beta_1, beta2=self.beta_2) 198 | 199 | gradients, self.variables = zip( 200 | *optimizer.compute_gradients(self.log_loss)) 201 | self.gradients, _ = tf.clip_by_global_norm( 202 | gradients, self.clipnorm) 203 | self.train = optimizer.apply_gradients( 204 | zip(self.gradients, self.variables)) 205 | 206 | tf.global_variables_initializer().run() 207 | 208 | def ranks(self): 209 | modelU = np.hstack( 210 | (np.ones((self.U.shape[0], 1)), self.Ubias, self.U)) 211 | modelV = np.vstack( 212 | (self.Vbias, np.ones((1, self.V.shape[1])), self.V)) 213 | 214 | res = np.hstack((np.zeros((self.U.shape[0], 1)), modelU @ modelV)) 215 | res = res - res.mean(axis=1).reshape(-1, 1) 216 | return res 217 | 218 | def fit(self, epoch=10, summary_interval=1000, checkpoint_interval=3600, 219 | testX=None, testY=None): 220 | """ Fits the model. 221 | 222 | Parameters 223 | ---------- 224 | epoch : int 225 | Number of epochs to train 226 | summary_interval : int 227 | Number of seconds until a summary is recorded 228 | checkpoint_interval : int 229 | Number of seconds until a checkpoint is recorded 230 | 231 | Returns 232 | ------- 233 | loss: float 234 | log likelihood loss. 235 | cv : float 236 | cross validation loss 237 | """ 238 | iterations = epoch * self.nnz // self.batch_size 239 | losses, cvs = [], [] 240 | cv = None 241 | last_checkpoint_time = 0 242 | last_summary_time = 0 243 | saver = tf.train.Saver() 244 | now = time.time() 245 | for i in tqdm(range(0, iterations)): 246 | if now - last_summary_time > summary_interval: 247 | 248 | res = self.session.run( 249 | [self.train, self.merged, self.log_loss, self.cv, 250 | self.qUmain, self.qUbias, 251 | self.qVmain, self.qVbias] 252 | ) 253 | train_, summary, loss, cv, rU, rUb, rV, rVb = res 254 | self.writer.add_summary(summary, i) 255 | last_summary_time = now 256 | else: 257 | res = self.session.run( 258 | [self.train, self.log_loss, 259 | self.qUmain, self.qUbias, 260 | self.qVmain, self.qVbias] 261 | ) 262 | train_, loss, rU, rUb, rV, rVb = res 263 | losses.append(loss) 264 | cvs.append(cv) 265 | cv = None 266 | 267 | # checkpoint model 268 | now = time.time() 269 | if now - last_checkpoint_time > checkpoint_interval: 270 | saver.save(self.session, 271 | os.path.join(self.save_path, "model.ckpt"), 272 | global_step=i) 273 | last_checkpoint_time = now 274 | 275 | self.U = rU 276 | self.V = rV 277 | self.Ubias = rUb 278 | self.Vbias = rVb 279 | 280 | return losses, cvs 281 | -------------------------------------------------------------------------------- /mmvec/q2/__init__.py: -------------------------------------------------------------------------------- 1 | from ._stats import (Conditional, ConditionalDirFmt, ConditionalFormat, 2 | MMvecStats, MMvecStatsFormat, MMvecStatsDirFmt) 3 | from ._method import paired_omics 4 | from ._visualizers import heatmap, paired_heatmap 5 | from ._summary import summarize_single, summarize_paired 6 | 7 | 8 | __all__ = ['paired_omics', 9 | 'Conditional', 'ConditionalFormat', 'ConditionalDirFmt', 10 | 'MMvecStats', 'MMvecStatsFormat', 'MMvecStatsDirFmt', 11 | 'heatmap', 'paired_heatmap', 12 | 'summarize_single', 'summarize_paired'] 13 | -------------------------------------------------------------------------------- /mmvec/q2/_method.py: -------------------------------------------------------------------------------- 1 | import biom 2 | import pandas as pd 3 | import numpy as np 4 | import tensorflow as tf 5 | from skbio import OrdinationResults 6 | import qiime2 7 | from qiime2.plugin import Metadata 8 | from mmvec.multimodal import MMvec 9 | from mmvec.util import split_tables 10 | from scipy.sparse import coo_matrix 11 | from scipy.sparse.linalg import svds 12 | 13 | 14 | def paired_omics(microbes: biom.Table, 15 | metabolites: biom.Table, 16 | metadata: Metadata = None, 17 | training_column: str = None, 18 | num_testing_examples: int = 5, 19 | min_feature_count: int = 10, 20 | epochs: int = 100, 21 | batch_size: int = 50, 22 | latent_dim: int = 3, 23 | input_prior: float = 1, 24 | output_prior: float = 1, 25 | learning_rate: float = 1e-3, 26 | equalize_biplot: float = False, 27 | arm_the_gpu: bool = False, 28 | summary_interval: int = 1) -> ( 29 | pd.DataFrame, OrdinationResults, qiime2.Metadata 30 | ): 31 | 32 | if metadata is not None: 33 | metadata = metadata.to_dataframe() 34 | 35 | if arm_the_gpu: 36 | # pick out the first GPU 37 | device_name = '/device:GPU:0' 38 | else: 39 | device_name = '/cpu:0' 40 | 41 | # Note: there are a couple of biom -> pandas conversions taking 42 | # place here. This is currently done on purpose, since we 43 | # haven't figured out how to handle sparse matrix multiplication 44 | # in the context of this algorithm. That is a future consideration. 45 | res = split_tables( 46 | microbes, metabolites, 47 | metadata=metadata, training_column=training_column, 48 | num_test=num_testing_examples, 49 | min_samples=min_feature_count) 50 | 51 | (train_microbes_df, test_microbes_df, 52 | train_metabolites_df, test_metabolites_df) = res 53 | 54 | train_microbes_coo = coo_matrix(train_microbes_df.values) 55 | test_microbes_coo = coo_matrix(test_microbes_df.values) 56 | 57 | with tf.Graph().as_default(), tf.Session() as session: 58 | model = MMvec( 59 | latent_dim=latent_dim, 60 | u_scale=input_prior, v_scale=output_prior, 61 | batch_size=batch_size, 62 | device_name=device_name, 63 | learning_rate=learning_rate) 64 | model(session, 65 | train_microbes_coo, train_metabolites_df.values, 66 | test_microbes_coo, test_metabolites_df.values) 67 | 68 | loss, cv = model.fit(epoch=epochs, summary_interval=summary_interval) 69 | ranks = pd.DataFrame(model.ranks(), index=train_microbes_df.columns, 70 | columns=train_metabolites_df.columns) 71 | if latent_dim > 0: 72 | u, s, v = svds(ranks - ranks.mean(axis=0), k=latent_dim) 73 | else: 74 | # fake it until you make it 75 | u, s, v = svds(ranks - ranks.mean(axis=0), k=1) 76 | 77 | ranks = ranks.T 78 | ranks.index.name = 'featureid' 79 | s = s[::-1] 80 | u = u[:, ::-1] 81 | v = v[::-1, :] 82 | if equalize_biplot: 83 | microbe_embed = u @ np.sqrt(np.diag(s)) 84 | metabolite_embed = v.T @ np.sqrt(np.diag(s)) 85 | else: 86 | microbe_embed = u @ np.diag(s) 87 | metabolite_embed = v.T 88 | 89 | pc_ids = ['PC%d' % i for i in range(microbe_embed.shape[1])] 90 | features = pd.DataFrame( 91 | microbe_embed, columns=pc_ids, 92 | index=train_microbes_df.columns) 93 | samples = pd.DataFrame( 94 | metabolite_embed, columns=pc_ids, 95 | index=train_metabolites_df.columns) 96 | short_method_name = 'mmvec biplot' 97 | long_method_name = 'Multiomics mmvec biplot' 98 | eigvals = pd.Series(s, index=pc_ids) 99 | proportion_explained = pd.Series(s**2 / np.sum(s**2), index=pc_ids) 100 | biplot = OrdinationResults( 101 | short_method_name, long_method_name, eigvals, 102 | samples=samples, features=features, 103 | proportion_explained=proportion_explained) 104 | 105 | its = np.arange(len(loss)) 106 | convergence_stats = pd.DataFrame( 107 | { 108 | 'loss': loss, 109 | 'cross-validation': cv, 110 | 'iteration': its 111 | } 112 | ) 113 | 114 | convergence_stats.index.name = 'id' 115 | convergence_stats.index = convergence_stats.index.astype(np.str) 116 | 117 | c = convergence_stats['loss'].astype(np.float) 118 | convergence_stats['loss'] = c 119 | 120 | c = convergence_stats['cross-validation'].astype(np.float) 121 | convergence_stats['cross-validation'] = c 122 | 123 | c = convergence_stats['iteration'].astype(np.int) 124 | convergence_stats['iteration'] = c 125 | 126 | return ranks, biplot, qiime2.Metadata(convergence_stats) 127 | -------------------------------------------------------------------------------- /mmvec/q2/_stats.py: -------------------------------------------------------------------------------- 1 | from qiime2.plugin import SemanticType, model 2 | from q2_types.feature_data import FeatureData 3 | from q2_types.sample_data import SampleData 4 | 5 | 6 | Conditional = SemanticType('Conditional', 7 | variant_of=FeatureData.field['type']) 8 | 9 | 10 | class ConditionalFormat(model.TextFileFormat): 11 | def validate(*args): 12 | pass 13 | 14 | 15 | ConditionalDirFmt = model.SingleFileDirectoryFormat( 16 | 'ConditionalDirFmt', 'conditionals.tsv', ConditionalFormat) 17 | 18 | 19 | # songbird stats summarizing loss and cv error 20 | MMvecStats = SemanticType('MMvecStats', 21 | variant_of=SampleData.field['type']) 22 | 23 | 24 | class MMvecStatsFormat(model.TextFileFormat): 25 | def validate(*args): 26 | pass 27 | 28 | 29 | MMvecStatsDirFmt = model.SingleFileDirectoryFormat( 30 | 'MMvecStatsDirFmt', 'stats.tsv', MMvecStatsFormat) 31 | -------------------------------------------------------------------------------- /mmvec/q2/_summary.py: -------------------------------------------------------------------------------- 1 | import os 2 | import qiime2 3 | import pandas as pd 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def _convergence_plot(model, baseline, ax0, ax1): 9 | iterations = np.array(model['iteration']) 10 | cv_model = model.dropna() 11 | ax0.plot(cv_model['iteration'][1:], 12 | np.array(cv_model['cross-validation'].values)[1:], 13 | label='model') 14 | ax0.set_ylabel('Cross validation score', fontsize=14) 15 | ax0.set_xlabel('# Iterations', fontsize=14) 16 | 17 | ax1.plot(iterations[1:], 18 | np.array(model['loss'])[1:], label='model') 19 | ax1.set_ylabel('Loss', fontsize=14) 20 | ax1.set_xlabel('# Iterations', fontsize=14) 21 | 22 | if baseline is not None: 23 | iterations = baseline['iteration'] 24 | cv_baseline = baseline.dropna() 25 | ax0.plot(cv_baseline['iteration'][1:], 26 | np.array(cv_baseline['cross-validation'].values)[1:], 27 | label='baseline') 28 | ax0.set_ylabel('Cross validation score', fontsize=14) 29 | ax0.set_xlabel('# Iterations', fontsize=14) 30 | ax0.legend() 31 | 32 | ax1.plot(iterations[1:], 33 | np.array(baseline['loss'])[1:], label='baseline') 34 | ax1.set_ylabel('Loss', fontsize=14) 35 | ax1.set_xlabel('# Iterations', fontsize=14) 36 | ax1.legend() 37 | 38 | 39 | def _summarize(output_dir: str, model: pd.DataFrame, 40 | baseline: pd.DataFrame = None): 41 | 42 | """ Helper method for generating summary pages 43 | Parameters 44 | ---------- 45 | output_dir : str 46 | Name of output directory 47 | model : pd.DataFrame 48 | Model summary with column names 49 | ['loss', 'cross-validation'] 50 | baseline : pd.DataFrame 51 | Baseline model summary with column names 52 | ['loss', 'cross-validation']. Defaults to None (i.e. if only a single 53 | set of model stats will be summarized). 54 | Note 55 | ---- 56 | There may be synchronizing issues if different summary intervals 57 | were used between analyses. For predictable results, try to use the 58 | same summary interval. 59 | """ 60 | fig, ax = plt.subplots(2, 1, figsize=(10, 10)) 61 | if baseline is None: 62 | _convergence_plot(model, None, ax[0], ax[1]) 63 | q2 = None 64 | else: 65 | 66 | _convergence_plot(model, baseline, ax[0], ax[1]) 67 | 68 | # this provides a pseudo-r2 commonly provided in the context 69 | # of logistic / multinomail model (proposed by Cox & Snell) 70 | # http://www3.stat.sinica.edu.tw/statistica/oldpdf/a16n39.pdf 71 | end = min(10, len(model.index)) 72 | # trim only the last 10 numbers 73 | 74 | # compute a q2 score, which is commonly used in 75 | # partial least squares for cross validation 76 | cv_model = model.dropna() 77 | cv_baseline = baseline.dropna() 78 | 79 | l0 = np.mean(cv_baseline['cross-validation'][-end:]) 80 | lm = np.mean(cv_model['cross-validation'][-end:]) 81 | q2 = 1 - lm / l0 82 | 83 | plt.tight_layout() 84 | fig.savefig(os.path.join(output_dir, 'convergence-plot.svg')) 85 | fig.savefig(os.path.join(output_dir, 'convergence-plot.pdf')) 86 | 87 | index_fp = os.path.join(output_dir, 'index.html') 88 | with open(index_fp, 'w') as index_f: 89 | index_f.write('\n') 90 | index_f.write('

Convergence summary

\n') 91 | index_f.write( 92 | "

If you don't see anything in these plots, you probably need " 93 | "to decrease your --p-summary-interval. Try setting " 94 | "--p-summary-interval 1, which will record the loss at " 95 | "every second.

\n" 96 | ) 97 | 98 | if q2 is not None: 99 | index_f.write( 100 | '

' 101 | '' 102 | 'Pseudo Q-squared: %f

\n' % q2 103 | ) 104 | 105 | index_f.write( 106 | 'convergence_plots' 107 | ) 108 | index_f.write('') 109 | index_f.write('Download as PDF
\n') 110 | 111 | 112 | def summarize_single(output_dir: str, model_stats: qiime2.Metadata): 113 | _summarize(output_dir, model_stats.to_dataframe()) 114 | 115 | 116 | def summarize_paired(output_dir: str, 117 | model_stats: qiime2.Metadata, 118 | baseline_stats: qiime2.Metadata): 119 | _summarize(output_dir, 120 | model_stats.to_dataframe(), 121 | baseline_stats.to_dataframe()) 122 | -------------------------------------------------------------------------------- /mmvec/q2/_transformer.py: -------------------------------------------------------------------------------- 1 | import qiime2 2 | import pandas as pd 3 | 4 | from mmvec.q2 import ConditionalFormat, MMvecStatsFormat 5 | from mmvec.q2.plugin_setup import plugin 6 | 7 | 8 | @plugin.register_transformer 9 | def _1(ff: ConditionalFormat) -> pd.DataFrame: 10 | df = pd.read_csv(str(ff), sep='\t', comment='#', skip_blank_lines=True, 11 | header=0, index_col=0) 12 | return df 13 | 14 | 15 | @plugin.register_transformer 16 | def _2(df: pd.DataFrame) -> ConditionalFormat: 17 | ff = ConditionalFormat() 18 | df.to_csv(str(ff), sep='\t', header=True, index=True) 19 | return ff 20 | 21 | 22 | @plugin.register_transformer 23 | def _3(ff: ConditionalFormat) -> qiime2.Metadata: 24 | return qiime2.Metadata.load(str(ff)) 25 | 26 | 27 | @plugin.register_transformer 28 | def _4(obj: qiime2.Metadata) -> MMvecStatsFormat: 29 | ff = MMvecStatsFormat() 30 | obj.save(str(ff)) 31 | return ff 32 | 33 | 34 | @plugin.register_transformer 35 | def _5(ff: MMvecStatsFormat) -> qiime2.Metadata: 36 | return qiime2.Metadata.load(str(ff)) 37 | -------------------------------------------------------------------------------- /mmvec/q2/_visualizers.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | import pandas as pd 3 | import qiime2 4 | import biom 5 | import pkg_resources 6 | import q2templates 7 | from mmvec.heatmap import ranks_heatmap, paired_heatmaps 8 | 9 | 10 | TEMPLATES = pkg_resources.resource_filename('mmvec.q2', 'assets') 11 | 12 | 13 | def heatmap(output_dir: str, 14 | ranks: pd.DataFrame, 15 | microbe_metadata: qiime2.CategoricalMetadataColumn = None, 16 | metabolite_metadata: qiime2.CategoricalMetadataColumn = None, 17 | method: str = 'average', 18 | metric: str = 'euclidean', 19 | color_palette: str = 'seismic', 20 | margin_palette: str = 'cubehelix', 21 | x_labels: bool = False, 22 | y_labels: bool = False, 23 | level: int = -1, 24 | row_center: bool = True) -> None: 25 | if microbe_metadata is not None: 26 | microbe_metadata = microbe_metadata.to_series() 27 | if metabolite_metadata is not None: 28 | metabolite_metadata = metabolite_metadata.to_series() 29 | ranks = ranks.T 30 | 31 | if row_center: 32 | ranks = ranks - ranks.mean(axis=0) 33 | 34 | hotmap = ranks_heatmap(ranks, microbe_metadata, metabolite_metadata, 35 | method, metric, color_palette, margin_palette, 36 | x_labels, y_labels, level) 37 | 38 | hotmap.savefig(join(output_dir, 'heatmap.pdf'), bbox_inches='tight') 39 | hotmap.savefig(join(output_dir, 'heatmap.png'), bbox_inches='tight') 40 | 41 | index = join(TEMPLATES, 'index.html') 42 | q2templates.render(index, output_dir, context={ 43 | 'title': 'Rank Heatmap', 44 | 'pdf_fp': 'heatmap.pdf', 45 | 'png_fp': 'heatmap.png'}) 46 | 47 | 48 | def paired_heatmap(output_dir: str, 49 | ranks: pd.DataFrame, 50 | microbes_table: biom.Table, 51 | metabolites_table: biom.Table, 52 | features: str = None, 53 | top_k_microbes: int = 2, 54 | keep_top_samples: bool = True, 55 | microbe_metadata: qiime2.CategoricalMetadataColumn = None, 56 | normalize: str = 'log10', 57 | color_palette: str = 'magma', 58 | top_k_metabolites: int = 50, 59 | level: int = -1, 60 | row_center: bool = True) -> None: 61 | if microbe_metadata is not None: 62 | microbe_metadata = microbe_metadata.to_series() 63 | 64 | ranks = ranks.T 65 | 66 | if row_center: 67 | ranks = ranks - ranks.mean(axis=0) 68 | 69 | select_microbes, select_metabolites, hotmaps = paired_heatmaps( 70 | ranks, microbes_table, metabolites_table, microbe_metadata, features, 71 | top_k_microbes, top_k_metabolites, keep_top_samples, level, normalize, 72 | color_palette) 73 | 74 | hotmaps.savefig(join(output_dir, 'heatmap.pdf'), bbox_inches='tight') 75 | hotmaps.savefig(join(output_dir, 'heatmap.png'), bbox_inches='tight') 76 | select_microbes.to_csv(join(output_dir, 'select_microbes.tsv'), sep='\t') 77 | select_metabolites.to_csv( 78 | join(output_dir, 'select_metabolites.tsv'), sep='\t') 79 | 80 | index = join(TEMPLATES, 'index.html') 81 | q2templates.render(index, output_dir, context={ 82 | 'title': 'Paired Feature Abundance Heatmaps', 83 | 'pdf_fp': 'heatmap.pdf', 84 | 'png_fp': 'heatmap.png', 85 | 'table1_fp': 'select_microbes.tsv', 86 | 'download1_text': 'Download microbe abundances as TSV', 87 | 'table2_fp': 'select_metabolites.tsv', 88 | 'download2_text': 'Download top k metabolite abundances as TSV'}) 89 | -------------------------------------------------------------------------------- /mmvec/q2/assets/index.html: -------------------------------------------------------------------------------- 1 | {% extends 'base.html' %} 2 | 3 | {% block title %}rhapsody : {{ title }}{% endblock %} 4 | 5 | {% block fixed %}{% endblock %} 6 | 7 | {% block content %} 8 | 9 |
10 |

{{ title }}

11 | 26 |
27 | 28 | {% endblock %} 29 | -------------------------------------------------------------------------------- /mmvec/q2/plugin_setup.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright (c) 2016--, gneiss development team. 3 | # 4 | # Distributed under the terms of the Modified BSD License. 5 | # 6 | # The full license is in the file COPYING.txt, distributed with this software. 7 | # ---------------------------------------------------------------------------- 8 | import importlib 9 | import qiime2.plugin 10 | import qiime2.sdk 11 | from mmvec import __version__, _heatmap_choices, _cmaps 12 | from qiime2.plugin import (Str, Properties, Int, Float, Metadata, Bool, 13 | MetadataColumn, Categorical, Range, Choices, List) 14 | from q2_types.feature_table import FeatureTable, Frequency 15 | from q2_types.feature_data import FeatureData 16 | from q2_types.sample_data import SampleData 17 | from q2_types.ordination import PCoAResults 18 | from mmvec.q2 import ( 19 | Conditional, ConditionalFormat, ConditionalDirFmt, 20 | MMvecStats, MMvecStatsFormat, MMvecStatsDirFmt, 21 | paired_omics, heatmap, paired_heatmap, summarize_single, summarize_paired 22 | ) 23 | 24 | plugin = qiime2.plugin.Plugin( 25 | name='mmvec', 26 | version=__version__, 27 | website="https://github.com/biocore/mmvec", 28 | short_description='Plugin for performing microbe-metabolite ' 29 | 'co-occurence analysis.', 30 | description='This is a QIIME 2 plugin supporting microbe-metabolite ' 31 | 'co-occurence analysis using mmvec.', 32 | package='mmvec') 33 | 34 | plugin.methods.register_function( 35 | function=paired_omics, 36 | inputs={'microbes': FeatureTable[Frequency], 37 | 'metabolites': FeatureTable[Frequency]}, 38 | parameters={ 39 | 'metadata': Metadata, 40 | 'training_column': Str, 41 | 'num_testing_examples': Int, 42 | 'min_feature_count': Int, 43 | 'epochs': Int, 44 | 'batch_size': Int, 45 | 'arm_the_gpu': Bool, 46 | 'latent_dim': Int, 47 | 'input_prior': Float, 48 | 'output_prior': Float, 49 | 'learning_rate': Float, 50 | 'equalize_biplot': Bool, 51 | 'summary_interval': Int 52 | }, 53 | outputs=[ 54 | ('conditionals', FeatureData[Conditional]), 55 | ('conditional_biplot', PCoAResults % Properties('biplot')), 56 | ('model_stats', SampleData[MMvecStats]), 57 | ], 58 | input_descriptions={ 59 | 'microbes': 'Input table of microbial counts.', 60 | 'metabolites': 'Input table of metabolite intensities.', 61 | }, 62 | output_descriptions={ 63 | 'conditionals': 'Mean-centered Conditional log-probabilities.', 64 | 'conditional_biplot': 'Biplot of microbe-metabolite vectors.', 65 | }, 66 | parameter_descriptions={ 67 | 'metadata': 'Sample metadata table with covariates of interest.', 68 | 'training_column': "The metadata column specifying which " 69 | "samples are for training/testing. " 70 | "Entries must be marked `Train` for training " 71 | "examples and `Test` for testing examples. ", 72 | 'num_testing_examples': "The number of random examples to select " 73 | "if `training_column` isn't specified.", 74 | 'epochs': 'The total number of iterations over the entire dataset.', 75 | 'equalize_biplot': 'Biplot arrows and points are on the same scale.', 76 | 'batch_size': 'The number of samples to be evaluated per ' 77 | 'training iteration.', 78 | 'arm_the_gpu': 'Specifies whether or not to use the GPU.', 79 | 'input_prior': 'Width of normal prior for the microbial ' 80 | 'coefficients. Smaller values will regularize ' 81 | 'parameters towards zero. Values must be greater ' 82 | 'than 0.', 83 | 'output_prior': 'Width of normal prior for the metabolite ' 84 | 'coefficients. Smaller values will regularize ' 85 | 'parameters towards zero. Values must be greater ' 86 | 'than 0.', 87 | 'learning_rate': 'Gradient descent decay rate.' 88 | }, 89 | name='Microbe metabolite vectors', 90 | description="Performs bi-loglinear multinomial regression and calculates " 91 | "the conditional probability ranks of metabolite " 92 | "co-occurence given the microbe presence.", 93 | citations=[] 94 | ) 95 | 96 | plugin.visualizers.register_function( 97 | function=heatmap, 98 | inputs={'ranks': FeatureData[Conditional]}, 99 | parameters={ 100 | 'microbe_metadata': MetadataColumn[Categorical], 101 | 'metabolite_metadata': MetadataColumn[Categorical], 102 | 'method': Str % Choices(_heatmap_choices['method']), 103 | 'metric': Str % Choices(_heatmap_choices['metric']), 104 | 'color_palette': Str % Choices(_cmaps['heatmap']), 105 | 'margin_palette': Str % Choices(_cmaps['margins']), 106 | 'x_labels': Bool, 107 | 'y_labels': Bool, 108 | 'level': Int % Range(-1, None), 109 | 'row_center': Bool, 110 | }, 111 | input_descriptions={'ranks': 'Conditional probabilities.'}, 112 | parameter_descriptions={ 113 | 'microbe_metadata': 'Optional microbe metadata for annotating plots.', 114 | 'metabolite_metadata': 'Optional metabolite metadata for annotating ' 115 | 'plots.', 116 | 'method': 'Hierarchical clustering method used in clustermap.', 117 | 'metric': 'Distance metric used in clustermap.', 118 | 'color_palette': 'Color palette for clustermap.', 119 | 'margin_palette': 'Name of color palette to use for annotating ' 120 | 'metadata along margin(s) of clustermap.', 121 | 'x_labels': 'Plot x-axis (metabolite) labels?', 122 | 'y_labels': 'Plot y-axis (microbe) labels?', 123 | 'level': 'taxonomic level for annotating clustermap. Set to -1 if not ' 124 | 'parsing semicolon-delimited taxonomies or wish to print ' 125 | 'entire annotation.', 126 | 'row_center': 'Center conditional probability table ' 127 | 'around average row.' 128 | }, 129 | name='Conditional probability heatmap', 130 | description="Generate heatmap depicting mmvec conditional probabilities.", 131 | citations=[] 132 | ) 133 | 134 | plugin.visualizers.register_function( 135 | function=paired_heatmap, 136 | inputs={'ranks': FeatureData[Conditional], 137 | 'microbes_table': FeatureTable[Frequency], 138 | 'metabolites_table': FeatureTable[Frequency]}, 139 | parameters={ 140 | 'microbe_metadata': MetadataColumn[Categorical], 141 | 'features': List[Str], 142 | 'top_k_microbes': Int % Range(0, None), 143 | 'color_palette': Str % Choices(_cmaps['heatmap']), 144 | 'normalize': Str % Choices(['log10', 'z_score_col', 'z_score_row', 145 | 'rel_row', 'rel_col', 'None']), 146 | 'top_k_metabolites': Int % Range(1, None) | Str % Choices(['all']), 147 | 'keep_top_samples': Bool, 148 | 'level': Int % Range(-1, None), 149 | 'row_center': Bool, 150 | }, 151 | input_descriptions={'ranks': 'Conditional probabilities.', 152 | 'microbes_table': 'Microbial feature abundances.', 153 | 'metabolites_table': 'Metabolite feature abundances.'}, 154 | parameter_descriptions={ 155 | 'microbe_metadata': 'Optional microbe metadata for annotating plots.', 156 | 'features': 'Microbial feature IDs to display in heatmap. Use this ' 157 | 'parameter to include named feature IDs in the heatmap. ' 158 | 'Can be used in conjunction with top_k_microbes, in which ' 159 | 'case named features will be displayed first, then top ' 160 | 'microbial features in order of log conditional ' 161 | 'probability maximum values.', 162 | 'top_k_microbes': 'Select top k microbes (those with the highest ' 163 | 'relative abundances) to display on the heatmap. ' 164 | 'Set to "all" to display all metabolites.', 165 | 'color_palette': 'Color palette for clustermap.', 166 | 'normalize': 'Optionally normalize heatmap values by columns or rows.', 167 | 'top_k_metabolites': 'Select top k metabolites associated with each ' 168 | 'of the chosen features to display on heatmap.', 169 | 'keep_top_samples': 'Display only samples in which at least one of ' 170 | 'the selected microbes is the most abundant ' 171 | 'feature.', 172 | 'level': 'taxonomic level for annotating clustermap. Set to -1 if not ' 173 | 'parsing semicolon-delimited taxonomies or wish to print ' 174 | 'entire annotation.', 175 | 'row_center': 'Center conditional probability table ' 176 | 'around average row.' 177 | }, 178 | name='Paired feature abundance heatmaps', 179 | description="Generate paired heatmaps that depict microbial and " 180 | "metabolite feature abundances. The left panel displays the " 181 | "abundance of each selected microbial feature in each sample. " 182 | "The right panel displays the abundances of the top k " 183 | "metabolites most highly correlated with these microbes in " 184 | "each sample. The y-axis (sample axis) is shared between each " 185 | "panel.", 186 | citations=[] 187 | ) 188 | 189 | 190 | plugin.visualizers.register_function( 191 | function=summarize_single, 192 | inputs={ 193 | 'model_stats': SampleData[MMvecStats] 194 | }, 195 | parameters={}, 196 | input_descriptions={ 197 | 'model_stats': ( 198 | "Summary information produced by running " 199 | "`qiime mmvec paired-omics`." 200 | ) 201 | }, 202 | parameter_descriptions={ 203 | }, 204 | name='MMvec summary statistics', 205 | description=( 206 | "Visualize the convergence statistics from running " 207 | "`qiime mmvec paired-omics`, giving insight " 208 | "into how the model fit to your data." 209 | ) 210 | ) 211 | 212 | plugin.visualizers.register_function( 213 | function=summarize_paired, 214 | inputs={ 215 | 'model_stats': SampleData[MMvecStats], 216 | 'baseline_stats': SampleData[MMvecStats] 217 | }, 218 | parameters={}, 219 | input_descriptions={ 220 | 221 | 'model_stats': ( 222 | "Summary information for the reference model, produced by running " 223 | "`qiime mmvec paired-omics`." 224 | ), 225 | 'baseline_stats': ( 226 | "Summary information for the baseline model, produced by running " 227 | "`qiime mmvec paired-omics`." 228 | ) 229 | 230 | }, 231 | parameter_descriptions={ 232 | }, 233 | name='Paired MMvec summary statistics', 234 | description=( 235 | "Visualize the convergence statistics from two MMvec models, " 236 | "giving insight into how the models fit to your data. " 237 | "The produced visualization includes a 'pseudo-Q-squared' value." 238 | ) 239 | ) 240 | 241 | # Register types 242 | plugin.register_formats(MMvecStatsFormat, MMvecStatsDirFmt) 243 | plugin.register_semantic_types(MMvecStats) 244 | plugin.register_semantic_type_to_format( 245 | SampleData[MMvecStats], MMvecStatsDirFmt) 246 | 247 | plugin.register_formats(ConditionalFormat, ConditionalDirFmt) 248 | plugin.register_semantic_types(Conditional) 249 | plugin.register_semantic_type_to_format( 250 | FeatureData[Conditional], ConditionalDirFmt) 251 | 252 | importlib.import_module('mmvec.q2._transformer') 253 | -------------------------------------------------------------------------------- /mmvec/q2/tests/test_method.py: -------------------------------------------------------------------------------- 1 | import biom 2 | import unittest 3 | import numpy as np 4 | import tensorflow as tf 5 | from mmvec.q2._method import paired_omics 6 | from mmvec.util import random_multimodal 7 | from skbio.stats.composition import clr_inv 8 | from scipy.stats import spearmanr 9 | import numpy.testing as npt 10 | 11 | 12 | class TestMMvec(unittest.TestCase): 13 | 14 | def setUp(self): 15 | np.random.seed(1) 16 | res = random_multimodal( 17 | num_microbes=8, num_metabolites=8, num_samples=150, 18 | latent_dim=2, sigmaQ=2, 19 | microbe_total=1000, metabolite_total=10000, seed=1 20 | ) 21 | (self.microbes, self.metabolites, self.X, self.B, 22 | self.U, self.Ubias, self.V, self.Vbias) = res 23 | n, d1 = self.microbes.shape 24 | n, d2 = self.metabolites.shape 25 | 26 | self.microbes = biom.Table(self.microbes.values.T, 27 | self.microbes.columns, 28 | self.microbes.index) 29 | self.metabolites = biom.Table(self.metabolites.values.T, 30 | self.metabolites.columns, 31 | self.metabolites.index) 32 | U_ = np.hstack( 33 | (np.ones((self.U.shape[0], 1)), self.Ubias, self.U)) 34 | V_ = np.vstack( 35 | (self.Vbias, np.ones((1, self.V.shape[1])), self.V)) 36 | 37 | uv = U_ @ V_ 38 | h = np.zeros((d1, 1)) 39 | self.exp_ranks = clr_inv(np.hstack((h, uv))) 40 | 41 | def test_fit(self): 42 | np.random.seed(1) 43 | tf.reset_default_graph() 44 | tf.set_random_seed(0) 45 | latent_dim = 2 46 | res_ranks, res_biplot, _ = paired_omics( 47 | self.microbes, self.metabolites, 48 | epochs=1000, latent_dim=latent_dim, 49 | min_feature_count=1, learning_rate=0.1 50 | ) 51 | res_ranks = clr_inv(res_ranks.T) 52 | s_r, s_p = spearmanr(np.ravel(res_ranks), np.ravel(self.exp_ranks)) 53 | 54 | self.assertGreater(s_r, 0.5) 55 | self.assertLess(s_p, 1e-2) 56 | 57 | # make sure the biplot is of the correct dimensions 58 | npt.assert_allclose( 59 | res_biplot.samples.shape, 60 | np.array([self.microbes.shape[0], latent_dim])) 61 | npt.assert_allclose( 62 | res_biplot.features.shape, 63 | np.array([self.metabolites.shape[0], latent_dim])) 64 | 65 | # make sure that the biplot has the correct ordering 66 | self.assertGreater(res_biplot.proportion_explained[0], 67 | res_biplot.proportion_explained[1]) 68 | self.assertGreater(res_biplot.eigvals[0], 69 | res_biplot.eigvals[1]) 70 | 71 | def test_equalize_sv(self): 72 | np.random.seed(1) 73 | tf.reset_default_graph() 74 | tf.set_random_seed(0) 75 | latent_dim = 2 76 | res_ranks, res_biplot, _ = paired_omics( 77 | self.microbes, self.metabolites, 78 | epochs=1000, latent_dim=latent_dim, 79 | min_feature_count=1, learning_rate=0.1, 80 | equalize_biplot=True 81 | ) 82 | # make sure the biplot is of the correct dimensions 83 | npt.assert_allclose( 84 | res_biplot.samples.shape, 85 | np.array([self.microbes.shape[0], latent_dim])) 86 | npt.assert_allclose( 87 | res_biplot.features.shape, 88 | np.array([self.metabolites.shape[0], latent_dim])) 89 | 90 | # make sure that the biplot has the correct ordering 91 | self.assertGreater(res_biplot.proportion_explained[0], 92 | res_biplot.proportion_explained[1]) 93 | self.assertGreater(res_biplot.eigvals[0], 94 | res_biplot.eigvals[1]) 95 | 96 | 97 | if __name__ == "__main__": 98 | unittest.main() 99 | -------------------------------------------------------------------------------- /mmvec/q2/tests/test_visualizers.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import pandas as pd 3 | from qiime2 import Artifact, CategoricalMetadataColumn 4 | from qiime2.plugins import mmvec 5 | import biom 6 | import numpy as np 7 | 8 | 9 | # these tests just make sure the visualizer runs; nuts + bolts are tested in 10 | # the main package. 11 | class TestHeatmap(unittest.TestCase): 12 | 13 | def setUp(self): 14 | _ranks = pd.DataFrame([[4.1, 1.3, 2.1], [0.1, 0.3, 0.2], 15 | [2.2, 4.3, 3.2], [-6.3, -4.4, 2.1]], 16 | index=pd.Index([c for c in 'ABCD'], name='id'), 17 | columns=['m1', 'm2', 'm3']).T 18 | self.ranks = Artifact.import_data('FeatureData[Conditional]', _ranks) 19 | self.taxa = CategoricalMetadataColumn(pd.Series([ 20 | 'k__Bacteria; p__Proteobacteria; c__Deltaproteobacteria; ' 21 | 'o__Desulfobacterales; f__Desulfobulbaceae; g__; s__', 22 | 'k__Bacteria; p__Cyanobacteria; c__Chloroplast; o__Streptophyta', 23 | 'k__Bacteria; p__Proteobacteria; c__Alphaproteobacteria; ' 24 | 'o__Rickettsiales; f__mitochondria; g__Lardizabala; s__biternata', 25 | 'k__Archaea; p__Euryarchaeota; c__Methanomicrobia; ' 26 | 'o__Methanosarcinales; f__Methanosarcinaceae; g__Methanosarcina'], 27 | index=pd.Index([c for c in 'ABCD'], name='feature-id'), 28 | name='Taxon')) 29 | self.metabolites = CategoricalMetadataColumn(pd.Series([ 30 | 'amino acid', 'carbohydrate', 'drug metabolism'], 31 | index=pd.Index(['m1', 'm2', 'm3'], name='feature-id'), 32 | name='Super Pathway')) 33 | 34 | def test_heatmap_default(self): 35 | mmvec.actions.heatmap(self.ranks, self.taxa, self.metabolites) 36 | 37 | def test_heatmap_no_metadata(self): 38 | mmvec.actions.heatmap(self.ranks) 39 | 40 | def test_heatmap_one_metadata(self): 41 | mmvec.actions.heatmap(self.ranks, self.taxa, None) 42 | 43 | def test_heatmap_no_taxonomy_parsing(self): 44 | mmvec.actions.heatmap(self.ranks, self.taxa, None, level=-1) 45 | 46 | def test_heatmap_plot_axis_labels(self): 47 | mmvec.actions.heatmap(self.ranks, x_labels=True, y_labels=True) 48 | 49 | 50 | class TestPairedHeatmap(unittest.TestCase): 51 | 52 | def setUp(self): 53 | _ranks = pd.DataFrame([[4.1, 1.3, 2.1], [0.1, 0.3, 0.2], 54 | [2.2, 4.3, 3.2], [-6.3, -4.4, 2.1]], 55 | index=pd.Index([c for c in 'ABCD'], name='id'), 56 | columns=['m1', 'm2', 'm3']).T 57 | self.ranks = Artifact.import_data('FeatureData[Conditional]', _ranks) 58 | self.taxa = CategoricalMetadataColumn(pd.Series([ 59 | 'k__Bacteria; p__Proteobacteria; c__Deltaproteobacteria; ' 60 | 'o__Desulfobacterales; f__Desulfobulbaceae; g__; s__', 61 | 'k__Bacteria; p__Cyanobacteria; c__Chloroplast; o__Streptophyta', 62 | 'k__Bacteria; p__Proteobacteria; c__Alphaproteobacteria; ' 63 | 'o__Rickettsiales; f__mitochondria; g__Lardizabala; s__biternata', 64 | 'k__Archaea; p__Euryarchaeota; c__Methanomicrobia; ' 65 | 'o__Methanosarcinales; f__Methanosarcinaceae; g__Methanosarcina'], 66 | index=pd.Index([c for c in 'ABCD'], name='feature-id'), 67 | name='Taxon')) 68 | metabolites = biom.Table( 69 | np.array([[9, 8, 2], [2, 1, 2], [9, 4, 5], [8, 8, 7]]), 70 | sample_ids=['s1', 's2', 's3'], 71 | observation_ids=['m1', 'm2', 'm3', 'm4']) 72 | self.metabolites = Artifact.import_data( 73 | 'FeatureTable[Frequency]', metabolites) 74 | microbes = biom.Table( 75 | np.array([[1, 2, 3], [3, 6, 3], [1, 9, 9], [8, 8, 7]]), 76 | sample_ids=['s1', 's2', 's3'], observation_ids=[i for i in 'ABCD']) 77 | self.microbes = Artifact.import_data( 78 | 'FeatureTable[Frequency]', microbes) 79 | 80 | def test_paired_heatmaps_single_feature(self): 81 | mmvec.actions.paired_heatmap( 82 | self.ranks, self.microbes, self.metabolites, features=['C'], 83 | microbe_metadata=self.taxa) 84 | 85 | def test_paired_heatmaps_multifeature(self): 86 | mmvec.actions.paired_heatmap( 87 | self.ranks, self.microbes, self.metabolites, features=['A', 'C']) 88 | 89 | def test_paired_heatmaps_fail_on_unknown_feature(self): 90 | with self.assertRaisesRegex(ValueError, "must represent feature IDs"): 91 | mmvec.actions.paired_heatmap( 92 | self.ranks, self.microbes, self.metabolites, 93 | features=['A', 'barf']) 94 | 95 | 96 | if __name__ == "__main__": 97 | unittest.main() 98 | -------------------------------------------------------------------------------- /mmvec/tests/data/ms_hits.txt: -------------------------------------------------------------------------------- 1 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 2 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 3 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 4 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 5 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 6 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 7 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 8 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 9 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 10 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 11 | 1.000000000000000056e-01 9.000000000000000222e-01 0.000000000000000000e+00 0.000000000000000000e+00 12 | 1.000000000000000056e-01 9.000000000000000222e-01 0.000000000000000000e+00 0.000000000000000000e+00 13 | 1.000000000000000056e-01 9.000000000000000222e-01 0.000000000000000000e+00 0.000000000000000000e+00 14 | 1.000000000000000056e-01 9.000000000000000222e-01 0.000000000000000000e+00 0.000000000000000000e+00 15 | 1.000000000000000056e-01 9.000000000000000222e-01 0.000000000000000000e+00 0.000000000000000000e+00 16 | 1.000000000000000056e-01 9.000000000000000222e-01 0.000000000000000000e+00 0.000000000000000000e+00 17 | 1.000000000000000056e-01 9.000000000000000222e-01 0.000000000000000000e+00 0.000000000000000000e+00 18 | 1.000000000000000056e-01 9.000000000000000222e-01 0.000000000000000000e+00 0.000000000000000000e+00 19 | 1.000000000000000056e-01 9.000000000000000222e-01 0.000000000000000000e+00 0.000000000000000000e+00 20 | 1.000000000000000056e-01 9.000000000000000222e-01 0.000000000000000000e+00 0.000000000000000000e+00 21 | 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 22 | 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 23 | 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 24 | 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 25 | 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 26 | 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 27 | 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 28 | 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 29 | 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 30 | 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 31 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 32 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 33 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 34 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 35 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 36 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 37 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 38 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 39 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 40 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 41 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 42 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 43 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 44 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 45 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 46 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 47 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 48 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 49 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 50 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 51 | 2.999999999999999889e-01 6.999999999999999556e-01 0.000000000000000000e+00 0.000000000000000000e+00 52 | 2.999999999999999889e-01 6.999999999999999556e-01 0.000000000000000000e+00 0.000000000000000000e+00 53 | 2.999999999999999889e-01 6.999999999999999556e-01 0.000000000000000000e+00 0.000000000000000000e+00 54 | 2.999999999999999889e-01 6.999999999999999556e-01 0.000000000000000000e+00 0.000000000000000000e+00 55 | 2.999999999999999889e-01 6.999999999999999556e-01 0.000000000000000000e+00 0.000000000000000000e+00 56 | 2.999999999999999889e-01 6.999999999999999556e-01 0.000000000000000000e+00 0.000000000000000000e+00 57 | 2.999999999999999889e-01 6.999999999999999556e-01 0.000000000000000000e+00 0.000000000000000000e+00 58 | 2.999999999999999889e-01 6.999999999999999556e-01 0.000000000000000000e+00 0.000000000000000000e+00 59 | 2.999999999999999889e-01 6.999999999999999556e-01 0.000000000000000000e+00 0.000000000000000000e+00 60 | 2.999999999999999889e-01 6.999999999999999556e-01 0.000000000000000000e+00 0.000000000000000000e+00 61 | 5.000000000000000000e-01 5.000000000000000000e-01 0.000000000000000000e+00 0.000000000000000000e+00 62 | 5.000000000000000000e-01 5.000000000000000000e-01 0.000000000000000000e+00 0.000000000000000000e+00 63 | 5.000000000000000000e-01 5.000000000000000000e-01 0.000000000000000000e+00 0.000000000000000000e+00 64 | 5.000000000000000000e-01 5.000000000000000000e-01 0.000000000000000000e+00 0.000000000000000000e+00 65 | 5.000000000000000000e-01 5.000000000000000000e-01 0.000000000000000000e+00 0.000000000000000000e+00 66 | 5.000000000000000000e-01 5.000000000000000000e-01 0.000000000000000000e+00 0.000000000000000000e+00 67 | 5.000000000000000000e-01 5.000000000000000000e-01 0.000000000000000000e+00 0.000000000000000000e+00 68 | 5.000000000000000000e-01 5.000000000000000000e-01 0.000000000000000000e+00 0.000000000000000000e+00 69 | 5.000000000000000000e-01 5.000000000000000000e-01 0.000000000000000000e+00 0.000000000000000000e+00 70 | 5.000000000000000000e-01 5.000000000000000000e-01 0.000000000000000000e+00 0.000000000000000000e+00 71 | 2.000000000000000111e-01 5.999999999999999778e-01 1.000000000000000056e-01 1.000000000000000056e-01 72 | 2.000000000000000111e-01 5.999999999999999778e-01 1.000000000000000056e-01 1.000000000000000056e-01 73 | 2.000000000000000111e-01 5.999999999999999778e-01 1.000000000000000056e-01 1.000000000000000056e-01 74 | 2.000000000000000111e-01 5.999999999999999778e-01 1.000000000000000056e-01 1.000000000000000056e-01 75 | 2.000000000000000111e-01 5.999999999999999778e-01 1.000000000000000056e-01 1.000000000000000056e-01 76 | 2.000000000000000111e-01 5.999999999999999778e-01 1.000000000000000056e-01 1.000000000000000056e-01 77 | 2.000000000000000111e-01 5.999999999999999778e-01 1.000000000000000056e-01 1.000000000000000056e-01 78 | 2.000000000000000111e-01 5.999999999999999778e-01 1.000000000000000056e-01 1.000000000000000056e-01 79 | 2.000000000000000111e-01 5.999999999999999778e-01 1.000000000000000056e-01 1.000000000000000056e-01 80 | 2.000000000000000111e-01 5.999999999999999778e-01 1.000000000000000056e-01 1.000000000000000056e-01 81 | 1.000000000000000056e-01 8.000000000000000444e-01 0.000000000000000000e+00 1.000000000000000056e-01 82 | 1.000000000000000056e-01 8.000000000000000444e-01 0.000000000000000000e+00 1.000000000000000056e-01 83 | 1.000000000000000056e-01 8.000000000000000444e-01 0.000000000000000000e+00 1.000000000000000056e-01 84 | 1.000000000000000056e-01 8.000000000000000444e-01 0.000000000000000000e+00 1.000000000000000056e-01 85 | 1.000000000000000056e-01 8.000000000000000444e-01 0.000000000000000000e+00 1.000000000000000056e-01 86 | 1.000000000000000056e-01 8.000000000000000444e-01 0.000000000000000000e+00 1.000000000000000056e-01 87 | 1.000000000000000056e-01 8.000000000000000444e-01 0.000000000000000000e+00 1.000000000000000056e-01 88 | 1.000000000000000056e-01 8.000000000000000444e-01 0.000000000000000000e+00 1.000000000000000056e-01 89 | 1.000000000000000056e-01 8.000000000000000444e-01 0.000000000000000000e+00 1.000000000000000056e-01 90 | 1.000000000000000056e-01 8.000000000000000444e-01 0.000000000000000000e+00 1.000000000000000056e-01 91 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 92 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 93 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 94 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 95 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 96 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 97 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 98 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 99 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 100 | 2.000000000000000111e-01 8.000000000000000444e-01 0.000000000000000000e+00 0.000000000000000000e+00 101 | -------------------------------------------------------------------------------- /mmvec/tests/data/otu_hits.txt: -------------------------------------------------------------------------------- 1 | 0.000000000000000000e+00 2 | 0.000000000000000000e+00 3 | 0.000000000000000000e+00 4 | 0.000000000000000000e+00 5 | 0.000000000000000000e+00 6 | 1.000000000000000000e+00 7 | 1.000000000000000000e+00 8 | 1.000000000000000000e+00 9 | 1.000000000000000000e+00 10 | 1.000000000000000000e+00 11 | 0.000000000000000000e+00 12 | 0.000000000000000000e+00 13 | 0.000000000000000000e+00 14 | 0.000000000000000000e+00 15 | 0.000000000000000000e+00 16 | 1.000000000000000000e+00 17 | 1.000000000000000000e+00 18 | 1.000000000000000000e+00 19 | 1.000000000000000000e+00 20 | 1.000000000000000000e+00 21 | 0.000000000000000000e+00 22 | 0.000000000000000000e+00 23 | 0.000000000000000000e+00 24 | 0.000000000000000000e+00 25 | 0.000000000000000000e+00 26 | 0.000000000000000000e+00 27 | 1.000000000000000000e+00 28 | 1.000000000000000000e+00 29 | 1.000000000000000000e+00 30 | 1.000000000000000000e+00 31 | 0.000000000000000000e+00 32 | 0.000000000000000000e+00 33 | 0.000000000000000000e+00 34 | 1.000000000000000000e+00 35 | 1.000000000000000000e+00 36 | 1.000000000000000000e+00 37 | 1.000000000000000000e+00 38 | 1.000000000000000000e+00 39 | 1.000000000000000000e+00 40 | 1.000000000000000000e+00 41 | 0.000000000000000000e+00 42 | 0.000000000000000000e+00 43 | 0.000000000000000000e+00 44 | 1.000000000000000000e+00 45 | 1.000000000000000000e+00 46 | 1.000000000000000000e+00 47 | 1.000000000000000000e+00 48 | 1.000000000000000000e+00 49 | 1.000000000000000000e+00 50 | 1.000000000000000000e+00 51 | 0.000000000000000000e+00 52 | 0.000000000000000000e+00 53 | 0.000000000000000000e+00 54 | 0.000000000000000000e+00 55 | 1.000000000000000000e+00 56 | 1.000000000000000000e+00 57 | 1.000000000000000000e+00 58 | 1.000000000000000000e+00 59 | 1.000000000000000000e+00 60 | 1.000000000000000000e+00 61 | 0.000000000000000000e+00 62 | 0.000000000000000000e+00 63 | 0.000000000000000000e+00 64 | 0.000000000000000000e+00 65 | 1.000000000000000000e+00 66 | 1.000000000000000000e+00 67 | 1.000000000000000000e+00 68 | 1.000000000000000000e+00 69 | 1.000000000000000000e+00 70 | 1.000000000000000000e+00 71 | 0.000000000000000000e+00 72 | 0.000000000000000000e+00 73 | 0.000000000000000000e+00 74 | 0.000000000000000000e+00 75 | 0.000000000000000000e+00 76 | 0.000000000000000000e+00 77 | 0.000000000000000000e+00 78 | 1.000000000000000000e+00 79 | 1.000000000000000000e+00 80 | 1.000000000000000000e+00 81 | 0.000000000000000000e+00 82 | 0.000000000000000000e+00 83 | 0.000000000000000000e+00 84 | 0.000000000000000000e+00 85 | 0.000000000000000000e+00 86 | 0.000000000000000000e+00 87 | 1.000000000000000000e+00 88 | 1.000000000000000000e+00 89 | 1.000000000000000000e+00 90 | 1.000000000000000000e+00 91 | 0.000000000000000000e+00 92 | 0.000000000000000000e+00 93 | 0.000000000000000000e+00 94 | 0.000000000000000000e+00 95 | 0.000000000000000000e+00 96 | 0.000000000000000000e+00 97 | 0.000000000000000000e+00 98 | 1.000000000000000000e+00 99 | 1.000000000000000000e+00 100 | 1.000000000000000000e+00 101 | -------------------------------------------------------------------------------- /mmvec/tests/data/soil_metabolites.biom: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biocore/mmvec/88ca33b408a85b6bf90fae06982936247b860272/mmvec/tests/data/soil_metabolites.biom -------------------------------------------------------------------------------- /mmvec/tests/data/soil_microbes.biom: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biocore/mmvec/88ca33b408a85b6bf90fae06982936247b860272/mmvec/tests/data/soil_microbes.biom -------------------------------------------------------------------------------- /mmvec/tests/data/x_test.biom: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biocore/mmvec/88ca33b408a85b6bf90fae06982936247b860272/mmvec/tests/data/x_test.biom -------------------------------------------------------------------------------- /mmvec/tests/data/x_train.biom: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biocore/mmvec/88ca33b408a85b6bf90fae06982936247b860272/mmvec/tests/data/x_train.biom -------------------------------------------------------------------------------- /mmvec/tests/data/y_test.biom: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biocore/mmvec/88ca33b408a85b6bf90fae06982936247b860272/mmvec/tests/data/y_test.biom -------------------------------------------------------------------------------- /mmvec/tests/data/y_train.biom: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biocore/mmvec/88ca33b408a85b6bf90fae06982936247b860272/mmvec/tests/data/y_train.biom -------------------------------------------------------------------------------- /mmvec/tests/test_heatmap.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import pandas as pd 3 | from mmvec.heatmap import ( 4 | _parse_taxonomy_strings, _parse_heatmap_metadata_annotations, 5 | _process_microbe_metadata, _process_metabolite_metadata, 6 | _normalize_table) 7 | import pandas.util.testing as pdt 8 | 9 | 10 | class TestParseTaxonomyStrings(unittest.TestCase): 11 | 12 | def setUp(self): 13 | self.taxa = pd.Series([ 14 | 'k__Bacteria; p__Proteobacteria; c__Deltaproteobacteria; ' 15 | 'o__Desulfobacterales; f__Desulfobulbaceae; g__; s__', 16 | 'k__Bacteria; p__Cyanobacteria; c__Chloroplast; o__Streptophyta', 17 | 'k__Bacteria; p__Proteobacteria; c__Alphaproteobacteria; ' 18 | 'o__Rickettsiales; f__mitochondria; g__Lardizabala; s__biternata', 19 | 'k__Archaea; p__Euryarchaeota; c__Methanomicrobia; ' 20 | 'o__Methanosarcinales; f__Methanosarcinaceae; g__Methanosarcina', 21 | 'k__Bacteria; p__Proteobacteria; c__Alphaproteobacteria; ' 22 | 'o__Rickettsiales; f__mitochondria; g__Pavlova; s__lutheri', 23 | 'k__Archaea; p__[Parvarchaeota]; c__[Parvarchaea]; o__WCHD3-30', 24 | 'k__Bacteria; p__Proteobacteria; c__Alphaproteobacteria; ' 25 | 'o__Sphingomonadales; f__Sphingomonadaceae'], 26 | index=pd.Index([c for c in 'ABCDEFG'], name='feature-id'), 27 | name='Taxon') 28 | self.exp = pd.Series( 29 | ['s__', 'o__Streptophyta', 's__biternata', 'g__Methanosarcina', 30 | 's__lutheri', 'o__WCHD3-30', 'f__Sphingomonadaceae'], 31 | index=pd.Index([c for c in 'ABCDEFG'], name='feature-id'), 32 | name='Taxon') 33 | 34 | def test_parse_taxonomy_strings(self): 35 | exp = pd.Series(['p__Proteobacteria', 'p__Cyanobacteria', 36 | 'p__Proteobacteria', 'p__Euryarchaeota', 37 | 'p__Proteobacteria', 'p__[Parvarchaeota]', 38 | 'p__Proteobacteria'], 39 | index=pd.Index([c for c in 'ABCDEFG'], 40 | name='feature-id'), name='Taxon') 41 | obs = _parse_taxonomy_strings(self.taxa, level=2) 42 | pdt.assert_series_equal(exp, obs) 43 | 44 | def test_parse_taxonomy_strings_baserank(self): 45 | exp = pd.Series(['k__Bacteria', 'k__Bacteria', 'k__Bacteria', 46 | 'k__Archaea', 'k__Bacteria', 'k__Archaea', 47 | 'k__Bacteria'], 48 | index=pd.Index([c for c in 'ABCDEFG'], 49 | name='feature-id'), name='Taxon') 50 | obs = _parse_taxonomy_strings(self.taxa, level=1) 51 | pdt.assert_series_equal(exp, obs) 52 | 53 | def test_parse_taxonomy_strings_toprank(self): 54 | # expect top rank even if level is higher than depth of top rank 55 | obs = _parse_taxonomy_strings(self.taxa, level=7) 56 | pdt.assert_series_equal(self.exp, obs) 57 | 58 | def test_parse_taxonomy_strings_rank_out_of_range_is_top(self): 59 | # expect top rank even if level is higher than depth of top rank 60 | obs = _parse_taxonomy_strings(self.taxa, level=9) 61 | pdt.assert_series_equal(self.exp, obs) 62 | 63 | 64 | class TestHeatmapAnnotation(unittest.TestCase): 65 | 66 | def setUp(self): 67 | self.taxonomy = pd.Series( 68 | ['k__Bacteria', 'k__Archaea', 'k__Bacteria', 'k__Archaea'], 69 | index=pd.Index([c for c in 'ABCD'], name='id'), name='Taxon') 70 | 71 | def test_parse_heatmap_metadata_annotations_colorhelix(self): 72 | exp_cols = pd.Series( 73 | [[0.8377187772618228, 0.7593149036488329, 0.9153517040128891], 74 | [0.2539759281991313, 0.3490084835469758, 0.14482988411775732], 75 | [0.8377187772618228, 0.7593149036488329, 0.9153517040128891], 76 | [0.2539759281991313, 0.3490084835469758, 0.14482988411775732]], 77 | index=pd.Index([c for c in 'ABCD'], name='id'), name='Taxon') 78 | exp_classes = {'k__Archaea': [0.2539759281991313, 0.3490084835469758, 79 | 0.14482988411775732], 80 | 'k__Bacteria': [0.8377187772618228, 0.7593149036488329, 81 | 0.9153517040128891]} 82 | cols, classes = _parse_heatmap_metadata_annotations( 83 | self.taxonomy, 'colorhelix') 84 | pdt.assert_series_equal(exp_cols, cols) 85 | self.assertDictEqual(exp_classes, classes) 86 | 87 | def test_parse_heatmap_metadata_annotations_magma(self): 88 | exp_cols = pd.Series( 89 | [(0.944006, 0.377643, 0.365136), (0.445163, 0.122724, 0.506901), 90 | (0.944006, 0.377643, 0.365136), (0.445163, 0.122724, 0.506901)], 91 | index=pd.Index([c for c in 'ABCD'], name='id'), name='Taxon') 92 | exp_classes = {'k__Archaea': (0.445163, 0.122724, 0.506901), 93 | 'k__Bacteria': (0.944006, 0.377643, 0.365136)} 94 | cols, classes = _parse_heatmap_metadata_annotations( 95 | self.taxonomy, 'magma') 96 | pdt.assert_series_equal(exp_cols, cols) 97 | self.assertDictEqual(exp_classes, classes) 98 | 99 | 100 | class TestMetadataProcessing(unittest.TestCase): 101 | 102 | def setUp(self): 103 | self.taxonomy = pd.Series( 104 | ['k__Bacteria', 'k__Archaea', 'k__Bacteria'], 105 | index=pd.Index([c for c in 'ABC']), name='Taxon') 106 | self.metabolites = pd.Series([ 107 | 'amino acid', 'carbohydrate', 'drug metabolism'], 108 | index=pd.Index(['a', 'b', 'c']), name='Super Pathway') 109 | self.ranks = pd.DataFrame( 110 | [[4, 1, 2, 3], [1, 2, 1, 2], [2, 4, 3, 1], [6, 4, 2, 3]], 111 | index=pd.Index([c for c in 'ABCD']), columns=[c for c in 'abcd']) 112 | 113 | # test that metadata processing works, filters ranks, and works in sequence 114 | def test_process_metadata(self): 115 | # filter on taxonomy, taxonomy parser/annotation tested above 116 | with self.assertWarnsRegex(UserWarning, "microbe IDs are present"): 117 | res = _process_microbe_metadata( 118 | self.ranks, self.taxonomy, -1, 'magma') 119 | ranks_filtered = pd.DataFrame( 120 | [[4, 1, 2, 3], [1, 2, 1, 2], [2, 4, 3, 1]], 121 | index=pd.Index([c for c in 'ABC']), columns=[c for c in 'abcd']) 122 | pdt.assert_frame_equal(ranks_filtered, res[1]) 123 | # filter on metabolites, annotation tested above 124 | with self.assertWarnsRegex(UserWarning, "metabolite IDs are present"): 125 | res = _process_metabolite_metadata( 126 | ranks_filtered, self.metabolites, 'magma') 127 | ranks_filtered = ranks_filtered[[c for c in 'abc']] 128 | pdt.assert_frame_equal(ranks_filtered, res[1]) 129 | 130 | 131 | class TestNormalize(unittest.TestCase): 132 | 133 | def setUp(self): 134 | self.tab = pd.DataFrame({'a': [1, 2, 3], 'b': [3, 4, 3]}) 135 | 136 | def test_normalize_table_log10(self): 137 | res = _normalize_table(self.tab, 'log10') 138 | exp = pd.DataFrame( 139 | {'a': {0: 0.3010299956639812, 1: 0.47712125471966244, 140 | 2: 0.6020599913279624}, 141 | 'b': {0: 0.6020599913279624, 1: 0.6989700043360189, 142 | 2: 0.6020599913279624}}) 143 | pdt.assert_frame_equal(res, exp) 144 | 145 | def test_normalize_table_z_score_col(self): 146 | res = _normalize_table(self.tab, 'z_score_col') 147 | exp = pd.DataFrame({'a': {0: -1.0, 1: 0.0, 2: 1.0}, 148 | 'b': {0: -0.577350269189626, 1: 1.154700538379251, 149 | 2: -0.577350269189626}}) 150 | pdt.assert_frame_equal(res, exp) 151 | 152 | def test_normalize_table_rel_col(self): 153 | res = _normalize_table(self.tab, 'rel_col') 154 | exp = pd.DataFrame({'a': {0: 0.16666666666666666, 155 | 1: 0.3333333333333333, 2: 0.5}, 156 | 'b': {0: 0.3, 1: 0.4, 2: 0.3}}) 157 | pdt.assert_frame_equal(res, exp) 158 | 159 | def test_normalize_table_z_score_row(self): 160 | res = _normalize_table(self.tab, 'z_score_row') 161 | exp = pd.DataFrame({'a': {0: -0.7071067811865475, 162 | 1: -0.7071067811865475, 2: 0.0}, 163 | 'b': {0: 0.7071067811865475, 1: 0.7071067811865475, 164 | 2: 0.0}}) 165 | pdt.assert_frame_equal(res, exp) 166 | 167 | def test_normalize_table_rel_row(self): 168 | res = _normalize_table(self.tab, 'rel_row') 169 | exp = pd.DataFrame({'a': {0: 0.25, 1: 0.3333333333333333, 2: 0.5}, 170 | 'b': {0: 0.75, 1: 0.6666666666666666, 2: 0.5}}) 171 | pdt.assert_frame_equal(res, exp) 172 | -------------------------------------------------------------------------------- /mmvec/tests/test_multimodal.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import shutil 3 | import unittest 4 | import numpy as np 5 | import pandas as pd 6 | from biom import load_table 7 | from skbio.stats.composition import clr_inv as softmax 8 | from skbio.util import get_data_path 9 | from scipy.stats import spearmanr 10 | from scipy.sparse import coo_matrix 11 | from scipy.spatial.distance import pdist 12 | from mmvec.multimodal import MMvec 13 | from mmvec.util import random_multimodal 14 | from tensorflow import set_random_seed 15 | import tensorflow as tf 16 | 17 | 18 | class TestMMvec(unittest.TestCase): 19 | def setUp(self): 20 | # build small simulation 21 | np.random.seed(1) 22 | res = random_multimodal( 23 | num_microbes=8, num_metabolites=8, num_samples=150, 24 | latent_dim=2, sigmaQ=2, 25 | microbe_total=1000, metabolite_total=10000, seed=1 26 | ) 27 | (self.microbes, self.metabolites, self.X, self.B, 28 | self.U, self.Ubias, self.V, self.Vbias) = res 29 | num_train = 10 30 | self.trainX = self.microbes.iloc[:-num_train] 31 | self.testX = self.microbes.iloc[-num_train:] 32 | self.trainY = self.metabolites.iloc[:-num_train] 33 | self.testY = self.metabolites.iloc[-num_train:] 34 | 35 | def tearDown(self): 36 | # remove all log directories 37 | for r in glob.glob("logdir*"): 38 | shutil.rmtree(r) 39 | 40 | def test_fit(self): 41 | np.random.seed(1) 42 | tf.reset_default_graph() 43 | n, d1 = self.trainX.shape 44 | n, d2 = self.trainY.shape 45 | with tf.Graph().as_default(), tf.Session() as session: 46 | set_random_seed(0) 47 | model = MMvec(beta_1=0.8, beta_2=0.9, latent_dim=2) 48 | model(session, 49 | coo_matrix(self.trainX.values), self.trainY.values, 50 | coo_matrix(self.testX.values), self.testY.values) 51 | model.fit(epoch=1000) 52 | 53 | U_ = np.hstack( 54 | (np.ones((self.U.shape[0], 1)), self.Ubias, self.U)) 55 | V_ = np.vstack( 56 | (self.Vbias, np.ones((1, self.V.shape[1])), self.V)) 57 | 58 | u_r, u_p = spearmanr(pdist(model.U), pdist(self.U)) 59 | v_r, v_p = spearmanr(pdist(model.V.T), pdist(self.V.T)) 60 | 61 | res = softmax(model.ranks()) 62 | exp = softmax(np.hstack((np.zeros((d1, 1)), U_ @ V_))) 63 | s_r, s_p = spearmanr(np.ravel(res), np.ravel(exp)) 64 | 65 | self.assertGreater(u_r, 0.5) 66 | self.assertGreater(v_r, 0.5) 67 | self.assertGreater(s_r, 0.5) 68 | self.assertLess(u_p, 5e-2) 69 | self.assertLess(v_p, 5e-2) 70 | self.assertLess(s_p, 5e-2) 71 | 72 | # sanity check cross validation 73 | self.assertLess(model.cv.eval(), 500) 74 | 75 | 76 | class TestMMvecSoilsBenchmark(unittest.TestCase): 77 | def setUp(self): 78 | self.microbes = load_table(get_data_path('soil_microbes.biom')) 79 | self.metabolites = load_table(get_data_path('soil_metabolites.biom')) 80 | X = self.microbes.to_dataframe().T 81 | Y = self.metabolites.to_dataframe().T 82 | X = X.loc[Y.index] 83 | self.trainX = X.iloc[:-2] 84 | self.trainY = Y.iloc[:-2] 85 | self.testX = X.iloc[-2:] 86 | self.testY = Y.iloc[-2:] 87 | 88 | def tearDown(self): 89 | # remove all log directories 90 | for r in glob.glob("logdir*"): 91 | shutil.rmtree(r) 92 | 93 | def test_soils(self): 94 | np.random.seed(1) 95 | tf.reset_default_graph() 96 | n, d1 = self.trainX.shape 97 | n, d2 = self.trainY.shape 98 | 99 | with tf.Graph().as_default(), tf.Session() as session: 100 | set_random_seed(0) 101 | model = MMvec(beta_1=0.8, beta_2=0.9, latent_dim=1, 102 | learning_rate=1e-3) 103 | model(session, 104 | coo_matrix(self.trainX.values), self.trainY.values, 105 | coo_matrix(self.testX.values), self.testY.values) 106 | model.fit(epoch=1000) 107 | 108 | ranks = pd.DataFrame( 109 | model.ranks(), 110 | index=self.microbes.ids(axis='observation'), 111 | columns=self.metabolites.ids(axis='observation')) 112 | 113 | microcoleus_metabolites = [ 114 | '(3-methyladenine)', '7-methyladenine', '4-guanidinobutanoate', 115 | 'uracil', 'xanthine', 'hypoxanthine', '(N6-acetyl-lysine)', 116 | 'cytosine', 'N-acetylornithine', 'N-acetylornithine', 117 | 'succinate', 'adenosine', 'guanine', 'adenine'] 118 | mprobs = ranks.loc['rplo 1 (Cyanobacteria)'] 119 | self.assertEqual(np.sum(mprobs.loc[microcoleus_metabolites] > 0), 120 | len(microcoleus_metabolites)) 121 | 122 | 123 | class TestMMvecBenchmark(unittest.TestCase): 124 | def setUp(self): 125 | # build small simulation 126 | res = random_multimodal( 127 | num_microbes=100, num_metabolites=1000, num_samples=300, 128 | latent_dim=2, sigmaQ=2, 129 | microbe_total=5000, metabolite_total=10000, seed=1 130 | ) 131 | (self.microbes, self.metabolites, self.X, self.B, 132 | self.U, self.Ubias, self.V, self.Vbias) = res 133 | num_train = 10 134 | self.trainX = self.microbes.iloc[:-num_train] 135 | self.testX = self.microbes.iloc[-num_train:] 136 | self.trainY = self.metabolites.iloc[:-num_train] 137 | self.testY = self.metabolites.iloc[-num_train:] 138 | 139 | @unittest.skip("Only for benchmarking") 140 | def test_gpu(self): 141 | np.random.seed(1) 142 | tf.reset_default_graph() 143 | n, d1 = self.trainX.shape 144 | n, d2 = self.trainY.shape 145 | 146 | with tf.Graph().as_default(), tf.Session() as session: 147 | set_random_seed(0) 148 | model = MMvec(beta_1=0.8, beta_2=0.9, latent_dim=2, 149 | batch_size=2000, 150 | device_name="/device:GPU:0") 151 | model(session, 152 | coo_matrix(self.trainX.values), self.trainY.values, 153 | coo_matrix(self.testX.values), self.testY.values) 154 | model.fit(epoch=10000) 155 | 156 | @unittest.skip("Only for benchmarking") 157 | def test_cpu(self): 158 | print('CPU run') 159 | np.random.seed(1) 160 | tf.reset_default_graph() 161 | n, d1 = self.trainX.shape 162 | n, d2 = self.trainY.shape 163 | 164 | with tf.Graph().as_default(), tf.Session() as session: 165 | set_random_seed(0) 166 | model = MMvec(beta_1=0.8, beta_2=0.9, latent_dim=2, 167 | batch_size=2000) 168 | model(session, 169 | coo_matrix(self.trainX.values), self.trainY.values, 170 | coo_matrix(self.testX.values), self.testY.values) 171 | model.fit(epoch=10000) 172 | 173 | 174 | if __name__ == "__main__": 175 | unittest.main() 176 | -------------------------------------------------------------------------------- /mmvec/tests/test_util.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import pandas as pd 4 | from biom import Table 5 | from mmvec.util import rank_hits, split_tables 6 | import numpy.testing as npt 7 | import pandas.util.testing as pdt 8 | 9 | 10 | class TestRankHits(unittest.TestCase): 11 | 12 | def test_rank_hits(self): 13 | ranks = pd.DataFrame( 14 | [ 15 | [1., 4., 1., 5., 7.], 16 | [2., 6., 9., 2., 8.], 17 | [2., 2., 6., 8., 4.] 18 | ], 19 | index=['OTU_1', 'OTU_2', 'OTU_3'], 20 | columns=['MS_1', 'MS_2', 'MS_3', 'MS_4', 'MS_5'] 21 | ) 22 | res = rank_hits(ranks, k=2) 23 | exp = pd.DataFrame( 24 | [ 25 | ['OTU_1', 5., 'MS_4'], 26 | ['OTU_2', 8., 'MS_5'], 27 | ['OTU_3', 6., 'MS_3'], 28 | ['OTU_1', 7., 'MS_5'], 29 | ['OTU_2', 9., 'MS_3'], 30 | ['OTU_3', 8., 'MS_4'] 31 | ], columns=['src', 'rank', 'dest'], 32 | ) 33 | 34 | pdt.assert_frame_equal(res, exp) 35 | 36 | 37 | class TestSplitTables(unittest.TestCase): 38 | 39 | def setUp(self): 40 | 41 | omat = np.array([ 42 | [104, 10, 2, 0, 0], 43 | [4, 100, 20, 0, 0], 44 | [0, 1, 0, 0, 4], 45 | [4, 0, 21, 0, 2], 46 | [40, 0, 2, 1, 39], 47 | [0, 0, 32, 10, 3], 48 | [59, 1, 0, 0, 3] 49 | ]) 50 | mmat = np.array([ 51 | [104, 1, 31, 0, 8], 52 | [4, 100, 20, 0, 0], 53 | [0, 8, 0, 0, 4], 54 | [0, 0, 2, 1, 2], 55 | [0, 0, 20, 10, 3], 56 | [0, 8, 0, 0, 4], 57 | [0, 0, 2, 10, 3], 58 | [0, 0, 320, 139, 3], 59 | [59, 9, 0, 0, 33] 60 | ]) * 10e6 61 | 62 | oids = list(map(lambda x: 'o'+str(x), np.arange(omat.shape[0]))) 63 | mids = list(map(lambda x: 'm'+str(x), np.arange(mmat.shape[0]))) 64 | sids = list(map(lambda x: 'm'+str(x), np.arange(mmat.shape[1]))) 65 | 66 | self.otu_table = Table(omat, oids, sids) 67 | self.metabolite_table = Table(mmat, mids, sids) 68 | 69 | self.metadata = pd.DataFrame( 70 | { 71 | 'testing': ['Train', 'Test', 'Train', 'Test', 'Train'], 72 | 'bad': [True, False, True, False, True] 73 | }, index=sids 74 | ) 75 | 76 | def test_split_tables_train_column(self): 77 | 78 | res = split_tables(self.otu_table, self.metabolite_table, 79 | metadata=self.metadata, training_column='testing', 80 | num_test=10, min_samples=0) 81 | 82 | (train_microbes, test_microbes, 83 | train_metabolites, test_metabolites) = res 84 | 85 | npt.assert_allclose(train_microbes.shape, np.array([3, 7])) 86 | npt.assert_allclose(test_microbes.shape, np.array([2, 7])) 87 | npt.assert_allclose(train_metabolites.shape, np.array([3, 9])) 88 | npt.assert_allclose(test_metabolites.shape, np.array([2, 9])) 89 | 90 | def test_split_tables_bad_column(self): 91 | with self.assertRaises(Exception): 92 | split_tables(self.otu_table, self.metabolite_table, 93 | metadata=self.metadata, training_column='bad', 94 | num_test=10, min_samples=0) 95 | 96 | def test_split_tables_random(self): 97 | res = split_tables(self.otu_table, self.metabolite_table, 98 | num_test=2, min_samples=0) 99 | 100 | (train_microbes, test_microbes, 101 | train_metabolites, test_metabolites) = res 102 | npt.assert_allclose(train_microbes.shape, np.array([3, 7])) 103 | npt.assert_allclose(test_microbes.shape, np.array([2, 7])) 104 | npt.assert_allclose(train_metabolites.shape, np.array([3, 9])) 105 | npt.assert_allclose(test_metabolites.shape, np.array([2, 9])) 106 | 107 | def test_split_tables_random_filter(self): 108 | res = split_tables(self.otu_table, self.metabolite_table, 109 | num_test=2, min_samples=2) 110 | 111 | (train_microbes, test_microbes, 112 | train_metabolites, test_metabolites) = res 113 | npt.assert_allclose(train_microbes.shape, np.array([3, 7])) 114 | npt.assert_allclose(test_microbes.shape, np.array([2, 7])) 115 | npt.assert_allclose(train_metabolites.shape, np.array([3, 9])) 116 | npt.assert_allclose(test_metabolites.shape, np.array([2, 9])) 117 | 118 | 119 | if __name__ == "__main__": 120 | unittest.main() 121 | -------------------------------------------------------------------------------- /mmvec/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from sklearn.utils import check_random_state 4 | from skbio.stats.composition import ilr_inv 5 | from skbio.stats.composition import clr_inv as softmax 6 | 7 | 8 | def random_multimodal(num_microbes=20, num_metabolites=100, num_samples=100, 9 | latent_dim=3, low=-1, high=1, 10 | microbe_total=10, metabolite_total=100, 11 | uB=0, sigmaB=2, sigmaQ=0.1, 12 | uU=0, sigmaU=1, uV=0, sigmaV=1, 13 | seed=0): 14 | """ 15 | Parameters 16 | ---------- 17 | num_microbes : int 18 | Number of microbial species to simulate 19 | num_metabolites : int 20 | Number of molecules to simulate 21 | num_samples : int 22 | Number of samples to generate 23 | latent_dim : 24 | Number of latent dimensions 25 | low : float 26 | Lower bound of gradient 27 | high : float 28 | Upper bound of gradient 29 | microbe_total : int 30 | Total number of microbial species 31 | metabolite_total : int 32 | Total number of metabolite species 33 | uB : float 34 | Mean of regression coefficient distribution 35 | sigmaB : float 36 | Standard deviation of regression coefficient distribution 37 | sigmaQ : float 38 | Standard deviation of error distribution 39 | uU : float 40 | Mean of microbial input projection coefficient distribution 41 | sigmaU : float 42 | Standard deviation of microbial input projection 43 | coefficient distribution 44 | uV : float 45 | Mean of metabolite output projection coefficient distribution 46 | sigmaV : float 47 | Standard deviation of metabolite output projection 48 | coefficient distribution 49 | seed : float 50 | Random seed 51 | 52 | Returns 53 | ------- 54 | microbe_counts : pd.DataFrame 55 | Count table of microbial counts 56 | metabolite_counts : pd.DataFrame 57 | Count table of metabolite counts 58 | """ 59 | state = check_random_state(seed) 60 | # only have two coefficients 61 | beta = state.normal(uB, sigmaB, size=(2, num_microbes)) 62 | 63 | X = np.vstack((np.ones(num_samples), 64 | np.linspace(low, high, num_samples))).T 65 | microbes = ilr_inv(state.multivariate_normal( 66 | mean=np.zeros(num_microbes-1), cov=np.diag([sigmaQ]*(num_microbes-1)), 67 | size=num_samples) 68 | ) 69 | Umain = state.normal( 70 | uU, sigmaU, size=(num_microbes, latent_dim)) 71 | Vmain = state.normal( 72 | uV, sigmaV, size=(latent_dim, num_metabolites-1)) 73 | 74 | Ubias = state.normal( 75 | uU, sigmaU, size=(num_microbes, 1)) 76 | Vbias = state.normal( 77 | uV, sigmaV, size=(1, num_metabolites-1)) 78 | 79 | U_ = np.hstack( 80 | (np.ones((num_microbes, 1)), Ubias, Umain)) 81 | V_ = np.vstack( 82 | (Vbias, np.ones((1, num_metabolites-1)), Vmain)) 83 | 84 | phi = np.hstack((np.zeros((num_microbes, 1)), U_ @ V_)) 85 | probs = softmax(phi) 86 | microbe_counts = np.zeros((num_samples, num_microbes)) 87 | metabolite_counts = np.zeros((num_samples, num_metabolites)) 88 | n1 = microbe_total 89 | n2 = metabolite_total // microbe_total 90 | for n in range(num_samples): 91 | otu = state.multinomial(n1, microbes[n, :]) 92 | for i in range(num_microbes): 93 | ms = state.multinomial(otu[i] * n2, probs[i, :]) 94 | metabolite_counts[n, :] += ms 95 | microbe_counts[n, :] += otu 96 | 97 | otu_ids = ['OTU_%d' % d for d in range(microbe_counts.shape[1])] 98 | ms_ids = ['metabolite_%d' % d for d in range(metabolite_counts.shape[1])] 99 | sample_ids = ['sample_%d' % d for d in range(metabolite_counts.shape[0])] 100 | 101 | microbe_counts = pd.DataFrame( 102 | microbe_counts, index=sample_ids, columns=otu_ids) 103 | metabolite_counts = pd.DataFrame( 104 | metabolite_counts, index=sample_ids, columns=ms_ids) 105 | 106 | return (microbe_counts, metabolite_counts, X, beta, 107 | Umain, Ubias, Vmain, Vbias) 108 | 109 | 110 | def split_tables(otu_table, metabolite_table, 111 | metadata=None, training_column=None, num_test=10, 112 | min_samples=10): 113 | """ Splits otu and metabolite tables into training and testing datasets. 114 | 115 | Parameters 116 | ---------- 117 | otu_table : biom.Table 118 | Table of microbe abundances 119 | metabolite_table : biom.Table 120 | Table of metabolite intensities 121 | metadata : pd.DataFrame 122 | DataFrame of sample metadata information. This is primarily used 123 | to indicated training and testing samples 124 | training_column : str 125 | The column used to indicate training and testing samples. 126 | Samples labeled 'Train' are allocated to the training set. 127 | All other samples are placed in the testing dataset. 128 | num_test : int 129 | If metadata or training_column is not specified, then `num_test` 130 | indicates how many testing samples will be allocated for 131 | cross validation. 132 | min_samples : int 133 | The minimum number of samples a microbe needs to be observed in 134 | in order to not get filtered out 135 | Returns 136 | ------- 137 | train_microbes : pd.DataFrame 138 | Training set of microbes 139 | test_microbes : pd.DataFrame 140 | Testing set of microbes 141 | train_metabolites : pd.DataFrame 142 | Training set of metabolites 143 | test_metabolites : pd.DataFrame 144 | Testing set of metabolites 145 | Notes 146 | ----- 147 | There is an inefficient conversion from a sparse matrix to a 148 | dense matrix. This may become a bottleneck later. 149 | """ 150 | microbes_df = otu_table.to_dataframe().T 151 | metabolites_df = metabolite_table.to_dataframe().T 152 | 153 | microbes_df, metabolites_df = microbes_df.align( 154 | metabolites_df, axis=0, join='inner' 155 | ) 156 | 157 | # filter out microbes that don't appear in many samples 158 | idx = (microbes_df > 0).sum(axis=0) >= min_samples 159 | microbes_df = microbes_df.loc[:, idx] 160 | if metadata is None or training_column is None: 161 | sample_ids = set(np.random.choice(microbes_df.index, size=num_test)) 162 | sample_ids = np.array([(x in sample_ids) for x in microbes_df.index]) 163 | else: 164 | if len(set(metadata[training_column]) & {'Train', 'Test'}) == 0: 165 | raise ValueError( 166 | "Training column must only specify `Train` and `Test` values" 167 | ) 168 | idx = metadata.loc[metadata[training_column] != 'Train'].index 169 | sample_ids = set(idx) 170 | sample_ids = np.array([(x in sample_ids) for x in microbes_df.index]) 171 | 172 | train_microbes = microbes_df.loc[~sample_ids] 173 | test_microbes = microbes_df.loc[sample_ids] 174 | train_metabolites = metabolites_df.loc[~sample_ids] 175 | test_metabolites = metabolites_df.loc[sample_ids] 176 | if len(train_microbes) == 0 or len(train_microbes.columns) == 0: 177 | raise ValueError('All of the training data has been filtered out. ' 178 | 'Adjust the `--min-feature-count` accordingly.') 179 | return train_microbes, test_microbes, train_metabolites, test_metabolites 180 | 181 | 182 | def rank_hits(ranks, k, pos=True): 183 | """ Creates an edge list based on rank matrix. 184 | 185 | Parameters 186 | ---------- 187 | ranks : pd.DataFrame 188 | Matrix of ranks (aka conditional probabilities) 189 | k : int 190 | Number of nearest neighbors 191 | pos : bool 192 | Specifies either most associated or least associated. 193 | This is a proxy to positively correlated or negatively correlated. 194 | 195 | Returns 196 | ------- 197 | edges : pd.DataFrame 198 | List of edges along with corresponding ranks. 199 | """ 200 | axis = 1 201 | 202 | def sort_f(x): 203 | if pos: 204 | return [ 205 | ranks.columns[i] for i in np.argsort(x)[-k:] 206 | ] 207 | else: 208 | return [ 209 | ranks.columns[i] for i in np.argsort(x)[:k] 210 | ] 211 | 212 | idx = ranks.index 213 | topk = ranks.apply(sort_f, axis=axis).values 214 | topk = pd.DataFrame([x for x in topk], index=idx) 215 | top_hits = topk.reset_index() 216 | top_hits = top_hits.rename(columns={'index': 'src'}) 217 | edges = pd.melt( 218 | top_hits, id_vars=['src'], 219 | var_name='rank', 220 | value_vars=list(range(k)), 221 | value_name='dest') 222 | 223 | # fill in actual ranks 224 | for i in edges.index: 225 | src = edges.loc[i, 'src'] 226 | dest = edges.loc[i, 'dest'] 227 | edges.loc[i, 'rank'] = ranks.loc[src, dest] 228 | edges['rank'] = edges['rank'].astype(np.float64) 229 | return edges 230 | 231 | 232 | def format_params(vals, colnames, rownames, 233 | embed_name, index_name='feature_id'): 234 | """ Reformats the model parameters in a readable format 235 | 236 | Parameters 237 | ---------- 238 | vals : np.array 239 | Values of the model parameters 240 | colnames : array_like of str 241 | Column names corresponding to the features. 242 | These typically correspond to PC axis names. 243 | rownames : array_like of str 244 | Row names corresponding to the features. 245 | These typically correspond to microbe/metabolite names. 246 | embed_name: str 247 | Specifies which embedding is being formatted 248 | index_name : str 249 | Specifies the index name, since it'll be formatted 250 | into a qiime2 Metadata format 251 | 252 | Returns 253 | ------- 254 | pd.DataFrame 255 | feature_id : str 256 | Feature names 257 | axis : str 258 | PC axis names 259 | embed_type: str 260 | Specifies which embedding is being formatted 261 | values : float 262 | Corresponding model parameters 263 | """ 264 | df = pd.DataFrame(vals, columns=colnames, index=rownames) 265 | df = df.reset_index() 266 | df = df.rename(columns={'index': 'feature_id'}) 267 | df = pd.melt(df, id_vars=['feature_id'], 268 | var_name='axis', value_name='values') 269 | 270 | df['embed_type'] = embed_name 271 | 272 | return df[['feature_id', 'axis', 'embed_type', 'values']] 273 | 274 | 275 | def embeddings2ranks(embeddings): 276 | """ Converts embeddings to ranks""" 277 | microbes = embeddings.loc[embeddings.embed_type == 'microbe'] 278 | metabolites = embeddings.loc[embeddings.embed_type == 'metabolite'] 279 | 280 | U = microbes.pivot(index='feature_id', columns='axis', values='values') 281 | V = metabolites.pivot(index='feature_id', columns='axis', values='values') 282 | pc_ids = sorted(list(set(U.columns) - {'bias'})) 283 | U['ones'] = 1 284 | V['ones'] = 1 285 | ranks = U[pc_ids + ['ones', 'bias']] @ V[pc_ids + ['bias', 'ones']].T 286 | # center each row 287 | ranks = ranks - ranks.mean(axis=1).values.reshape(-1, 1) 288 | return ranks 289 | 290 | 291 | def alr2clr(x): 292 | if x.ndim > 1: 293 | y = np.hstack((np.zeros((x.shape[1], 1)), x)) 294 | y = y - y.mean(axis=1).reshape(-1, 1) 295 | else: 296 | y = np.hstack((np.zeros(1), x)) 297 | y = y - y.mean() 298 | 299 | return y 300 | -------------------------------------------------------------------------------- /scripts/mmvec: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import time 4 | import click 5 | import datetime 6 | from tqdm import tqdm 7 | import pandas as pd 8 | import numpy as np 9 | from biom import load_table, Table 10 | from biom.util import biom_open 11 | from skbio import OrdinationResults 12 | from skbio.stats.composition import clr, centralize, closure 13 | from skbio.stats.composition import clr_inv as softmax 14 | from scipy.stats import entropy, spearmanr 15 | from scipy.sparse import coo_matrix 16 | from scipy.sparse.linalg import svds 17 | import tensorflow as tf 18 | from tensorflow.contrib.distributions import Multinomial, Normal 19 | from mmvec.multimodal import MMvec 20 | from mmvec.util import split_tables, format_params 21 | import matplotlib.pyplot as plt 22 | 23 | @click.group() 24 | def mmvec(): 25 | pass 26 | 27 | 28 | @mmvec.command() 29 | @click.option('--microbe-file', 30 | help='Input microbial abundances') 31 | @click.option('--metabolite-file', 32 | help='Input metabolite abundances') 33 | @click.option('--metadata-file', default=None, 34 | help='Input sample metadata file') 35 | @click.option('--training-column', 36 | help=('Column in the sample metadata specifying which ' 37 | 'samples are for training and testing.'), 38 | default=None) 39 | @click.option('--num-testing-examples', 40 | help=('Number of samples to randomly select for testing'), 41 | default=5) 42 | @click.option('--min-feature-count', 43 | help=('Minimum number of samples a microbe needs to be observed ' 44 | 'in order to not filter out'), 45 | default=10) 46 | @click.option('--epochs', 47 | help='Number of epochs to train', default=10) 48 | @click.option('--batch-size', 49 | help='Size of mini-batch', default=50) 50 | @click.option('--latent-dim', 51 | help=('Dimensionality of shared latent space. ' 52 | 'This is analogous to the number of PC axes.'), 53 | default=3) 54 | @click.option('--input-prior', 55 | help=('Width of normal prior for input embedding. ' 56 | 'Smaller values will regularize parameters towards zero. ' 57 | 'Values must be greater than 0.'), 58 | default=1.) 59 | @click.option('--output-prior', 60 | help=('Width of normal prior for input embedding. ' 61 | 'Smaller values will regularize parameters towards zero. ' 62 | 'Values must be greater than 0.'), 63 | default=1.) 64 | @click.option('--arm-the-gpu', is_flag=True, 65 | help=('Enables GPU support'), 66 | default=False) 67 | @click.option('--learning-rate', 68 | help=('Gradient descent decay rate.'), 69 | default=1e-3) 70 | @click.option('--beta1', 71 | help=('Gradient decay rate for first Adam momentum estimates'), 72 | default=0.9) 73 | @click.option('--beta2', 74 | help=('Gradient decay rate for second Adam momentum estimates'), 75 | default=0.95) 76 | @click.option('--clipnorm', 77 | help=('Gradient clipping size.'), 78 | default=10.) 79 | @click.option('--checkpoint-interval', 80 | help=('Number of seconds before a storing a summary.'), 81 | default=1000) 82 | @click.option('--summary-interval', 83 | help=('Number of seconds before a storing a summary.'), 84 | default=1) 85 | @click.option('--summary-dir', default='summarydir', 86 | help='Summary directory to save cross validation results.') 87 | @click.option('--embeddings-file', default=None, 88 | help=('Path to save the embeddings learned from the model. ' 89 | 'If this is not specified, then this will be saved under ' 90 | '`--summary-dir`.')) 91 | @click.option('--ranks-file', default=None, 92 | help=('Path to save the ranks learned from the model. ' 93 | 'If this is not specified, then this will be saved under ' 94 | '`--summary-dir`.')) 95 | @click.option('--ordination-file', default=None, 96 | help=('Path to save the ordination learned from the model. ' 97 | 'If this is not specified, then this will be saved under ' 98 | '`--summary-dir`.')) 99 | @click.option("--equalize-biplot", default=False, required=False, is_flag=True, 100 | help=('Equalize the norms of the singular ' 101 | 'vectors of the conditional probability matrix.')) 102 | def paired_omics(microbe_file, metabolite_file, 103 | metadata_file, training_column, 104 | num_testing_examples, min_feature_count, 105 | epochs, batch_size, latent_dim, 106 | input_prior, output_prior, arm_the_gpu, 107 | learning_rate, beta1, beta2, clipnorm, 108 | checkpoint_interval, summary_interval, 109 | summary_dir, embeddings_file, ranks_file, ordination_file, 110 | equalize_biplot): 111 | 112 | microbes = load_table(microbe_file) 113 | metabolites = load_table(metabolite_file) 114 | 115 | if metadata_file is not None: 116 | metadata = pd.read_table(metadata_file, index_col=0) 117 | else: 118 | metadata = None 119 | 120 | # filter out low abundance microbes and split table 121 | res = split_tables( 122 | microbes, metabolites, 123 | metadata=metadata, training_column=training_column, 124 | num_test=num_testing_examples, 125 | min_samples=min_feature_count) 126 | 127 | (train_microbes_df, test_microbes_df, 128 | train_metabolites_df, test_metabolites_df) = res 129 | 130 | 131 | params = [] 132 | 133 | sname = 'latent_dim_' + str(latent_dim) + \ 134 | '_input_prior_%.2f' % input_prior + \ 135 | '_output_prior_%.2f' % output_prior + \ 136 | '_beta1_%.2f' % beta1 + \ 137 | '_beta2_%.2f' % beta2 138 | 139 | sname = os.path.join(summary_dir, sname) 140 | if embeddings_file is None: 141 | embeddings_file = sname + "_embedding.txt" 142 | if ranks_file is None: 143 | ranks_file = sname + "_ranks.txt" 144 | if ordination_file is None: 145 | ordination_file = sname + "_ordination.txt" 146 | 147 | 148 | n, d1 = microbes.shape 149 | n, d2 = metabolites.shape 150 | 151 | train_microbes_coo = coo_matrix(train_microbes_df.values) 152 | test_microbes_coo = coo_matrix(test_microbes_df.values) 153 | 154 | if arm_the_gpu: 155 | # pick out the first GPU 156 | device_name='/device:GPU:0' 157 | else: 158 | device_name='/cpu:0' 159 | 160 | config = tf.ConfigProto() 161 | with tf.Graph().as_default(), tf.Session(config=config) as session: 162 | model = MMvec( 163 | latent_dim=latent_dim, 164 | u_scale=input_prior, v_scale=output_prior, 165 | learning_rate = learning_rate, 166 | beta_1=beta1, beta_2=beta2, 167 | device_name=device_name, 168 | batch_size=batch_size, 169 | clipnorm=clipnorm, save_path=sname) 170 | 171 | model(session, 172 | train_microbes_coo, train_metabolites_df.values, 173 | test_microbes_coo, test_metabolites_df.values) 174 | 175 | loss, cv = model.fit(epoch=epochs, summary_interval=summary_interval, 176 | checkpoint_interval=checkpoint_interval) 177 | 178 | pc_ids = list(range(latent_dim)) 179 | vdim = model.V.shape[0] 180 | V = np.hstack((np.zeros((vdim, 1)), model.V)) 181 | V = V.T 182 | Vbias = np.hstack((np.zeros(1), model.Vbias.ravel())) 183 | 184 | # Save to an embeddings file 185 | Uparam = format_params(model.U, pc_ids, list(train_microbes_df.columns), 'microbe') 186 | Vparam = format_params(V, pc_ids, list(train_metabolites_df.columns), 'metabolite') 187 | df = pd.concat( 188 | ( 189 | Uparam, Vparam, 190 | format_params(model.Ubias, ['bias'], train_microbes_df.columns, 'microbe'), 191 | format_params(Vbias, ['bias'], train_metabolites_df.columns, 'metabolite') 192 | ), axis=0) 193 | 194 | df.to_csv(embeddings_file, sep='\t') 195 | 196 | # Save to a ranks file 197 | ranks = pd.DataFrame(model.ranks(), index=train_microbes_df.columns, 198 | columns=train_metabolites_df.columns) 199 | 200 | u, s, v = svds(ranks - ranks.mean(axis=0), k=latent_dim) 201 | ranks = ranks.T 202 | ranks.index.name = 'featureid' 203 | ranks.to_csv(ranks_file, sep='\t') 204 | # Save to an ordination file 205 | s = s[::-1] 206 | u = u[:, ::-1] 207 | v = v[::-1, :] 208 | if equalize_biplot: 209 | microbe_embed = u @ np.sqrt(np.diag(s)) 210 | metabolite_embed = v.T @ np.sqrt(np.diag(s)) 211 | else: 212 | microbe_embed = u @ np.diag(s) 213 | metabolite_embed = v.T 214 | pc_ids = ['PC%d' % i for i in range(microbe_embed.shape[1])] 215 | features = pd.DataFrame( 216 | microbe_embed, columns=pc_ids, 217 | index=train_microbes_df.columns) 218 | samples = pd.DataFrame( 219 | metabolite_embed, columns=pc_ids, 220 | index=train_metabolites_df.columns) 221 | short_method_name = 'mmvec biplot' 222 | long_method_name = 'Multiomics mmvec biplot' 223 | eigvals = pd.Series(s, index=pc_ids) 224 | proportion_explained = pd.Series(s**2 / np.sum(s**2), index=pc_ids) 225 | biplot = OrdinationResults( 226 | short_method_name, long_method_name, eigvals, 227 | samples=samples, features=features, 228 | proportion_explained=proportion_explained) 229 | biplot.write(ordination_file) 230 | 231 | 232 | if __name__ == '__main__': 233 | mmvec() 234 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # ---------------------------------------------------------------------------- 4 | # Copyright (c) 2016--, gneiss development team. 5 | # 6 | # Distributed under the terms of the Modified BSD License. 7 | # 8 | # The full license is in the file COPYING.txt, distributed with this software. 9 | # ---------------------------------------------------------------------------- 10 | import re 11 | import ast 12 | from glob import glob 13 | from setuptools import find_packages, setup 14 | 15 | classes = """ 16 | Development Status :: 5 - Production/Stable 17 | License :: OSI Approved :: BSD License 18 | Topic :: Software Development :: Libraries 19 | Topic :: Scientific/Engineering 20 | Topic :: Scientific/Engineering :: Bio-Informatics 21 | Programming Language :: Python :: 3 22 | Programming Language :: Python :: 3 :: Only 23 | Operating System :: Unix 24 | Operating System :: POSIX 25 | Operating System :: MacOS :: MacOS X 26 | """ 27 | classifiers = [s.strip() for s in classes.split('\n') if s] 28 | 29 | description = ('Microbe-metabolite interactions through neural networks') 30 | 31 | with open('README.md') as f: 32 | long_description = f.read() 33 | 34 | # version parsing from __init__ pulled from Flask's setup.py 35 | # https://github.com/mitsuhiko/flask/blob/master/setup.py 36 | _version_re = re.compile(r'__version__\s+=\s+(.*)') 37 | with open('mmvec/__init__.py', 'rb') as f: 38 | hit = _version_re.search(f.read().decode('utf-8')).group(1) 39 | version = str(ast.literal_eval(hit)) 40 | 41 | 42 | setup(name='mmvec', 43 | version=version, 44 | license='BSD-3-Clause', 45 | description=description, 46 | long_description=long_description, 47 | long_description_content_type='text/markdown', 48 | author="gneiss development team", 49 | author_email="jamietmorton@gmail.com", 50 | maintainer="gneiss development team", 51 | maintainer_email="jamietmorton@gmail.com", 52 | packages=find_packages(), 53 | scripts=glob('scripts/mmvec'), 54 | install_requires=[ 55 | 'biom-format', 56 | 'numpy >= 1.9.2', 57 | 'pandas <= 0.25.3', 58 | 'scipy >= 0.15.1', 59 | 'nose >= 1.3.7', 60 | 'scikit-bio >= 0.5.1', 61 | 'seaborn', 62 | 'tqdm', 63 | 'tensorflow>=1.15,<2' 64 | ], 65 | classifiers=classifiers, 66 | entry_points={ 67 | 'qiime2.plugins': ['q2-mmvec=mmvec.q2.plugin_setup:plugin'] 68 | }, 69 | package_data={'mmvec': ['q2/assets/*.html']}, 70 | zip_safe=False) 71 | --------------------------------------------------------------------------------