├── .gitmodules ├── benchmarks ├── sim5g.sh ├── simHC_edges.sh ├── simHC.sh ├── aale_loss_vae.sh ├── aale_loss_gcn3.sh ├── losscoef_vae.sh ├── noise_experiments.sh ├── test_graphmb.sh ├── strong100.sh ├── aale_gnn.sh ├── run_wwtp.sh ├── aale.sh ├── run_strong100.sh ├── aale_lr.sh ├── run_test.sh ├── runs.sh ├── aale_vae.sh └── aale_bs.sh ├── src └── graphmb │ ├── version.py │ ├── data │ ├── kernel.npz │ └── Bacteria.ms │ ├── __init__.py │ ├── arg_options.py │ ├── dgl_dataset.py │ ├── utils.py │ ├── train_gnn.py │ ├── visualize.py │ ├── unused │ ├── train_vae.py │ └── train_gnn_decode.py │ └── amber_eval.py ├── MANIFEST.in ├── .dockerignore ├── docs ├── source │ ├── modules.rst │ ├── _build │ │ └── html │ │ │ ├── objects.inv │ │ │ ├── _static │ │ │ ├── file.png │ │ │ ├── plus.png │ │ │ ├── minus.png │ │ │ ├── css │ │ │ │ ├── fonts │ │ │ │ │ ├── lato-bold.woff │ │ │ │ │ ├── lato-bold.woff2 │ │ │ │ │ ├── lato-normal.woff │ │ │ │ │ ├── lato-normal.woff2 │ │ │ │ │ ├── Roboto-Slab-Bold.woff │ │ │ │ │ ├── Roboto-Slab-Bold.woff2 │ │ │ │ │ ├── fontawesome-webfont.eot │ │ │ │ │ ├── fontawesome-webfont.ttf │ │ │ │ │ ├── lato-bold-italic.woff │ │ │ │ │ ├── lato-bold-italic.woff2 │ │ │ │ │ ├── lato-normal-italic.woff │ │ │ │ │ ├── Roboto-Slab-Regular.woff │ │ │ │ │ ├── Roboto-Slab-Regular.woff2 │ │ │ │ │ ├── fontawesome-webfont.woff │ │ │ │ │ ├── fontawesome-webfont.woff2 │ │ │ │ │ └── lato-normal-italic.woff2 │ │ │ │ └── badge_only.css │ │ │ ├── documentation_options.js │ │ │ ├── js │ │ │ │ ├── badge_only.js │ │ │ │ ├── html5shiv.min.js │ │ │ │ ├── html5shiv-printshiv.min.js │ │ │ │ └── theme.js │ │ │ ├── pygments.css │ │ │ ├── doctools.js │ │ │ └── language_data.js │ │ │ ├── .doctrees │ │ │ ├── index.doctree │ │ │ ├── intro.doctree │ │ │ ├── examples.doctree │ │ │ ├── graphmb.doctree │ │ │ ├── modules.doctree │ │ │ └── environment.pickle │ │ │ ├── .buildinfo │ │ │ └── searchindex.js │ ├── generated │ │ └── graphmb.rst │ ├── index.rst │ ├── intro.rst │ ├── graphmb.rst │ ├── conf.py │ ├── development.rst │ └── examples.rst ├── Makefile └── make.bat ├── results ├── graphmb 0.2 results.ods └── graphbemb experiments.ods ├── pyproject.toml ├── .gitignore ├── Dockerfile ├── .github └── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md ├── setup.py ├── LICENSE └── CHANGELOG.md /.gitmodules: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /benchmarks/sim5g.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | -------------------------------------------------------------------------------- /src/graphmb/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.6" 2 | 3 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include src/graphmb/data/Bacteria.ms 2 | include src/graphmb/data/kernel.npz -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | build/ 2 | dist/ 3 | results/ 4 | venv/ 5 | graphmb.egg-info/ 6 | docs/ 7 | -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | graphmb 2 | ======= 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | graphmb 8 | -------------------------------------------------------------------------------- /src/graphmb/data/kernel.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/src/graphmb/data/kernel.npz -------------------------------------------------------------------------------- /results/graphmb 0.2 results.ods: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/results/graphmb 0.2 results.ods -------------------------------------------------------------------------------- /docs/source/_build/html/objects.inv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/objects.inv -------------------------------------------------------------------------------- /results/graphbemb experiments.ods: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/results/graphbemb experiments.ods -------------------------------------------------------------------------------- /docs/source/_build/html/_static/file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/_static/file.png -------------------------------------------------------------------------------- /docs/source/_build/html/_static/plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/_static/plus.png -------------------------------------------------------------------------------- /docs/source/_build/html/_static/minus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/_static/minus.png -------------------------------------------------------------------------------- /src/graphmb/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from . import contigsdataset 3 | from . import utils 4 | from . import evaluate 5 | from . import version 6 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "setuptools.build_meta" 3 | requires = ["setuptools~=58.0", "pip>=19,!=20.0,!=20.0.1,<21", "wheel"] 4 | -------------------------------------------------------------------------------- /docs/source/_build/html/.doctrees/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/.doctrees/index.doctree -------------------------------------------------------------------------------- /docs/source/_build/html/.doctrees/intro.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/.doctrees/intro.doctree -------------------------------------------------------------------------------- /docs/source/_build/html/.doctrees/examples.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/.doctrees/examples.doctree -------------------------------------------------------------------------------- /docs/source/_build/html/.doctrees/graphmb.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/.doctrees/graphmb.doctree -------------------------------------------------------------------------------- /docs/source/_build/html/.doctrees/modules.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/.doctrees/modules.doctree -------------------------------------------------------------------------------- /docs/source/_build/html/.doctrees/environment.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/.doctrees/environment.pickle -------------------------------------------------------------------------------- /docs/source/_build/html/_static/css/fonts/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/_static/css/fonts/lato-bold.woff -------------------------------------------------------------------------------- /docs/source/_build/html/_static/css/fonts/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/_static/css/fonts/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/source/_build/html/_static/css/fonts/lato-normal.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/_static/css/fonts/lato-normal.woff -------------------------------------------------------------------------------- /docs/source/_build/html/_static/css/fonts/lato-normal.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/_static/css/fonts/lato-normal.woff2 -------------------------------------------------------------------------------- /docs/source/_build/html/_static/css/fonts/Roboto-Slab-Bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/_static/css/fonts/Roboto-Slab-Bold.woff -------------------------------------------------------------------------------- /docs/source/_build/html/_static/css/fonts/Roboto-Slab-Bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/_static/css/fonts/Roboto-Slab-Bold.woff2 -------------------------------------------------------------------------------- /docs/source/_build/html/_static/css/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/_static/css/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /docs/source/_build/html/_static/css/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/_static/css/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /docs/source/_build/html/_static/css/fonts/lato-bold-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/_static/css/fonts/lato-bold-italic.woff -------------------------------------------------------------------------------- /docs/source/_build/html/_static/css/fonts/lato-bold-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/_static/css/fonts/lato-bold-italic.woff2 -------------------------------------------------------------------------------- /docs/source/_build/html/_static/css/fonts/lato-normal-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/_static/css/fonts/lato-normal-italic.woff -------------------------------------------------------------------------------- /docs/source/_build/html/_static/css/fonts/Roboto-Slab-Regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/_static/css/fonts/Roboto-Slab-Regular.woff -------------------------------------------------------------------------------- /docs/source/_build/html/_static/css/fonts/Roboto-Slab-Regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/_static/css/fonts/Roboto-Slab-Regular.woff2 -------------------------------------------------------------------------------- /docs/source/_build/html/_static/css/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/_static/css/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /docs/source/_build/html/_static/css/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/_static/css/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /docs/source/_build/html/_static/css/fonts/lato-normal-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrobialDarkMatter/GraphMB/HEAD/docs/source/_build/html/_static/css/fonts/lato-normal-italic.woff2 -------------------------------------------------------------------------------- /docs/source/generated/graphmb.rst: -------------------------------------------------------------------------------- 1 | graphmb 2 | ======= 3 | 4 | .. automodule:: graphmb 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /docs/source/_build/html/.buildinfo: -------------------------------------------------------------------------------- 1 | # Sphinx build info version 1 2 | # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. 3 | config: 1ea206695607a08625ce58412a77a94b 4 | tags: 645f666f9bcd5a90fca523b33c5a78b7 5 | -------------------------------------------------------------------------------- /benchmarks/simHC_edges.sh: -------------------------------------------------------------------------------- 1 | python src/graphmb/main.py --cuda --assembly ../data/simHC/ --outdir results/simHC_edges/ --assembly_name assembly.fasta --depth 2 | abundance.tsv.edges_jgi --evalskip 0 --epoch 100 --model gcn_ae --rawfeatures --batchsize 256 --gnn_alpha 0 --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-3 --layers_gnn 0 --nega 3 | tives 10 --outname ae_lr1e-3 --nruns 1 --embsize_gnn 32 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.edges 2 | *_out/ 3 | bins/ 4 | data/ 5 | *_logs/ 6 | embeddings/ 7 | graphs/ 8 | *.ckpt 9 | .vscode/ 10 | venv/ 11 | *.tsv 12 | *.csv 13 | *.html 14 | *.txt 15 | *.gfa 16 | *.pkl 17 | dist/ 18 | build/ 19 | .eggs/ 20 | results/ 21 | *.out 22 | *.png 23 | *.sbatch 24 | *.pyc 25 | PKG-INFO 26 | mlruns/ 27 | src/graphmb/unused/* 28 | src/graphmb/__pycache__/ 29 | *.log 30 | 31 | -------------------------------------------------------------------------------- /benchmarks/simHC.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # simHC experiments 3 | 4 | python src/graphmb/main.py --cuda --assembly ../data/simHC/ --outdir results/simHC/ --assembly_name contigs.fasta --depth abundance.tsv --contignodes --evalskip 0 --epoch 100 --model gcn_ae --rawfeatures --batchsize 256 --gnn_alpha 0 --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-3 --layers_gnn 0 --negatives 10 --outname ae_lr1e-3 --nruns 1 --labels amber_ground_truth.tsv --embsize_gnn 32 -------------------------------------------------------------------------------- /docs/source/_build/html/_static/documentation_options.js: -------------------------------------------------------------------------------- 1 | var DOCUMENTATION_OPTIONS = { 2 | URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), 3 | VERSION: 'v0.1.2', 4 | LANGUAGE: 'None', 5 | COLLAPSE_INDEX: false, 6 | BUILDER: 'html', 7 | FILE_SUFFIX: '.html', 8 | LINK_SUFFIX: '.html', 9 | HAS_SOURCE: true, 10 | SOURCELINK_SUFFIX: '.txt', 11 | NAVIGATION_WITH_KEYS: false 12 | }; -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:22.04 2 | ARG DEBIAN_FRONTEND=noninteractive 3 | RUN apt-get update -y && apt-get install wget unzip vim -y 4 | RUN apt-get update -y && apt-get install -y python3 python3-pip python3-dev git && apt-get autoclean -y 5 | #RUN apt-get update && apt-get install sqlite3 libsqlite3-dev -y 6 | #RUN ln -s $(which pip3) /usr/bin/pip 7 | RUN pip install --upgrade pip 8 | 9 | #RUN make /app 10 | COPY ./ /graphmb/ 11 | #COPY ./data/strong100/ /graphmb/data/strong100/ 12 | WORKDIR /graphmb 13 | RUN python3 -m pip install -e . 14 | #CMD python /app/app.py 15 | -------------------------------------------------------------------------------- /benchmarks/aale_loss_vae.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | for ga in 0 0.1 0.2 0.3 0.5 1 5 | do 6 | for sa in 0 0.1 0.2 0.3 0.5 1 7 | do 8 | 9 | #### VAE+GNN0 10 | python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 11 | --epoch 1000 --model gcn_ae --batchsize 512 --gnn_alpha $ga \ 12 | --ae_alpha 1 --scg_alpha $sa --lr_gnn 1e-2 --layers_gnn 0 --negatives 10 \ 13 | --outname vaegcn_lr1e-2_e512_negs10_ga${ga}_sa${sa} --nruns 3 \ 14 | --embsize_gnn 64 --skip_preclustering --quick --rawfeatures 15 | 16 | done 17 | done 18 | -------------------------------------------------------------------------------- /benchmarks/aale_loss_gcn3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | for ga in 0.1 5 | do 6 | for sa in 0.1 7 | do 8 | 9 | #### VAE+GNN0 10 | python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 11 | --epoch 1000 --model gcn_ae --batchsize 0 --gnn_alpha $ga \ 12 | --ae_alpha 1 --scg_alpha $sa --lr_gnn 1e-2 --layers_gnn 3 --negatives 10 \ 13 | --outname vaegcn3_lr1e-2_ga${ga}_sa${sa}_pv50_fv_gd_bs0 --nruns 3 \ 14 | --embsize_gnn 32 --skip_preclustering --quick --rawfeatures --concatfeatures \ 15 | --vaepretrain 50 --decoder_input gnn 16 | 17 | done 18 | done 19 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. GraphMB documentation master file, created by 2 | sphinx-quickstart on Tue Dec 14 10:18:04 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to GraphMB's documentation! 7 | =================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | intro 14 | examples 15 | development 16 | graphmb 17 | 18 | 19 | 20 | Indices and tables 21 | ================== 22 | 23 | * :ref:`genindex` 24 | * :ref:`modindex` 25 | * :ref:`search` 26 | -------------------------------------------------------------------------------- /benchmarks/losscoef_vae.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wwtp=$1 4 | bs=128 5 | emb=32 6 | lr=1e-3 7 | quick= 8 | #quick="--quick" 9 | for ga in 0 0.1 0.2 0.3 0.5 1 10 | do 11 | for sa in 0 0.1 0.2 0.3 0.5 1 12 | do 13 | 14 | #### VAE+GNN0 15 | python src/graphmb/main.py --cuda --assembly ../data/$wwtp/ --outdir results/$wwtp/ --evalskip 100 \ 16 | --epoch 1000 --model gcn_ae --batchsize ${bs} --gnn_alpha $ga \ 17 | --ae_alpha 1 --scg_alpha $sa --lr_gnn ${lr} --layers_gnn 0 --negatives 10 \ 18 | --outname vaegcn_lr${lr}_e${bs}_negs10_ga${ga}_sa${sa} --nruns 3 \ 19 | --embsize_gnn ${emb} --skip_preclustering --rawfeatures ${quick} 20 | 21 | done 22 | done 23 | -------------------------------------------------------------------------------- /benchmarks/noise_experiments.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | python src/graphmb/main.py --assembly ../data/strong100/ --outdir results/simdata/ --evalskip 200 --epoch 1000 \ 5 | --model gcn_ae --rawfeatures --gnn_alpha 1 --ae_alpha 0 --scg_alpha 1 --lr_gnn 1e-3 \ 6 | --layers_gnn 1 --read_cache --markers "" --scg_alpha 0 --noise 7 | 8 | python src/graphmb/main.py --assembly ../data/strong100/ --outdir results/strong100/ --evalskip 200 --epoch 1000 \ 9 | --model gcn_ae --rawfeatures --gnn_alpha 1 --ae_alpha 0 --scg_alpha 1 --lr_gnn 1e-3 \ 10 | --layers_gnn 1 --scg_alpha 0 --noise -------------------------------------------------------------------------------- /benchmarks/test_graphmb.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATADIR=data/ 3 | 4 | #check data 5 | if [ -d $DATADIR/strong100 ] 6 | then 7 | echo "Strong100 dataset found" 8 | else 9 | echo "Error: dataset not found, downloading" 10 | cd $DATADIR; wget https://zenodo.org/record/6122610/files/strong100.zip; unzip strong100.zip 11 | fi 12 | 13 | # check venv 14 | if [ -d "./venv/" ] 15 | then 16 | echo "venv found" 17 | source venv/bin/activate 18 | else 19 | echo "venv not found" 20 | python -m venv venv 21 | source venv/bin/activate 22 | pip install -e . 23 | fi 24 | 25 | python src/graphmb/main.py --assembly $DATADIR/strong100/ --outdir results/strong100/ --markers marker_gene_stats.tsv -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is, the command you used and error you obtained. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Environment used 16 | 2. Type of data 17 | 3. How GraphMB was installed 18 | 4. Options used 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Output messages** 24 | If applicable, add your output messages to help explain your problem. 25 | 26 | 27 | **Additional context** 28 | Add any other context about the problem here. 29 | -------------------------------------------------------------------------------- /docs/source/_build/html/_static/js/badge_only.js: -------------------------------------------------------------------------------- 1 | !function(e){var t={};function r(n){if(t[n])return t[n].exports;var o=t[n]={i:n,l:!1,exports:{}};return e[n].call(o.exports,o,o.exports,r),o.l=!0,o.exports}r.m=e,r.c=t,r.d=function(e,t,n){r.o(e,t)||Object.defineProperty(e,t,{enumerable:!0,get:n})},r.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},r.t=function(e,t){if(1&t&&(e=r(e)),8&t)return e;if(4&t&&"object"==typeof e&&e&&e.__esModule)return e;var n=Object.create(null);if(r.r(n),Object.defineProperty(n,"default",{enumerable:!0,value:e}),2&t&&"string"!=typeof e)for(var o in e)r.d(n,o,function(t){return e[t]}.bind(null,o));return n},r.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return r.d(t,"a",t),t},r.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r.p="",r(r.s=4)}({4:function(e,t,r){}}); -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from setuptools.command.install import install 3 | import os 4 | import subprocess 5 | from distutils.util import convert_path 6 | 7 | main_ns = {} 8 | ver_path = convert_path("src/graphmb/version.py") 9 | with open(ver_path) as ver_file: 10 | exec(ver_file.read(), main_ns) 11 | 12 | setup( 13 | name="graphmb", 14 | version=main_ns["__version__"], 15 | packages=["graphmb"], 16 | python_requires=">=3.8", 17 | package_dir={"": "src"}, 18 | setup_requires=["setuptools~=58.0", "wheel", "sphinx-rtd-theme", "twine"], 19 | install_requires=[ 20 | "wheel", 21 | "requests", 22 | "networkx==2.6.2", 23 | "torch==1.13.1", 24 | "tensorflow==2.11.1", 25 | "tqdm==4.61.2", 26 | "mlflow==2.6.0", 27 | "importlib_resources" 28 | 29 | ], 30 | entry_points={ 31 | "console_scripts": ["graphmb=graphmb.main:main"], 32 | }, 33 | include_package_data=True, 34 | ) 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018 The Python Packaging Authority 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /src/graphmb/data/Bacteria.ms: -------------------------------------------------------------------------------- 1 | # [Taxon Marker File] 2 | Bacteria 1 0 Bacteria 5449 [{'PF01193.19', 'PF00411.14', 'PF00416.17', 'PF01000.21', 'PF01196.14'}, {'PF00281.14', 'PF00861.17', 'PF03719.10', 'PF00828.14', 'PF00673.16', 'PF00347.18', 'PF00238.14', 'PF00333.15', 'TIGR01079', 'TIGR00967', 'PF00410.14'}, {'PF00573.17', 'PF00181.18', 'PF00366.15', 'PF00831.18', 'PF00276.15', 'PF00203.16', 'PF03947.13', 'PF00189.15', 'PF00297.17', 'PF00252.13', 'PF00237.14'}, {'PF05000.12', 'PF04998.12', 'PF00562.23', 'PF04560.15', 'PF10385.4', 'PF04997.7', 'PF04983.13', 'PF04565.11', 'PF04563.10', 'PF04561.9', 'PF00623.15'}, {'PF00572.13', 'PF00380.14'}, {'PF00298.14', 'PF00687.16', 'PF03946.9'}, {'PF01281.14', 'PF03948.9'}, {'PF08529.6', 'PF13184.1'}, {'PF00453.13', 'PF01632.14'}, {'PF00164.20', 'PF00177.16'}, {'TIGR00855', 'PF00466.15'}, {'PF02912.13', 'PF01409.15'}, {'PF00889.14', 'PF00318.15'}, {'PF00829.16', 'PF01016.14'}, {'TIGR03723', 'TIGR00329'}, {'PF01668.13'}, {'PF01250.12'}, {'PF00312.17'}, {'PF01121.15'}, {'TIGR00459'}, {'PF01245.15'}, {'TIGR00755'}, {'PF02130.12'}, {'PF02367.12'}, {'TIGR03594'}, {'PF02033.13'}, {'TIGR00615'}, {'TIGR00084'}, {'PF01018.17'}, {'PF01195.14'}, {'TIGR00019'}, {'PF01649.13'}, {'PF01795.14'}, {'TIGR00250'}, {'PF00886.14'}, {'PF06421.7'}, {'PF11987.3'}, {'PF00338.17'}, {'TIGR00392'}, {'PF01509.13'}, {'PF01746.16'}, {'PF06071.8'}, {'PF05697.8'}, {'TIGR00922'}, {'PF02978.14'}, {'PF03484.10'}, {'TIGR02075'}, {'TIGR00810'}, {'PF13603.1'}, {'PF01765.14'}, {'PF00162.14'}, {'PF12344.3'}, {'TIGR02432'}, {'TIGR00460'}, {'PF05491.8'}, {'TIGR03263'}, {'PF08459.6'}, {'TIGR00344'}] 3 | -------------------------------------------------------------------------------- /benchmarks/strong100.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | 4 | python src/graphmb/main.py --assembly data/strong100/ --outdir results/strong100/ --evalskip 100 \ 5 | --epoch 1000 --model gcn_ae --rawfeatures --batchsize 256 --gnn_alpha 0 \ 6 | --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-3 --layers_gnn 0 --negatives 0 \ 7 | --outname vae_lr1e-3_bs256 --nruns 3 --labels amber_ground_truth.tsv \ 8 | --embsize_gnn 32 --quick 9 | 10 | 11 | python src/graphmb/main.py --assembly data/strong100/ --outdir results/strong100/ --evalskip 100 \ 12 | --epoch 1000 --model gcn_ae --rawfeatures --batchsize 256 --gnn_alpha 0 \ 13 | --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-4 --layers_gnn 0 \ 14 | --outname vae_lr1e-4_bs256 --nruns 3 --labels amber_ground_truth.tsv \ 15 | --embsize_gnn 32 --skip_preclustering --quick 16 | 17 | 18 | python src/graphmb/main.py --assembly data/strong100/ --outdir results/strong100/ --evalskip 100 \ 19 | --epoch 1000 --model gcn_ae --rawfeatures --batchsize 256 --gnn_alpha 0.1 \ 20 | --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-4 --layers_gnn 0 --negatives 5 \ 21 | --outname vaegcn_lr1e-4_edgesbatch256_negs5 --nruns 3 --labels amber_ground_truth.tsv \ 22 | --embsize_gnn 32 --batchtype edges --skip_preclustering --quick 23 | 24 | 25 | python src/graphmb/main.py --assembly data/strong100/ --outdir results/strong100/ --evalskip 100 \ 26 | --epoch 1000 --model gcn_ae --rawfeatures --batchsize 256 --gnn_alpha 0.1 \ 27 | --ae_alpha 1 --scg_alpha 0.3 --lr_gnn 1e-4 --layers_gnn 0 --negatives 5 \ 28 | --outname vaegcn_lr1e-4_bs256_negs5_scg --nruns 3 --labels amber_ground_truth.tsv \ 29 | --embsize_gnn 32 --skip_preclustering --quick -------------------------------------------------------------------------------- /benchmarks/aale_gnn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #### baseline VAE lr 1e-3 4 | #python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 5 | # --epoch 1000 --model gcn_ae --rawfeatures --batchsize 256 --gnn_alpha 0 \ 6 | # --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-3 --layers_gnn 0 --negatives 0 \ 7 | # --outname vae_lr1e-3_nodesbatch256 --nruns 3 --labels amber_ground_truth_species.tsv \ 8 | # --embsize_gnn 32 --batchtype nodes 9 | 10 | 11 | 12 | #### VAE+GNN0 13 | python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 14 | --epoch 1000 --model gcn_ae --batchsize 256 --gnn_alpha 0.1 \ 15 | --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-3 --layers_gnn 1 --negatives 10 \ 16 | --outname vaegcn1_lr1e-3_negs10 --nruns 3 --labels amber_ground_truth_species.tsv \ 17 | --embsize_gnn 64 --skip_preclustering --quick 18 | 19 | python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 20 | --epoch 1000 --model gcn_ae --batchsize 256 --gnn_alpha 0.1 \ 21 | --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-3 --layers_gnn 2 --negatives 10 \ 22 | --outname vaegcn2_lr1e-3_negs10 --nruns 3 --labels amber_ground_truth_species.tsv \ 23 | --embsize_gnn 64 --skip_preclustering --quick 24 | 25 | python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 26 | --epoch 1000 --model gcn_ae --batchsize 256 --gnn_alpha 0.1 \ 27 | --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-3 --layers_gnn 2 --negatives 10 \ 28 | --outname vaegcn2_lr1e-3_negs10 --nruns 3 --labels amber_ground_truth_species.tsv \ 29 | --embsize_gnn 64 --skip_preclustering --quick -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | 2 | # Change Log 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [this sample changelog](https://gist.github.com/juampynr/4c18214a8eb554084e21d6e288a18a2c). 6 | 7 | ## [0.2.5] - 2023-07-18 8 | - Update MLflow version 9 | - Fix data dir missing from src/graphmb 10 | - Add option to write bins to fasta files (--writebins) 11 | 12 | ## [0.2.4] - 2023-03-31 13 | - Update Tensorflow and MLflow versions 14 | 15 | 16 | ## [0.2.3] - 2023-02-03 17 | 18 | ### Changed 19 | - vaepretrain parameter controls the number of epochs of VAE pre-training (default 500) 20 | 21 | ## [0.2.2] - 2023-02-02 22 | 23 | ### Fixed 24 | - Correct wheel file 25 | 26 | ## [0.2.0] - 2023-02-01 27 | 28 | ### Added 29 | - VAE, GCN, SAGE and GAT models based on tensorflow (VAEG code) 30 | - SCG-based loss to train VAE and GNNs 31 | - Output assembly stats while starting 32 | - Eliminate VAMB and DGL dependencies 33 | - PyPI installation 34 | 35 | ### Changed 36 | - Code structure changed to load data outside of DGL and use DGL only for the GraphSAGE-LSTM model 37 | - Log dataloading steps 38 | - Write cache to numpy files 39 | 40 | ### Fixed 41 | - Feature files are written to specific directories (fixes #17) 42 | 43 | ## [0.1.3] - 2022-02-25 44 | 45 | BioarXiv version 46 | 47 | `pip install . --upgrade` 48 | 49 | ### Added 50 | - Dockerfile and docker image link 51 | - Set seed option 52 | - Eval interval option 53 | 54 | ### Changed 55 | 56 | - Change default file name 57 | 58 | 59 | ### Fixed 60 | 61 | - Assembly dir option is no longer mandatory, so files can be in different directories 62 | - Logging also includes errors 63 | - DGL should no longer write a file to ~/ 64 | 65 | -------------------------------------------------------------------------------- /docs/source/intro.rst: -------------------------------------------------------------------------------- 1 | Introduction 2 | ============ 3 | 4 | GraphMB is a Metagenomic Binner developed for long-read assemblies, that takes advantage of graph machine learning 5 | algorithms and the assembly graph generated during assembly. 6 | It has been tested on (meta)flye assemblies. 7 | 8 | Installation 9 | ************ 10 | 11 | Option 1 - From wheel:: 12 | pip install https://github.com/AndreLamurias/GraphMB/releases/download/v0.1.2/graphmb-0.1.2-py3-none-any.whl 13 | 14 | 15 | Option 2 - From source:: 16 | git clone https://github.com/AndreLamurias/GraphMB 17 | cd GraphMB 18 | python -m venv venv; source venv/bin/activate # optional 19 | pip install . 20 | 21 | 22 | Option 3 - From anaconda:: 23 | conda install -c andrelamurias graphmb 24 | 25 | Option 4 - From pip 26 | 27 | pip install graphmb 28 | 29 | 30 | Input files 31 | *********** 32 | The only files required are the contigs in fasta format, and the assembly graph in GFA format. For optimal performance, 33 | the assembly graph should be generated with Flye 2.9, since it includes the number of reads mapping to each pair of 34 | contigs. Also, for better results, CheckM is run on each contig using the general Bacteria marker sets. This is optional 35 | though, you can just run the model for a number of epochs and pick the last model. 36 | By default, it runs with with early stopping. 37 | 38 | In summary, you need to have a directory with these files: 39 | - edges.fasta 40 | - assembly_graph.fasta 41 | - edges_depth.txt (output of `jgi_summarize_bam_contig_depths`) 42 | - marker_gene_stats.csv (optional) 43 | 44 | You can get an example of these files https://drive.google.com/drive/folders/1m6uTgTPUghk_q9GxfX1UNEOfn8jnIdt5?usp=sharing 45 | Download from this link and extract to data/strong100. 46 | 47 | -------------------------------------------------------------------------------- /docs/source/graphmb.rst: -------------------------------------------------------------------------------- 1 | graphmb package 2 | =============== 3 | 4 | Submodules 5 | ---------- 6 | 7 | graphmb.contigsdataset module 8 | ----------------------------- 9 | 10 | .. automodule:: graphmb.contigsdataset 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | graphmb.evaluate module 16 | ----------------------- 17 | 18 | .. automodule:: graphmb.evaluate 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | graphmb.utils module 24 | ------------------------------- 25 | 26 | .. automodule:: graphmb.utils 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | graphmb.models module 32 | ------------------------------- 33 | 34 | .. automodule:: graphmb.models 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | graphmb.gnn_models module 40 | ------------------------------- 41 | 42 | .. automodule:: graphmb.gnn_models 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | graphmb.train_ccvae module 48 | ------------------------------- 49 | 50 | .. automodule:: graphmb.train_ccvae 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | graphmb.train_gnn module 56 | ------------------------------- 57 | 58 | .. automodule:: graphmb.train_gnn 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | graphmb.graphsage\_unsupervised module 64 | -------------------------------------- 65 | 66 | .. automodule:: graphmb.graphsage_unsupervised 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | graphmb.main module 72 | ------------------- 73 | 74 | .. automodule:: graphmb.main 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | Module contents 80 | --------------- 81 | 82 | .. automodule:: graphmb 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | -------------------------------------------------------------------------------- /benchmarks/run_wwtp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wwtp=$1 3 | 4 | #### baseline VAE lr 1e-2 5 | #python src/graphmb/main.py --cuda --assembly ../data/$wwtp/ --outdir results/$wwtp/ --evalskip 100 \ 6 | # --epoch 500 --model gcn_ae --rawfeatures --batchsize 256 --gnn_alpha 0 \ 7 | # --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-2 --layers_gnn 0 --negatives 0 \ 8 | # --outname vae_lr1e-2_nb256 --nruns 3 \ 9 | # --embsize_gnn 64 --quick --batchtype nodes 10 | 11 | 12 | #### baseline VAE lr 1e-2 13 | #python src/graphmb/main.py --cuda --assembly ../data/$wwtp/ --outdir results/$wwtp/ --evalskip 100 \ 14 | # --epoch 500 --model gcn_ae --rawfeatures --batchsize 256 --gnn_alpha 0 \ 15 | # --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-3 --layers_gnn 0 --negatives 0 \ 16 | # --outname vae_lr1e-3_nb256 --nruns 3 \ 17 | # --embsize_gnn 64 --quick --batchtype nodes 18 | 19 | #### baseline VAE lr 1e-4 20 | #python src/graphmb/main.py --cuda --assembly ../data/$wwtp/ --outdir results/$wwtp/ --evalskip 100 \ 21 | # --epoch 1000 --model gcn_ae --rawfeatures --batchsize 256 --gnn_alpha 0 \ 22 | # --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-4 --layers_gnn 0 \ 23 | # --outname vae_lr1e-4_nb256 --nruns 3 \ 24 | # --skip_preclustering --embsize_gnn 64 --quick --batchtype nodes 25 | 26 | 27 | #### VAE+GNN0 28 | python src/graphmb/main.py --cuda --assembly ../data/$wwtp/ --outdir results/$wwtp/ --evalskip 100 \ 29 | --epoch 2000 --evalepoch 20 --model gcn_ae --rawfeatures --batchsize 256 --gnn_alpha 0.1 \ 30 | --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-3 --layers_gnn 0 --negatives 10 \ 31 | --outname vaegcn_lr1e-3_eb256_negs10_gnn0.1 --nruns 3 \ 32 | --skip_preclustering --embsize_gnn 64 --quick 33 | 34 | #### VAE+GNN0+SCG 35 | python src/graphmb/main.py --cuda --assembly ../data/$wwtp/ --outdir results/$wwtp/ --evalskip 10 \ 36 | --epoch 2000 --evalepoch 20 --model gcn_ae --rawfeatures --batchsize 256 --gnn_alpha 0.1 \ 37 | --ae_alpha 1 --scg_alpha 0.3 --lr_gnn 1e-3 --layers_gnn 0 --negatives 10 \ 38 | --outname vaegcn_lr1e-3_eb256_negs10_scg0.1_gnn0.3 --nruns 3 \ 39 | --skip_preclustering --embsize_gnn 64 --quick 40 | -------------------------------------------------------------------------------- /benchmarks/aale.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #### baseline VAE lr 1e-3 4 | python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 5 | --epoch 1000 --model gcn_ae --rawfeatures --batchsize 256 --gnn_alpha 0 \ 6 | --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-3 --layers_gnn 0 --negatives 0 \ 7 | --outname vae_lr1e-3_nodesbatch256 --nruns 3 --labels amber_ground_truth_species.tsv \ 8 | --embsize_gnn 64 --batchtype nodes 9 | 10 | #### baseline VAE lr 1e-4 11 | python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 12 | --epoch 1000 --model gcn_ae --rawfeatures --batchsize 256 --gnn_alpha 0 \ 13 | --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-4 --layers_gnn 0 \ 14 | --outname vae_lr1e-4_nodesbatch256 --nruns 3 --labels amber_ground_truth_species.tsv \ 15 | --embsize_gnn 32 --batchtype nodes --skip_preclustering 16 | 17 | 18 | #### VAE+GNN0 19 | python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 20 | --epoch 1000 --model gcn_ae --rawfeatures --batchsize 256 --gnn_alpha 1 \ 21 | --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-4 --layers_gnn 0 --negatives 10 \ 22 | --outname vaegcn_lr1e-4_edgesbatch256_negs10 --nruns 3 --labels amber_ground_truth_species.tsv \ 23 | --embsize_gnn 32 --batchtype edges --skip_preclustering 24 | 25 | #### VAE+GNN0+SCG 26 | python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 27 | --epoch 1000 --model gcn_ae --rawfeatures --batchsize 256 --gnn_alpha 1 \ 28 | --ae_alpha 1 --scg_alpha 1 --lr_gnn 1e-4 --layers_gnn 0 --negatives 10 \ 29 | --outname vaegcn_lr1e-4_edgesbatch256_negs10_scg1 --nruns 3 --labels amber_ground_truth_species.tsv \ 30 | --skip_preclustering 31 | 32 | #### VAE+GNN3+SCG 55+1/175+3 33 | python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 34 | --epoch 1000 --model gcn_ae --rawfeatures --batchsize 256 --gnn_alpha 1 \ 35 | --ae_alpha 1 --scg_alpha 1 --lr_gnn 1e-4 --layers_gnn 3 --negatives 10 \ 36 | --outname vaegcn_lr1e-4_edgesbatch256_negs10_noise --nruns 3 --labels amber_ground_truth_species.tsv \ 37 | --skip_preclustering -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | sys.path.insert(0, os.path.abspath("../../src/")) 17 | 18 | 19 | # -- Project information ----------------------------------------------------- 20 | 21 | project = "GraphMB" 22 | copyright = "2022, Andre Lamurias" 23 | author = "Andre Lamurias" 24 | 25 | # The full version, including alpha/beta/rc tags 26 | release = "v0.2.0" 27 | 28 | 29 | # -- General configuration --------------------------------------------------- 30 | 31 | # Add any Sphinx extension module names here, as strings. They can be 32 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 33 | # ones. 34 | # extensions = [] 35 | extensions = ["sphinx.ext.napoleon", 'sphinx.ext.autodoc', 'sphinx.ext.autosummary'] 36 | 37 | # Add any paths that contain templates here, relative to this directory. 38 | templates_path = ["_templates"] 39 | 40 | # List of patterns, relative to source directory, that match files and 41 | # directories to ignore when looking for source files. 42 | # This pattern also affects html_static_path and html_extra_path. 43 | exclude_patterns = [] 44 | 45 | 46 | # -- Options for HTML output ------------------------------------------------- 47 | 48 | # The theme to use for HTML and HTML Help pages. See the documentation for 49 | # a list of builtin themes. 50 | # 51 | html_theme = "sphinx_rtd_theme" 52 | 53 | # Add any paths that contain custom static files (such as style sheets) here, 54 | # relative to this directory. They are copied after the builtin static files, 55 | # so a file named "default.css" will overwrite the builtin "default.css". 56 | html_static_path = ["_static"] 57 | -------------------------------------------------------------------------------- /benchmarks/run_strong100.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wwtp=$1 3 | 4 | #### baseline VAE lr 1e-2 5 | #python src/graphmb/main.py --cuda --assembly ../data/$wwtp/ --outdir results/$wwtp/ --evalskip 100 \ 6 | # --epoch 500 --model gcn_ae --rawfeatures --batchsize 256 --gnn_alpha 0 \ 7 | # --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-2 --layers_gnn 0 --negatives 0 \ 8 | # --outname vae_lr1e-2_nb256 --nruns 3 \ 9 | # --embsize_gnn 64 --quick --batchtype nodes 10 | 11 | 12 | #### baseline VAE lr 1e-2 13 | #python src/graphmb/main.py --cuda --assembly ../data/$wwtp/ --outdir results/$wwtp/ --evalskip 100 \ 14 | # --epoch 500 --model gcn_ae --rawfeatures --batchsize 256 --gnn_alpha 0 \ 15 | # --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-3 --layers_gnn 0 --negatives 0 \ 16 | # --outname vae_lr1e-3_nb256 --nruns 3 \ 17 | # --embsize_gnn 64 --quick --batchtype nodes 18 | 19 | #### baseline VAE lr 1e-3 20 | #python src/graphmb/main.py --cuda --assembly ../data/$wwtp/ --outdir results/$wwtp/ --evalskip 100 \ 21 | # --epoch 1000 --model gcn_ae --rawfeatures --batchsize 256 --gnn_alpha 0 \ 22 | # --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-3 --layers_gnn 0 \ 23 | # --outname vae_lr1e-3_nb256 --nruns 3 --labels amber_ground_truth.tsv \ 24 | # --skip_preclustering --embsize_gnn 64 --quick --batchtype nodes 25 | 26 | 27 | #### VAE+GNN0 28 | python src/graphmb/main.py --cuda --assembly ../data/$wwtp/ --outdir results/$wwtp/ --evalskip 100 \ 29 | --epoch 2000 --evalepoch 20 --model gcn_ae --rawfeatures --batchsize 256 --gnn_alpha 0.1 \ 30 | --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-3 --layers_gnn 0 --negatives 10 \ 31 | --outname vaegcn_lr1e-3_nb256_negs10_gnn0.1 --nruns 3 --labels amber_ground_truth.tsv \ 32 | --skip_preclustering --embsize_gnn 64 --batchtype nodes 33 | 34 | #### VAE+GNN0+SCG 35 | python src/graphmb/main.py --cuda --assembly ../data/$wwtp/ --outdir results/$wwtp/ --evalskip 10 \ 36 | --epoch 2000 --evalepoch 20 --model gcn_ae --rawfeatures --batchsize 256 --gnn_alpha 0.1 \ 37 | --ae_alpha 1 --scg_alpha 0.3 --lr_gnn 1e-3 --layers_gnn 0 --negatives 10 \ 38 | --outname vaegcn_lr1e-3_nb256_negs10_scg0.1_gnn0.3 --nruns 3 --labels amber_ground_truth.tsv \ 39 | --skip_preclustering --embsize_gnn 64 --batchtype nodes 40 | -------------------------------------------------------------------------------- /docs/source/development.rst: -------------------------------------------------------------------------------- 1 | Development 2 | =========== 3 | 4 | Code structure 5 | **************** 6 | 7 | GraphMB contains options to experiment with model architecture and training, 8 | as well as with pre-processing of data and post-processing of the results. 9 | The core of GraphMB are deep learning models that process contigs into an 10 | embedding space. 11 | 12 | The files **models.py**, **gnn_models.py**, and **layers.py** contain the 13 | tensorflow models used on version 0.2. 14 | These fails also contain the trainer helper function and loss functions. 15 | The files **train_ccvae.py** and **train_gnn.py** contain the training loops of those models. 16 | **graphsage_unsupervised.py** contains the model used by v0.1, while **graphmb1.py** 17 | contains helper functions used on the initial GraphMB release (but not anymore). 18 | 19 | The file **evaluate.py** contains several evaluation metrics, and a function to 20 | run a clustering algorithm on the embeddings and evaluate the output. 21 | The main clustering algorithm is in **vamb_clustering.py** as was originally developed 22 | for VAMB. The file **amber_eval.py** is adapted from the AMBER evaluation tool, to run 23 | the same metrics as that tool. 24 | 25 | The file **contigsdataset.py** contains the code to read, pre-process and save a 26 | set of contigs, along with their assembly graph, depth, single-copy marker genes 27 | and embeddings. It also computes several stats on a dataset. 28 | The file **dgl_dataset.py** contains code to convert the v0.2 AssemblyDataset class 29 | to the one used by v0.1. 30 | The file **utils.py** contains some additional helper functions that did not fit 31 | elsewhere. 32 | 33 | Finally, all the running parameters are stored in **arg_options.py**. 34 | The main file **main.py** reads these parameters and executes the experiments 35 | accordingly. **version.py** is used only to store the current version number. 36 | **setup.py** defines the dependencies and other parameters to build a new version. 37 | 38 | 39 | 40 | Typical workflow for new versions 41 | ********************************** 42 | Useful commands to build new version: 43 | 44 | .. code-block:: bash 45 | 46 | python setup.py sdist bdist_wheel 47 | python -m twine upload dist/graphmb-X.X.X* 48 | cd docs; make html 49 | sudo docker build . -t andrelamurias/graphmb:X.X.X 50 | sudo docker push 51 | 52 | 53 | Documentation 54 | **************** 55 | The documentation is stored in docs/ and uses Sphinx to generate HTML pages. 56 | The docstring of each funtion and class are automatically added. If new source 57 | code files are added, these should be added too to docs/source/graphmb.rst. 58 | -------------------------------------------------------------------------------- /docs/source/_build/html/_static/js/html5shiv.min.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @preserve HTML5 Shiv 3.7.3 | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed 3 | */ 4 | !function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=t.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=t.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),t.elements=c+" "+a,j(b)}function f(a){var b=s[a[q]];return b||(b={},r++,a[q]=r,s[r]=b),b}function g(a,c,d){if(c||(c=b),l)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():p.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||o.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),l)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return t.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(t,b.frag)}function j(a){a||(a=b);var d=f(a);return!t.shivCSS||k||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),l||i(a,d),a}var k,l,m="3.7.3-pre",n=a.html5||{},o=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,p=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,q="_html5shiv",r=0,s={};!function(){try{var a=b.createElement("a");a.innerHTML="",k="hidden"in a,l=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){k=!0,l=!0}}();var t={elements:n.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:m,shivCSS:n.shivCSS!==!1,supportsUnknownElements:l,shivMethods:n.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=t,j(b),"object"==typeof module&&module.exports&&(module.exports=t)}("undefined"!=typeof window?window:this,document); -------------------------------------------------------------------------------- /docs/source/_build/html/_static/css/badge_only.css: -------------------------------------------------------------------------------- 1 | .fa:before{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:after,.clearfix:before{display:table;content:""}.clearfix:after{clear:both}@font-face{font-family:FontAwesome;font-style:normal;font-weight:400;src:url(fonts/fontawesome-webfont.eot?674f50d287a8c48dc19ba404d20fe713?#iefix) format("embedded-opentype"),url(fonts/fontawesome-webfont.woff2?af7ae505a9eed503f8b8e6982036873e) format("woff2"),url(fonts/fontawesome-webfont.woff?fee66e712a8a08eef5805a46892932ad) format("woff"),url(fonts/fontawesome-webfont.ttf?b06871f281fee6b241d60582ae9369b9) format("truetype"),url(fonts/fontawesome-webfont.svg?912ec66d7572ff821749319396470bde#FontAwesome) format("svg")}.fa:before{font-family:FontAwesome;font-style:normal;font-weight:400;line-height:1}.fa:before,a .fa{text-decoration:inherit}.fa:before,a .fa,li .fa{display:inline-block}li .fa-large:before{width:1.875em}ul.fas{list-style-type:none;margin-left:2em;text-indent:-.8em}ul.fas li .fa{width:.8em}ul.fas li .fa-large:before{vertical-align:baseline}.fa-book:before,.icon-book:before{content:"\f02d"}.fa-caret-down:before,.icon-caret-down:before{content:"\f0d7"}.fa-caret-up:before,.icon-caret-up:before{content:"\f0d8"}.fa-caret-left:before,.icon-caret-left:before{content:"\f0d9"}.fa-caret-right:before,.icon-caret-right:before{content:"\f0da"}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;z-index:400}.rst-versions a{color:#2980b9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27ae60}.rst-versions .rst-current-version:after{clear:both;content:"";display:block}.rst-versions .rst-current-version .fa{color:#fcfcfc}.rst-versions .rst-current-version .fa-book,.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#e74c3c;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#f1c40f;color:#000}.rst-versions.shift-up{height:auto;max-height:100%;overflow-y:scroll}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:grey;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:1px solid #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px;max-height:90%}.rst-versions.rst-badge .fa-book,.rst-versions.rst-badge .icon-book{float:none;line-height:30px}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book,.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge>.rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width:768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}} -------------------------------------------------------------------------------- /benchmarks/aale_lr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wwtp=$1 3 | 4 | 5 | #### baseline VAE lr 1e-3 6 | #python src/graphmb/main.py --cuda --assembly ../data/${wwtp}/ --outdir results/${wwtp}/ --evalskip 100 \ 7 | # --epoch 1000 --model gcn_ae --batchsize 512 --gnn_alpha 0 \ 8 | # --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-3 --layers_gnn 0 --negatives 0 \ 9 | # --outname vae_lr1e-3_b512 --nruns 3 \ 10 | # --embsize_gnn 32 --batchtype edges --rawfeatures 11 | 12 | 13 | 14 | #### VAE+GNN0 15 | python src/graphmb/main.py --cuda --assembly ../data/${wwtp}/ --outdir results/${wwtp}/ --evalskip 100 \ 16 | --epoch 1000 --model gcn_ae --batchsize 512 --gnn_alpha 0.1 \ 17 | --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-1 --layers_gnn 0 --negatives 10 \ 18 | --outname vaegcn_lr1e-1_eb512_negs10 --nruns 3 \ 19 | --embsize_gnn 64 --skip_preclustering --quick --rawfeatures 20 | 21 | #### VAE+GNN0+SCG 22 | #python src/graphmb/main.py --cuda --assembly ../data/${wwtp}/ --outdir results/${wwtp}/ --evalskip 100 \ 23 | # --epoch 1000 --model gcn_ae --batchsize 512 --gnn_alpha 1 \ 24 | # --ae_alpha 1 --scg_alpha 1 --lr_gnn 1e-4 --layers_gnn 0 --negatives 10 \ 25 | # --outname vaegcn_lr1e-4_edgesbatch512_negs10_scg1 --nruns 3 \ 26 | # --skip_preclustering 27 | 28 | #### VAE+GNN0 29 | python src/graphmb/main.py --cuda --assembly ../data/${wwtp}/ --outdir results/${wwtp}/ --evalskip 100 \ 30 | --epoch 1000 --model gcn_ae --batchsize 512 --gnn_alpha 0.1 \ 31 | --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-2 --layers_gnn 0 --negatives 10 \ 32 | --outname vaegcn_lr1e-2_eb512_negs10 --nruns 3 \ 33 | --embsize_gnn 64 --skip_preclustering --quick --rawfeatures 34 | 35 | #### VAE+GNN0 36 | python src/graphmb/main.py --cuda --assembly ../data/${wwtp}/ --outdir results/${wwtp}/ --evalskip 100 \ 37 | --epoch 1000 --model gcn_ae --batchsize 512 --gnn_alpha 0.1 \ 38 | --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-3 --layers_gnn 0 --negatives 10 \ 39 | --outname vaegcn_lr1e-3_eb512_negs10 --nruns 3 \ 40 | --embsize_gnn 64 --skip_preclustering --quick --rawfeatures 41 | 42 | #### VAE+GNN0 43 | python src/graphmb/main.py --cuda --assembly ../data/${wwtp}/ --outdir results/${wwtp}/ --evalskip 100 \ 44 | --epoch 1000 --model gcn_ae --batchsize 512 --gnn_alpha 1 \ 45 | --ae_alpha 1 --scg_alpha 0 --lr_gnn 5e-4 --layers_gnn 0 --negatives 10 \ 46 | --outname vaegcn_lr5e-4_eb512_negs10 --nruns 3 \ 47 | --embsize_gnn 64 --skip_preclustering --quick --rawfeatures 48 | 49 | #### VAE+GNN0 50 | python src/graphmb/main.py --cuda --assembly ../data/${wwtp}/ --outdir results/${wwtp}/ --evalskip 100 \ 51 | --epoch 1000 --model gcn_ae --batchsize 512 --gnn_alpha 1 \ 52 | --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-4 --layers_gnn 0 --negatives 10 \ 53 | --outname vaegcn_lr1e-4_eb512_negs10 --nruns 3 \ 54 | --embsize_gnn 64 --skip_preclustering --quick --rawfeatures 55 | 56 | #### VAE+GNN0 57 | python src/graphmb/main.py --cuda --assembly ../data/${wwtp}/ --outdir results/${wwtp}/ --evalskip 100 \ 58 | --epoch 1000 --model gcn_ae --batchsize 512 --gnn_alpha 1 \ 59 | --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-5 --layers_gnn 0 --negatives 10 \ 60 | --outname vaegcn_lr1e-5_eb512_negs10 --nruns 3 \ 61 | --embsize_gnn 64 --skip_preclustering --quick --rawfeatures -------------------------------------------------------------------------------- /benchmarks/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x 3 | set -e 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | dataset=$1 7 | #source venv/bin/activate 8 | #python src/graphmb/main.py --cuda --assembly ../data/$dataset/ --outdir results/$dataset/ \ 9 | # --model_name vae --markers marker_gene_stats.tsv --epoch 500 \ 10 | # --nruns 1 --evalepochs 20 --outname vae --batchsize 256 \ 11 | # --evalskip 200 --labels amber_ground_truth_species.tsv 12 | #mv results/$dataset/vae_best_embs.pickle ../data/$dataset/ 13 | 14 | #python src/graphmb/main.py --cuda --assembly ../data/$dataset/ --outdir results/$dataset/ \ 15 | # --model_name sage_lstm --markers marker_gene_stats.tsv --epoch 500 \ 16 | # --nruns 1 --evalepochs 20 --outname sagelstm --skip_preclustering \ 17 | # --features vae_best_embs.pickle --concat_features --evalskip 10 --labels amber_ground_truth_species.tsv 18 | 19 | #python src/graphmb/main.py --cuda --assembly ../data/$dataset/ --outdir results/$dataset/ \ 20 | # --model_name gcn --markers marker_gene_stats.tsv --epoch 500 \ 21 | # --nruns 1 --evalepochs 20 --outname gcn \ 22 | # --features vae_best_embs.pickle --concat_features --evalskip 10 --labels amber_ground_truth_species.tsv 23 | 24 | 25 | python src/graphmb/main.py --cuda --assembly ../data/$dataset/ --outdir results/$dataset/ \ 26 | --model_name gcn_ae --markers marker_gene_stats.tsv --epoch 1000 \ 27 | --nruns 1 --evalepochs 20 --outname gcnae_nognn \ 28 | --layers_gnn 0 --evalskip 200 --batchsize 256 --labels amber_ground_truth_species.tsv 29 | 30 | python src/graphmb/main.py --cuda --assembly ../data/$dataset/ --outdir results/$dataset/ \ 31 | --model_name gcn_ae --markers marker_gene_stats.tsv --epoch 1000 \ 32 | --nruns 1 --evalepochs 20 --outname gcnae \ 33 | --concat_features --evalskip 200 --batchsize 256 --labels amber_ground_truth_species.tsv 34 | 35 | ## with GTDB 36 | 37 | python src/graphmb/main.py --cuda --assembly ../data/$dataset/ --outdir results/$dataset/ \ 38 | --model_name vae --markers gtdb --epoch 500 \ 39 | --nruns 1 --evalepochs 20 --outname vae_gtdb --batchsize 256 \ 40 | --evalskip 200 --labels amber_ground_truth_species.tsv 41 | mv results/$dataset/vae_best_embs.pickle ../data/$dataset/ 42 | 43 | python src/graphmb/main.py --cuda --assembly ../data/$dataset/ --outdir results/$dataset/ \ 44 | --model_name sage_lstm --markers gtdb --epoch 500 \ 45 | --nruns 1 --evalepochs 20 --outname sagelstm_gtdb --skip_preclustering \ 46 | --features vae_best_embs.pickle --concat_features --evalskip 10 --labels amber_ground_truth_species.tsv 47 | 48 | python src/graphmb/main.py --cuda --assembly ../data/$dataset/ --outdir results/$dataset/ \ 49 | --model_name gcn --markers gtdb --epoch 500 \ 50 | --nruns 1 --evalepochs 20 --outname gcn_gtdb \ 51 | --features vae_best_embs.pickle --concat_features --evalskip 10 --labels amber_ground_truth_species.tsv 52 | 53 | python src/graphmb/main.py --cuda --assembly ../data/$dataset/ --outdir results/$dataset/ \ 54 | --model_name gcn_ae --markers marker_gene_stats.tsv --epoch 1000 \ 55 | --nruns 1 --evalepochs 20 --outname gcnae_gtdb_nognn \ 56 | --layers_gnn 0 --evalskip 200 --batchsize 256 --labels amber_ground_truth_species.tsv 57 | 58 | python src/graphmb/main.py --cuda --assembly ../data/$dataset/ --outdir results/$dataset/ \ 59 | --model_name gcn_ae --markers gtdb --epoch 1000 \ 60 | --nruns 1 --evalepochs 20 --outname gcnae_gtdb \ 61 | --concat_features --evalskip 200 --batchsize 256 --labels amber_ground_truth_species.tsv -------------------------------------------------------------------------------- /docs/source/_build/html/_static/js/html5shiv-printshiv.min.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @preserve HTML5 Shiv 3.7.3-pre | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed 3 | */ 4 | !function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=y.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=y.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),y.elements=c+" "+a,j(b)}function f(a){var b=x[a[v]];return b||(b={},w++,a[v]=w,x[w]=b),b}function g(a,c,d){if(c||(c=b),q)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():u.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||t.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),q)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return y.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(y,b.frag)}function j(a){a||(a=b);var d=f(a);return!y.shivCSS||p||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),q||i(a,d),a}function k(a){for(var b,c=a.getElementsByTagName("*"),e=c.length,f=RegExp("^(?:"+d().join("|")+")$","i"),g=[];e--;)b=c[e],f.test(b.nodeName)&&g.push(b.applyElement(l(b)));return g}function l(a){for(var b,c=a.attributes,d=c.length,e=a.ownerDocument.createElement(A+":"+a.nodeName);d--;)b=c[d],b.specified&&e.setAttribute(b.nodeName,b.nodeValue);return e.style.cssText=a.style.cssText,e}function m(a){for(var b,c=a.split("{"),e=c.length,f=RegExp("(^|[\\s,>+~])("+d().join("|")+")(?=[[\\s,>+~#.:]|$)","gi"),g="$1"+A+"\\:$2";e--;)b=c[e]=c[e].split("}"),b[b.length-1]=b[b.length-1].replace(f,g),c[e]=b.join("}");return c.join("{")}function n(a){for(var b=a.length;b--;)a[b].removeNode()}function o(a){function b(){clearTimeout(g._removeSheetTimer),d&&d.removeNode(!0),d=null}var d,e,g=f(a),h=a.namespaces,i=a.parentWindow;return!B||a.printShived?a:("undefined"==typeof h[A]&&h.add(A),i.attachEvent("onbeforeprint",function(){b();for(var f,g,h,i=a.styleSheets,j=[],l=i.length,n=Array(l);l--;)n[l]=i[l];for(;h=n.pop();)if(!h.disabled&&z.test(h.media)){try{f=h.imports,g=f.length}catch(o){g=0}for(l=0;g>l;l++)n.push(f[l]);try{j.push(h.cssText)}catch(o){}}j=m(j.reverse().join("")),e=k(a),d=c(a,j)}),i.attachEvent("onafterprint",function(){n(e),clearTimeout(g._removeSheetTimer),g._removeSheetTimer=setTimeout(b,500)}),a.printShived=!0,a)}var p,q,r="3.7.3",s=a.html5||{},t=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,u=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,v="_html5shiv",w=0,x={};!function(){try{var a=b.createElement("a");a.innerHTML="",p="hidden"in a,q=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){p=!0,q=!0}}();var y={elements:s.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:r,shivCSS:s.shivCSS!==!1,supportsUnknownElements:q,shivMethods:s.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=y,j(b);var z=/^$|\b(?:all|print)\b/,A="html5shiv",B=!q&&function(){var c=b.documentElement;return!("undefined"==typeof b.namespaces||"undefined"==typeof b.parentWindow||"undefined"==typeof c.applyElement||"undefined"==typeof c.removeNode||"undefined"==typeof a.attachEvent)}();y.type+=" print",y.shivPrint=o,o(b),"object"==typeof module&&module.exports&&(module.exports=y)}("undefined"!=typeof window?window:this,document); -------------------------------------------------------------------------------- /benchmarks/runs.sh: -------------------------------------------------------------------------------- 1 | # VAE 2 | 3 | # multiple runs with filtered nodes (SCG and connected only) 4 | ############################################################ 5 | 6 | 7 | dataset=$2 8 | # GCN on pre-trained AE features 9 | export CUDA_VISIBLE_DEVICES=$1 10 | quick=" --quick --nruns 5" 11 | #quick="" 12 | addname="_fixloss" 13 | 14 | #python src/graphmb/main.py --cuda --assembly ../data/$dataset --outdir results/$dataset --model_name vae \ 15 | # --markers marker_gene_stats.tsv --batchsize 256 --epoch 500 --lr_vae 1e-3 \ 16 | # --nruns 5 --evalepochs 20 --outname vae_baseline 17 | 18 | python src/graphmb/main.py --cuda --assembly ../data/$dataset/ --outdir results/$dataset/ \ 19 | --model_name gcn --markers marker_gene_stats.tsv --epoch 500 \ 20 | --evalepochs 20 --outname gcn_lr1e-4$addname --lr_gnn 1e-4 \ 21 | --features vae_baseline_best_embs.pickle --concat_features --evalskip 200 $quick 22 | 23 | python src/graphmb/main.py --cuda --assembly ../data/$dataset/ --outdir results/$dataset/ \ 24 | --model_name gcn --markers marker_gene_stats.tsv --epoch 500 \ 25 | --evalepochs 20 --outname gcn_lr1e-3$addname --lr_gnn 1e-3 \ 26 | --features vae_best_embs.pickle --concat_features --evalskip 200 $quick 27 | 28 | # VAE+GCN model (separate losses) 29 | python src/graphmb/main.py --cuda --assembly ../data/$dataset/ --outdir results/$dataset/ \ 30 | --model_name gcn_ae --markers marker_gene_stats.tsv --epoch 500 \ 31 | --evalepochs 20 --outname gcnae_lr1e-4$addname --lr_gnn 1e-4 \ 32 | --batchsize 256 --rawfeatures --gnn_alpha 0.5 --scg_alpha 100 --concat_features \ 33 | --evalskip 100 --skip_preclustering $quick 34 | 35 | 36 | # GVAE model, reconloss 37 | python src/graphmb/main.py --cuda --assembly ../data/$dataset/ --outdir results/$dataset/ \ 38 | --model_name gcn_decode --markers marker_gene_stats.tsv --epoch 500 \ 39 | --evalepochs 20 --outname gcndecode_lr1e-4$addname --lr_gnn 1e-4 --batchsize 256 \ 40 | --rawfeatures --gnn_alpha 0.5 --scg_alpha 0 --evalskip 100 --skip_preclustering --layers_gnn 3 $quick 41 | 42 | 43 | 44 | # Using only top 10% of edges (separate losses) 45 | python src/graphmb/main.py --cuda --assembly ../data/$dataset/ --outdir results/$dataset/ \ 46 | --model_name gcn_ae --markers marker_gene_stats.tsv --epoch 500 \ 47 | --evalepochs 20 --outname gcnae_lr1e-4_binarize$addname --lr_gnn 1e-4 \ 48 | --batchsize 256 --rawfeatures --gnn_alpha 0.5 --scg_alpha 100 --concat_features \ 49 | --evalskip 100 --skip_preclustering --binarize $quick 50 | 51 | 52 | # VAE+GCN augmented graph 53 | python src/graphmb/main.py --cuda --assembly ../data/$dataset/ --outdir results/$dataset/ \ 54 | --model_name gcn_aug --markers marker_gene_stats.tsv --epoch 500 \ 55 | --evalepochs 20 --outname gcnaug_lr1e-4$addname --concat_features \ 56 | --lr_gnn 1e-4 --rawfeatures --evalskip 100 $quick 57 | 58 | 59 | ### extra experiments 60 | 61 | ### no edges 62 | python src/graphmb/main.py --cuda --assembly ../data/$dataset/ --outdir results/$dataset/ \ 63 | --model_name gcn --markers marker_gene_stats.tsv --epoch 500 \ 64 | --evalepochs 20 --outname gcn_lr1e-3_noedges$addname --noedges --lr_gnn 1e-3 \ 65 | --features vae_best_embs.pickle --concat_features --evalskip 200 $quick 66 | python src/graphmb/main.py --cuda --assembly ../data/$dataset/ --outdir results/$dataset/ \ 67 | --model_name gcn_ae --markers marker_gene_stats.tsv --epoch 500 \ 68 | --evalepochs 20 --outname gcnae_lr1e-4_noedges$addname --noedges --lr_gnn 1e-4 \ 69 | --batchsize 256 --rawfeatures --gnn_alpha 0.5 --scg_alpha 100 --concat_features \ 70 | --evalskip 100 --skip_preclustering $quick 71 | # GVAE model, reconloss 72 | python src/graphmb/main.py --cuda --assembly ../data/$dataset/ --outdir results/$dataset/ \ 73 | --model_name gcn_decode --markers marker_gene_stats.tsv --epoch 500 \ 74 | --evalepochs 20 --outname gcndecode_lr1e-4_noedges$addname --noedges --lr_gnn 1e-4 --batchsize 256 \ 75 | --rawfeatures --gnn_alpha 0.5 --scg_alpha 0 --evalskip 100 --skip_preclustering --layers_gnn 3 $quick 76 | python src/graphmb/main.py --cuda --assembly ../data/$dataset/ --outdir results/$dataset/ \ 77 | --model_name gcn_aug --markers marker_gene_stats.tsv --epoch 500 \ 78 | --evalepochs 20 --outname gcnaug_lr1e-4_noedges$addname --noedges --concat_features \ 79 | --lr_gnn 1e-4 --rawfeatures --evalskip 100 $quick -------------------------------------------------------------------------------- /docs/source/_build/html/_static/js/theme.js: -------------------------------------------------------------------------------- 1 | !function(n){var e={};function t(i){if(e[i])return e[i].exports;var o=e[i]={i:i,l:!1,exports:{}};return n[i].call(o.exports,o,o.exports,t),o.l=!0,o.exports}t.m=n,t.c=e,t.d=function(n,e,i){t.o(n,e)||Object.defineProperty(n,e,{enumerable:!0,get:i})},t.r=function(n){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(n,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(n,"__esModule",{value:!0})},t.t=function(n,e){if(1&e&&(n=t(n)),8&e)return n;if(4&e&&"object"==typeof n&&n&&n.__esModule)return n;var i=Object.create(null);if(t.r(i),Object.defineProperty(i,"default",{enumerable:!0,value:n}),2&e&&"string"!=typeof n)for(var o in n)t.d(i,o,function(e){return n[e]}.bind(null,o));return i},t.n=function(n){var e=n&&n.__esModule?function(){return n.default}:function(){return n};return t.d(e,"a",e),e},t.o=function(n,e){return Object.prototype.hasOwnProperty.call(n,e)},t.p="",t(t.s=0)}([function(n,e,t){t(1),n.exports=t(3)},function(n,e,t){(function(){var e="undefined"!=typeof window?window.jQuery:t(2);n.exports.ThemeNav={navBar:null,win:null,winScroll:!1,winResize:!1,linkScroll:!1,winPosition:0,winHeight:null,docHeight:null,isRunning:!1,enable:function(n){var t=this;void 0===n&&(n=!0),t.isRunning||(t.isRunning=!0,e((function(e){t.init(e),t.reset(),t.win.on("hashchange",t.reset),n&&t.win.on("scroll",(function(){t.linkScroll||t.winScroll||(t.winScroll=!0,requestAnimationFrame((function(){t.onScroll()})))})),t.win.on("resize",(function(){t.winResize||(t.winResize=!0,requestAnimationFrame((function(){t.onResize()})))})),t.onResize()})))},enableSticky:function(){this.enable(!0)},init:function(n){n(document);var e=this;this.navBar=n("div.wy-side-scroll:first"),this.win=n(window),n(document).on("click","[data-toggle='wy-nav-top']",(function(){n("[data-toggle='wy-nav-shift']").toggleClass("shift"),n("[data-toggle='rst-versions']").toggleClass("shift")})).on("click",".wy-menu-vertical .current ul li a",(function(){var t=n(this);n("[data-toggle='wy-nav-shift']").removeClass("shift"),n("[data-toggle='rst-versions']").toggleClass("shift"),e.toggleCurrent(t),e.hashChange()})).on("click","[data-toggle='rst-current-version']",(function(){n("[data-toggle='rst-versions']").toggleClass("shift-up")})),n("table.docutils:not(.field-list,.footnote,.citation)").wrap("
"),n("table.docutils.footnote").wrap("
"),n("table.docutils.citation").wrap("
"),n(".wy-menu-vertical ul").not(".simple").siblings("a").each((function(){var t=n(this);expand=n(''),expand.on("click",(function(n){return e.toggleCurrent(t),n.stopPropagation(),!1})),t.prepend(expand)}))},reset:function(){var n=encodeURI(window.location.hash)||"#";try{var e=$(".wy-menu-vertical"),t=e.find('[href="'+n+'"]');if(0===t.length){var i=$('.document [id="'+n.substring(1)+'"]').closest("div.section");0===(t=e.find('[href="#'+i.attr("id")+'"]')).length&&(t=e.find('[href="#"]'))}if(t.length>0){$(".wy-menu-vertical .current").removeClass("current").attr("aria-expanded","false"),t.addClass("current").attr("aria-expanded","true"),t.closest("li.toctree-l1").parent().addClass("current").attr("aria-expanded","true");for(let n=1;n<=10;n++)t.closest("li.toctree-l"+n).addClass("current").attr("aria-expanded","true");t[0].scrollIntoView()}}catch(n){console.log("Error expanding nav for anchor",n)}},onScroll:function(){this.winScroll=!1;var n=this.win.scrollTop(),e=n+this.winHeight,t=this.navBar.scrollTop()+(n-this.winPosition);n<0||e>this.docHeight||(this.navBar.scrollTop(t),this.winPosition=n)},onResize:function(){this.winResize=!1,this.winHeight=this.win.height(),this.docHeight=$(document).height()},hashChange:function(){this.linkScroll=!0,this.win.one("hashchange",(function(){this.linkScroll=!1}))},toggleCurrent:function(n){var e=n.closest("li");e.siblings("li.current").removeClass("current").attr("aria-expanded","false"),e.siblings().find("li.current").removeClass("current").attr("aria-expanded","false");var t=e.find("> ul li");t.length&&(t.removeClass("current").attr("aria-expanded","false"),e.toggleClass("current").attr("aria-expanded",(function(n,e){return"true"==e?"false":"true"})))}},"undefined"!=typeof window&&(window.SphinxRtdTheme={Navigation:n.exports.ThemeNav,StickyNav:n.exports.ThemeNav}),function(){for(var n=0,e=["ms","moz","webkit","o"],t=0;t -o --meta``` 58 | 2. Filter and polish assembly if necessary (or extract edge sequences and polish edge sequences instead) 59 | 3. Convert assembly graph to contig-based graph if you want to use full contig instead of edges 60 | 4. Run CheckM on sequences with Bacteria markers:: 61 | 62 | mkdir edges 63 | cd edges; cat ../assembly.fasta | awk '{ if (substr($0, 1, 1)==">") {filename=(substr($0,2) ".fa")} print $0 > filename }'; cd .. 64 | find edges/ -name "* *" -type f | rename 's/ /_/g' 65 | # evaluate edges 66 | checkm taxonomy_wf -x fa domain Bacteria edges/ checkm_edges/ 67 | checkm qa checkm_edges/Bacteria.ms checkm_edges/ -f checkm_edges_polished_results.txt --tab_table -o 2 68 | 69 | 5. Get abundances with `jgi_summarize_bam_contig_depths`:: 70 | 71 | minimap2 -I 64GB -d assembly.mmi assembly.fasta # make index 72 | minimap2 -I 64GB -ax map-ont assembly.mmi > assembly.sam 73 | samtools sort assembly.sam > assembly.bam 74 | jgi_summarize_bam_contig_depths --outputDepth assembly_depth.txt assembly.bam 75 | 76 | 6. Now you should have all the files to run GraphMB 77 | 78 | We have only tested GraphMB on flye assemblies. Flye generates a repeat graph where the nodes do not correspond to full contigs. 79 | Depending on your setup, you need to either use the edges as contigs. 80 | 81 | Parameters 82 | **************** 83 | GraphMB contains many parameters to run various experiments, from changing the data inputs, to architecture of the 84 | model, training loss, data preprocessing, and output type. 85 | The defaults were chosen to obtain the published results on the WWTP datasets, but may require some tuning for different 86 | datasets and scenarios. 87 | The full list of parameters is below but this section focused on the most relevant ones. 88 | 89 | assembly, assembly_name, graph_file, features, labels, markers, depth 90 | """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" 91 | If --assembly is given, that is used as the base data directory. 92 | Every other path is in relation to that directory. 93 | Otherwise, every data file path must be given in relation to the current directory 94 | 95 | Note about markers: this file is not mandatory but assumed to exist by default. 96 | This is because we have run all of our experiments with it. 97 | Without this, the number of HQ bins will be probably worse. 98 | 99 | outdir, outname 100 | """""""""""""""""" 101 | Where to write the output files, including caches, and what prefix to use. 102 | If not given, GraphMB writes to the data directory given by --assembly. 103 | 104 | reload 105 | """""""""""""""""" 106 | 107 | Ignore cache and reprocess data files. 108 | 109 | nruns, seed 110 | """""""""""""""""" 111 | Repeat experiment nrun times. Use --seed to specify the initial seed, which is changed 112 | with every run to get different results. 113 | 114 | cuda 115 | """""""""""""""""" 116 | Run model training and clustering on GPU. 117 | 118 | contignodes 119 | """""""""""""""""" 120 | If the contigs given by the --assembly_name parameter are actual contigs and not assembly graph edges, 121 | use this parameter, which will transform the assembly graph to use full contigs as nodes. 122 | 123 | model 124 | """""""""""""""""" 125 | Model to be used by GraphMB. By default it uses a Graph Convolution Network, and trains a Variational Autoencoder first 126 | to generate node features. The VAE embeddings are saved to make it faster to rerun the GCN. 127 | 128 | Other models: 129 | - sage_lstm: original GraphMB GraphSAGE model, requires DGL installation 130 | - gat and sage: alternative GNN model to GCN (VAEGBin) 131 | - gcn_ccvae, gat_ccvae, sage_ccvae: combined VAE and GNN models, trained end-to-end, or without GNN if layers_gnn=0 -------------------------------------------------------------------------------- /benchmarks/aale_bs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #### baseline VAE lr 1e-3 4 | # python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 5 | # --epoch 1000 --model gcn_ae --rawfeatures --batchsize 0 --gnn_alpha 0 \ 6 | # --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-2 --layers_gnn 0 --negatives 0 \ 7 | # --outname vae_lr1e-2_bs0 --nruns 3 \ 8 | # --embsize_gnn 64 --quick 9 | 10 | # python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 11 | # --epoch 1000 --model gcn_ae --rawfeatures --batchsize 128 --gnn_alpha 0 \ 12 | # --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-2 --layers_gnn 0 --negatives 0 \ 13 | # --outname vae_lr1e-2_bs128 --nruns 3 \ 14 | # --embsize_gnn 64 --quick 15 | 16 | # python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 17 | # --epoch 1000 --model gcn_ae --rawfeatures --batchsize 512 --gnn_alpha 0 \ 18 | # --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-2 --layers_gnn 0 --negatives 0 \ 19 | # --outname vae_lr1e-2_bs512 --nruns 3 \ 20 | # --embsize_gnn 64 --quick 21 | 22 | # python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 23 | # --epoch 1000 --model gcn_ae --rawfeatures --batchsize 1024 --gnn_alpha 0 \ 24 | # --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-2 --layers_gnn 0 --negatives 0 \ 25 | # --outname vae_lr1e-2_bs1024 --nruns 3 \ 26 | # --embsize_gnn 64 --quick 27 | 28 | # python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 29 | # --epoch 1000 --model gcn_ae --rawfeatures --batchsize 10000 --gnn_alpha 0 \ 30 | # --ae_alpha 1 --scg_alpha 0 --lr_gnn 1e-2 --layers_gnn 0 --negatives 0 \ 31 | # --outname vae_lr1e-2_bs10k --nruns 3 \ 32 | # --embsize_gnn 64 --quick 33 | 34 | 35 | # ################################################ 36 | 37 | #### VAE+GNN0 38 | python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 39 | --epoch 1000 --model gcn_ae --batchsize 0 --gnn_alpha 0.1 \ 40 | --ae_alpha 1 --scg_alpha 0.1 --lr_gnn 1e-2 --layers_gnn 0 --negatives 10 \ 41 | --outname vaegcn_lr1e-2_bs0_negs10 --nruns 3 \ 42 | --embsize_gnn 64 --skip_preclustering --quick 43 | 44 | #### VAE+GNN0 45 | python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 46 | --epoch 1000 --model gcn_ae --batchsize 128 --gnn_alpha 0.1 \ 47 | --ae_alpha 1 --scg_alpha 0.1 --lr_gnn 1e-2 --layers_gnn 0 --negatives 10 \ 48 | --outname vaegcn_lr1e-2_eb128_negs10 --nruns 3 --labels amber_ground_truth_species.tsv \ 49 | --embsize_gnn 64 --skip_preclustering --quick 50 | 51 | #### VAE+GNN0+SCG 52 | #python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 53 | # --epoch 1000 --model gcn_ae --rawfeatures --batchsize 256 --gnn_alpha 1 \ 54 | # --ae_alpha 1 --scg_alpha 1 --lr_gnn 1e-4 --layers_gnn 0 --negatives 10 \ 55 | # --outname vaegcn_lr1e-4_edgesbatch256_negs10_scg1 --nruns 3 --labels amber_ground_truth_species.tsv \ 56 | # --skip_preclustering 57 | 58 | #### VAE+GNN0 59 | python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 60 | --epoch 1000 --model gcn_ae --batchsize 256 --gnn_alpha 0.1 \ 61 | --ae_alpha 1 --scg_alpha 0.1 --lr_gnn 1e-2 --layers_gnn 0 --negatives 10 \ 62 | --outname vaegcn_lr1e-2_eb256_negs10 --nruns 3 --labels amber_ground_truth_species.tsv \ 63 | --embsize_gnn 64 --skip_preclustering --quick 64 | 65 | #### VAE+GNN0 66 | python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 67 | --epoch 1000 --model gcn_ae --batchsize 512 --gnn_alpha 0.1 \ 68 | --ae_alpha 1 --scg_alpha 0.1 --lr_gnn 1e-2 --layers_gnn 0 --negatives 10 \ 69 | --outname vaegcn_lr1e-2_eb512_negs10 --nruns 3 --labels amber_ground_truth_species.tsv \ 70 | --embsize_gnn 64 --skip_preclustering --quick 71 | 72 | ### VAE+GNN0 73 | python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 74 | --epoch 1000 --model gcn_ae --batchsize 1024 --gnn_alpha 0.1 \ 75 | --ae_alpha 1 --scg_alpha 0.1 --lr_gnn 1e-2 --layers_gnn 0 --negatives 10 \ 76 | --outname vaegcn_lr1e-2_eb1024_negs10 --nruns 3 --labels amber_ground_truth_species.tsv \ 77 | --embsize_gnn 64 --skip_preclustering --quick 78 | 79 | ### VAE+GNN0 80 | python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 81 | --epoch 1000 --model gcn_ae --batchsize 10000 --gnn_alpha 0.1 \ 82 | --ae_alpha 1 --scg_alpha 0.1 --lr_gnn 1e-2 --layers_gnn 0 --negatives 10 \ 83 | --outname vaegcn_lr1e-2_eb10k_negs10 --nruns 3 --labels amber_ground_truth_species.tsv \ 84 | --embsize_gnn 64 --skip_preclustering --quick 85 | 86 | 87 | ######################## 88 | #### VAE+GNN0 89 | python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 90 | --epoch 1000 --model gcn_ae --batchsize 0 --gnn_alpha 1 \ 91 | --ae_alpha 0 --scg_alpha 1 --lr_gnn 1e-2 --layers_gnn 0 --negatives 10 \ 92 | --outname vae0gcn_lr1e-2_bs0_negs10 --nruns 3 \ 93 | --embsize_gnn 64 --skip_preclustering --quick 94 | 95 | #### VAE+GNN0 96 | python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 97 | --epoch 1000 --model gcn_ae --batchsize 128 --gnn_alpha 1 \ 98 | --ae_alpha 0 --scg_alpha 1 --lr_gnn 1e-2 --layers_gnn 0 --negatives 10 \ 99 | --outname vae0gcn_lr1e-2_eb128_negs10 --nruns 3 --labels amber_ground_truth_species.tsv \ 100 | --embsize_gnn 64 --skip_preclustering --quick 101 | 102 | 103 | 104 | #### VAE+GNN0 105 | python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 106 | --epoch 1000 --model gcn_ae --batchsize 256 --gnn_alpha 1 \ 107 | --ae_alpha 0 --scg_alpha 1 --lr_gnn 1e-2 --layers_gnn 0 --negatives 10 \ 108 | --outname vae0gcn_lr1e-2_eb256_negs10 --nruns 3 --labels amber_ground_truth_species.tsv \ 109 | --embsize_gnn 64 --skip_preclustering --quick 110 | 111 | #### VAE+GNN0 112 | python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 113 | --epoch 1000 --model gcn_ae --batchsize 512 --gnn_alpha 1 \ 114 | --ae_alpha 0 --scg_alpha 1 --lr_gnn 1e-2 --layers_gnn 0 --negatives 10 \ 115 | --outname vae0gcn_lr1e-2_eb512_negs10 --nruns 3 --labels amber_ground_truth_species.tsv \ 116 | --embsize_gnn 64 --skip_preclustering --quick 117 | 118 | #### VAE+GNN0 119 | python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 120 | --epoch 1000 --model gcn_ae --batchsize 1024 --gnn_alpha 1 \ 121 | --ae_alpha 0 --scg_alpha 1 --lr_gnn 1e-2 --layers_gnn 0 --negatives 10 \ 122 | --outname vae0gcn_lr1e-2_eb1024_negs10 --nruns 3 --labels amber_ground_truth_species.tsv \ 123 | --embsize_gnn 64 --skip_preclustering --quick 124 | 125 | #### VAE+GNN0 126 | python src/graphmb/main.py --cuda --assembly ../data/aale/ --outdir results/aale/ --evalskip 100 \ 127 | --epoch 1000 --model gcn_ae --batchsize 10000 --gnn_alpha 1 \ 128 | --ae_alpha 0 --scg_alpha 1 --lr_gnn 1e-2 --layers_gnn 0 --negatives 10 \ 129 | --outname vae0gcn_lr1e-2_eb10k_negs10 --nruns 3 --labels amber_ground_truth_species.tsv \ 130 | --embsize_gnn 64 --skip_preclustering --quick 131 | -------------------------------------------------------------------------------- /docs/source/_build/html/searchindex.js: -------------------------------------------------------------------------------- 1 | Search.setIndex({docnames:["examples","graphmb","index","intro","modules"],envversion:{"sphinx.domains.c":2,"sphinx.domains.changeset":1,"sphinx.domains.citation":1,"sphinx.domains.cpp":4,"sphinx.domains.index":1,"sphinx.domains.javascript":2,"sphinx.domains.math":2,"sphinx.domains.python":3,"sphinx.domains.rst":2,"sphinx.domains.std":2,sphinx:56},filenames:["examples.rst","graphmb.rst","index.rst","intro.rst","modules.rst"],objects:{"":[[1,0,0,"-","graphmb"]],"graphmb.contigsdataset":[[1,1,1,"","ContigsDataset"]],"graphmb.contigsdataset.ContigsDataset":[[1,2,1,"","filter_contigs"],[1,2,1,"","filter_edges"],[1,2,1,"","get_labels_from_reads"],[1,2,1,"","has_cache"],[1,2,1,"","load"],[1,2,1,"","process"],[1,2,1,"","read_depths"],[1,2,1,"","read_gfa"],[1,2,1,"","read_seqs"],[1,2,1,"","remove_nodes"],[1,2,1,"","rename_nodes_to_index"],[1,2,1,"","save"],[1,2,1,"","set_node_mask"]],"graphmb.evaluate":[[1,3,1,"","completeness"],[1,3,1,"","contamination"],[1,3,1,"","evaluate_contig_sets"],[1,3,1,"","get_markers_to_contigs"],[1,3,1,"","main"],[1,3,1,"","read_contig_genes"],[1,3,1,"","read_marker_gene_sets"]],"graphmb.graph_functions":[[1,1,1,"","Read"],[1,1,1,"","ReadMapping"],[1,3,1,"","augment_graph"],[1,3,1,"","calculate_bin_metrics"],[1,3,1,"","cluster_embs"],[1,3,1,"","cluster_eval"],[1,3,1,"","compute_loss_para"],[1,3,1,"","connected_components"],[1,3,1,"","count_kmers"],[1,3,1,"","draw_nx_graph"],[1,3,1,"","evaluate_binning"],[1,3,1,"","get_kmer_to_id"],[1,3,1,"","open_gfa_file"],[1,3,1,"","plot_embs"],[1,3,1,"","read_contigs_scg"],[1,3,1,"","read_reads_mapping_sam"],[1,3,1,"","set_seed"],[1,3,1,"","write_components_file"]],"graphmb.graphsage_unsupervised":[[1,1,1,"","CrossEntropyLoss"],[1,1,1,"","MultiLayerNeighborWeightedSampler"],[1,1,1,"","NegativeSampler"],[1,1,1,"","NegativeSamplerWeight"],[1,1,1,"","SAGE"],[1,3,1,"","train_graphsage"]],"graphmb.graphsage_unsupervised.CrossEntropyLoss":[[1,2,1,"","forward"],[1,4,1,"","training"]],"graphmb.graphsage_unsupervised.MultiLayerNeighborWeightedSampler":[[1,2,1,"","sample_frontier"]],"graphmb.graphsage_unsupervised.SAGE":[[1,2,1,"","forward"],[1,2,1,"","inference"],[1,2,1,"","init"],[1,4,1,"","training"]],"graphmb.main":[[1,3,1,"","main"]],graphmb:[[1,0,0,"-","contigsdataset"],[1,0,0,"-","evaluate"],[1,0,0,"-","graph_functions"],[1,0,0,"-","graphsage_unsupervised"],[1,0,0,"-","main"]]},objnames:{"0":["py","module","Python module"],"1":["py","class","Python class"],"2":["py","method","Python method"],"3":["py","function","Python function"],"4":["py","attribute","Python attribute"]},objtypes:{"0":"py:module","1":"py:class","2":"py:method","3":"py:function","4":"py:attribute"},terms:{"0":[0,1,3],"1":[0,1,3],"1000":1,"1m6utgtpughk_q9gxfx1uneofn8jnidt5":3,"2":[0,1,3],"3":[1,3],"4":[0,1],"5":[],"6":1,"64gb":0,"9":3,"class":1,"default":[1,3],"do":0,"function":1,"int":1,"long":3,"new":1,"return":1,"true":1,"while":1,By:[1,3],For:[1,3],If:[0,1],In:3,It:[1,3],On:0,The:[1,3],To:0,_:0,_http:[],abov:0,abund:0,accord:1,activ:[1,3],ad:1,add:1,add_new:1,add_read:1,adj:1,advantag:3,after:0,afterward:1,agg:1,algorithm:3,align:1,all:[0,1],also:[0,3],although:1,an:[0,3],anaconda:3,andrelamuria:[0,3],ani:[1,3],ar:[0,1,3],asseembly_depth:[],assembl:[0,1,3],assembly_depth:0,assembly_graph:[0,1,3],assembly_nam:[0,1],augment_graph:1,avail:0,awk:0,ax:0,bacteria:[0,1,3],balanc:1,bam:0,base:[0,1],basenam:1,bash:0,batch_siz:1,been:3,best:1,best_hq:1,best_hq_epoch:1,better:[1,3],between:1,bin:[1,3],bin_to_contig:1,binner:3,binning_workflow:0,bitflag:1,block:[0,1],block_id:1,block_output:1,bool:1,both:0,c:3,cach:1,calcul:1,calculate_bin_metr:1,call:1,can:[0,1,3],canonical_k:1,care:1,cat:0,cd:[0,3],centroid:1,checkm:[0,1,3],checkm_edg:0,checkm_edges_polished_result:0,checkm_ev:0,clone:3,cluster:[0,1],cluster_emb:1,cluster_ev:1,cluster_featur:1,cluster_to_contig:1,clusteringalgo:1,clusteringloss:1,code:[0,1],column:1,com:[0,3],complet:1,compon:1,comput:1,compute_loss_para:1,concept:1,conda:3,connect:1,connected_compon:1,contamin:1,content:4,contig:[0,1,3],contig_mark:1,contig_marker_count:1,contig_s:1,contignam:1,contigsdataset:[2,4],convert:0,correspond:0,cotnig:1,could:1,count:1,count_kmer:1,cpu:[0,1],crossentropyloss:1,csv:3,cuda:[0,1],current:1,d:0,data:[0,1,3],dataload:1,dataset:1,decid:1,defin:1,depend:0,depth:[0,1],descript:1,destin:1,develop:3,devic:1,dgl:1,dgl_dataset:1,dgldataset:1,dglgraph:1,dict:1,diff:1,dir:1,directori:[0,3],domain:0,done:1,download:3,draw_nx_graph:1,drive:3,dropout:1,dure:3,e:1,each:[0,1,3],earli:3,edg:[0,1,3],edge_weight:1,edges_depth:[0,3],either:0,element:1,emb:1,empti:1,entir:1,epoch:[0,1,3],epsilon:1,eval:1,evalu:[0,2,4],evaluate_bin:1,evaluate_contig_set:1,everi:1,exampl:[2,3],exist:1,extra:1,extra_metr:1,extract:[0,3],f:0,fa:0,fals:1,fan_out:1,fanout:1,fashion:1,fasta:[0,1,3],faster:0,file:[0,1,2],filenam:[0,1],filter:[0,1],filter_contig:1,filter_edg:1,find:0,flye:[0,3],folder:3,force_reload:1,format:[1,3],former:1,forward:1,found:1,from:[1,3],frontier:1,full:[0,1],g:[0,1],gene:1,gener:[0,1,3],get:[0,1,3],get_kmer_to_id:1,get_labels_from_read:1,get_markers_to_contig:1,gfa:[0,1,3],gfapath:1,git:3,github:[0,3],given:1,gnn:1,googl:3,gpu:0,graph:[0,1,3],graph_fil:[0,1],graph_funct:[2,4],graphmb:[0,3],graphsag:1,graphsage_unsupervis:[2,4],guid:1,ha:[1,3],handl:1,has_cach:1,have:[0,3],here:0,hook:1,how:0,hq:1,hq_centroid:1,http:[0,3],i:[0,1],id:1,ignor:1,in_feat:1,includ:3,index:[0,2],infer:1,inform:1,init:1,input:[1,2],instal:[0,2],instanc:1,instead:[0,1],introduct:2,invers:1,jgi_summarize_bam_contig_depth:[0,3],just:[1,3],k:1,kcluster:1,kmean:1,kmer:1,kmer_to_id:1,label:1,label_to_nod:1,labels_to_nod:1,last:3,latter:1,layer:1,learn:3,limit:0,lineage_fil:1,link:3,list:1,load:1,load_graph:1,load_info:1,log:1,logger:1,logic:1,logit:1,loss:1,loss_weight:1,lr:1,m:3,machin:3,main:[0,2,4],make:0,map:[0,1,3],mapq:1,marker:[0,1,3],marker_count_bin:1,marker_fil:1,marker_gene_stat:[0,3],marker_genes_bin:1,marker_set:1,matrix:1,mean:1,member:1,mention:0,meta:[0,3],metafly:0,metagenom:3,mfg:1,might:0,min:1,min_contig:1,min_elem:1,min_map:1,minibatch:1,minimap2:0,minimum:1,minsiz:1,mkdir:0,mmi:0,model:[1,3],modul:[2,4],most:1,ms:[0,1],multilayerneighborsampl:1,multilayerneighborweightedsampl:1,n_class:1,n_hidden:1,n_layer:1,name:[0,1],nano:0,necessari:0,need:[0,1,3],neg:1,neg_graph:1,neg_shar:1,negativesampl:1,negativesamplerweight:1,neighbor:1,networkx:1,nn:1,node:[0,1],node_embed:1,node_embeddings_2dim:1,node_id:1,node_len:1,node_nam:1,node_s:1,node_titl:1,node_to_label:1,none:[1,3],note:1,now:0,ntype:1,num_epoch:1,num_neg:1,num_work:1,number:[0,1,3],numcor:0,o:0,object:1,obtain:1,one:1,onli:[0,1,3],ont:0,open:1,open_gfa_fil:1,optim:[0,3],option:[1,3],origin:1,otherwis:1,our:0,outpath:1,output:[0,1,3],outputclust:1,outputdepth:0,outputnam:1,overrid:1,overridden:1,overview:0,overwit:1,overwrit:1,own:1,packag:[2,4],page:2,pair:3,param:[0,1],paramet:1,pars:1,pass:1,path:1,perform:[1,3],pick:3,pip:[0,3],pleas:1,plot:1,plot_emb:1,po:1,point:1,polish:0,pos_graph:1,present:0,prevent:0,primari:1,print:[0,1],print_interv:1,process:1,provid:[0,1],py3:3,py:0,python:[0,1,3],qa:0,raw:0,read:[0,1,3],read_contig_gen:1,read_contigs_scg:1,read_depth:1,read_gfa:1,read_marker_gene_set:1,read_reads_mapping_sam:1,read_seq:1,readi:0,readid:1,readmap:1,reads_dict:1,reads_fil:0,readspath:1,realiz:1,recip:1,recommend:1,ref_fil:1,refer:1,regist:1,releas:3,remove_list:1,remove_nod:1,renam:0,rename_nodes_to_index:1,repeat:0,replac:1,repres:1,requir:3,result:[0,1,3],return_eid:1,root:1,row:1,run:[0,1,3],s:0,sage:1,sam:[0,1],sampl:1,sample_fronti:1,sample_weight:1,samtool:0,save:1,save_dir:1,save_graph:1,save_info:1,scenario:1,score:1,search:2,section:[0,1],seed:1,seed_nod:1,seq:1,sequenc:[0,1],set:[1,3],set_node_mask:1,set_se:1,setup:0,share:3,should:[0,1,3],silent:1,sinc:[1,3],singl:1,size:1,some:0,sort:0,sourc:3,speci:1,specifi:[0,1],src:0,stat:1,stop:3,str:1,strong100:[0,3],subclass:1,submodul:[2,4],substr:0,summari:3,supervis:1,support:1,tab_tabl:0,tabl:1,take:[1,3],taxon:1,taxonomy_wf:0,tensor:1,test:[0,3],than:1,them:1,thi:[0,1,3],though:3,thread:0,torch:1,train:[0,1],train_graphsag:1,tsv:0,tutori:1,txt:[0,3],type:[0,1],typic:2,us:[0,1,3],use_weight:1,user:1,usp:3,util:1,v0:3,valu:1,venv:3,wa:1,want:0,we:0,weight:1,wheel:3,where:[0,1],whether:1,which:1,whl:3,within:1,without:1,workflow:2,write:1,write_components_fil:1,written:1,x:[0,1],you:[0,3],your:[0,1]},titles:["Examples","graphmb package","Welcome to GraphMB\u2019s documentation!","Introduction","graphmb"],titleterms:{content:[1,2],contigsdataset:1,document:2,evalu:1,exampl:0,file:3,graph_funct:1,graphmb:[1,2,4],graphsage_unsupervis:1,indic:2,input:3,instal:3,introduct:3,main:1,modul:1,packag:1,s:2,submodul:1,tabl:2,typic:0,welcom:2,workflow:0}}) -------------------------------------------------------------------------------- /src/graphmb/arg_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | class LoadFromFile (argparse.Action): 4 | def __call__ (self, parser, namespace, values, option_string = None): 5 | with values as f: 6 | contents = f.read() 7 | # parse arguments in the file and store them in a blank namespace 8 | data = parser.parse_args(contents.split(), namespace=None) 9 | for k, v in vars(data).items(): 10 | # set arguments in the target namespace if they haven’t been set yet 11 | #if getattr(namespace, k, None) is None: 12 | #print(f"using args {k}={v}") 13 | setattr(namespace, k, v) 14 | 15 | def create_parser(): 16 | parser = argparse.ArgumentParser(description="Train graph embedding model") 17 | # input files 18 | parser.add_argument("--assembly", type=str, help="Assembly base path", required=False) 19 | parser.add_argument("--assembly_name", type=str, help="File name with contigs", default="assembly.fasta") 20 | parser.add_argument("--graph_file", type=str, help="File name with graph", default="assembly_graph.gfa") 21 | parser.add_argument("--edge_threshold", type=float, help="Remove edges with weight lower than this (keep only >=)", 22 | default=None) 23 | parser.add_argument("--depth", type=str, help="Depth file from jgi", default="assembly_depth.txt") 24 | parser.add_argument("--features", type=str, help="Features file mapping contig name to features", 25 | default="features.tsv") 26 | parser.add_argument("--labels", type=str, help="File mapping contig to label", default=None) 27 | parser.add_argument("--embs", type=str, help="No train, load embs", default=None) 28 | 29 | # model specification 30 | parser.add_argument("--model_name", type=str, 31 | help="One of the implemented models: gcn, gat, sage, sage_lstm, _ccvae variation", 32 | default="gcn") 33 | parser.add_argument("--activation", type=str, help="Activation function to use(relu, prelu, sigmoid, tanh)", 34 | default="relu") 35 | parser.add_argument("--layers_vae", type=int, help="Number of layers of the VAE", default=2) 36 | parser.add_argument("--layers_gnn", type=int, help="Number of layers of the GNN", default=3) 37 | parser.add_argument("--hidden_gnn", type=int, help="Dimension of hidden layers of GNN", default=128) 38 | parser.add_argument("--hidden_vae", type=int, help="Dimension of hidden layers of VAE", default=512) 39 | parser.add_argument("--embsize_gnn", "--zg", type=int, help="Output embedding dimension of GNN", default=32) 40 | parser.add_argument("--embsize_vae", "--zl", type=int, help="Output embedding dimension of VAE", default=64) 41 | parser.add_argument("--batchsize", type=int, help="batchsize to train the VAE", default=256) 42 | parser.add_argument("--batchtype", type=str, help="Batch type, nodes or edges", default="auto") 43 | parser.add_argument("--dropout_gnn", type=float, help="dropout of the GNN", default=0.1) 44 | parser.add_argument("--dropout_vae", type=float, help="dropout of the VAE", default=0.2) 45 | parser.add_argument("--lr_gnn", type=float, help="learning rate", default=1e-2) 46 | parser.add_argument("--lr_vae", type=float, help="learning rate", default=1e-3) 47 | parser.add_argument("--graph_alpha", type=float, help="Coeficient for graph loss", default=1) 48 | parser.add_argument("--kld_alpha", type=float, help="Coeficient for KLD loss", default=200) 49 | parser.add_argument("--ae_alpha", type=float, help="Coeficient for AE loss", default=1) 50 | parser.add_argument("--scg_alpha", type=float, help="Coeficient for SCG loss", default=1) 51 | parser.add_argument("--clusteringalgo", help="clustering algorithm: vamb, kmeans", default="vamb") 52 | parser.add_argument("--kclusters", help="Number of clusters (only for some clustering methods)", default=None) 53 | # GraphSAGE params 54 | parser.add_argument("--aggtype", help="Aggregation type for GraphSAGE (mean, pool, lstm, gcn)", default="lstm") 55 | parser.add_argument("--decoder_input", help="What to use for input to the decoder", default="vae") 56 | parser.add_argument("--vaepretrain", help="How many epochs to pretrain VAE", default=500, type=int) 57 | parser.add_argument("--ae_only", help="Do not use GNN (ae model must be used and decoder input must be ae", action="store_true") 58 | parser.add_argument("--negatives", help="Number of negatives to train GraphSAGE", default=10, type=int) 59 | parser.add_argument("--quick", help="Reduce number of nodes to run quicker", action="store_true") 60 | parser.add_argument("--classify", help="Run classification instead of clustering", action="store_true") 61 | parser.add_argument( 62 | "--fanout", help="Fan out, number of positive neighbors sampled at each level", default="10,25" 63 | ) 64 | # other training params 65 | parser.add_argument("--epoch", type=int, help="Number of epochs to train model", default=500) 66 | parser.add_argument("--print", type=int, help="Print interval during training", default=10) 67 | parser.add_argument("--evalepochs", type=int, help="Epoch interval to run eval", default=20) 68 | parser.add_argument("--evalskip", type=int, help="Skip eval of these epochs", default=50) 69 | parser.add_argument("--eval_split", type=float, help="Percentage of dataset to use for eval", default=0.0) 70 | parser.add_argument("--kmer", default=4) 71 | parser.add_argument("--rawfeatures", help="Use raw features", action="store_true") 72 | parser.add_argument("--clusteringloss", help="Train with clustering loss", action="store_true") 73 | parser.add_argument("--targetmetric", help="Metric to pick best epoch", default="hq") 74 | parser.add_argument("--concatfeatures", help="Concat learned and original features before clustering", 75 | action="store_true") 76 | parser.add_argument("--no_loss_weights", action="store_false", help="Using edge weights for loss (positive only)") 77 | parser.add_argument("--no_sample_weights", action="store_false", help="Using edge weights to sample negatives") 78 | parser.add_argument( 79 | "--early_stopping", 80 | type=float, 81 | help="Stop training if delta between last two losses is less than this", 82 | default="0.1", 83 | ) 84 | parser.add_argument("--nruns", type=int, help="Number of runs", default=1) 85 | # data processing 86 | parser.add_argument("--mincontig", type=int, help="Minimum size of input contigs", default=1000) 87 | parser.add_argument("--minbin", type=int, help="Minimum size of clusters in bp", default=200000) 88 | parser.add_argument("--mincomp", type=int, help="Minimum size of connected components", default=1) 89 | parser.add_argument("--randomize", help="Randomize graph", action="store_true") 90 | parser.add_argument("--labelgraph", help="Create graph based on labels (ignore assembly graph)", action="store_true") 91 | parser.add_argument("--binarize", help="Binarize adj matrix", action="store_true") 92 | parser.add_argument("--noedges", help="Remove all but self edges from adj matrix", action="store_true") 93 | parser.add_argument("--read_embs", help="Read embeddings from file", action="store_true") 94 | parser.add_argument("--reload", help="Reload data", action="store_true") 95 | 96 | parser.add_argument("--markers", type=str, help="""File with precomputed checkm results to eval. 97 | If not found, it will assume it does not exist.""", 98 | default="marker_gene_stats.tsv") 99 | parser.add_argument("--post", help="Output options", default="writeembs_contig2bin") 100 | parser.add_argument("--writebins", help="Write bins to fasta files", action="store_true") 101 | parser.add_argument("--skip_preclustering", help="Use precomputed checkm results to eval", action="store_true") 102 | parser.add_argument("--outname", "--outputname", help="Output (experiment) name", default="graphmb") 103 | parser.add_argument("--cuda", help="Use gpu", action="store_true") 104 | parser.add_argument("--noise", help="Use noise generator", action="store_true") 105 | parser.add_argument("--savemodel", help="Save best model to disk", action="store_true") 106 | parser.add_argument("--tsne", help="Plot tsne at checkpoints", action="store_true") 107 | parser.add_argument("--numcores", help="Number of cores to use", default=1, type=int) 108 | parser.add_argument( 109 | "--outdir", "--outputdir", help="Output dir (same as input assembly dir if not defined", default=None 110 | ) 111 | parser.add_argument("--assembly_type", help="flye or spades", default="flye") 112 | parser.add_argument("--contignodes", help="Use contigs as nodes instead of edges", action="store_true") 113 | parser.add_argument("--seed", help="Set seed", default=1, type=int) 114 | parser.add_argument("--quiet", "-q", help="Do not output epoch progress", action="store_true") 115 | parser.add_argument("--read_cache", help="Do not check assembly files, read cached files only", action="store_true") 116 | parser.add_argument("--version", "-v", help="Print version and exit", action="store_true") 117 | parser.add_argument("--loglevel", "-l", help="Log level", default="info") 118 | return parser 119 | -------------------------------------------------------------------------------- /src/graphmb/dgl_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | os.environ["DGLBACKEND"] = "pytorch" 4 | import dgl 5 | from dgl.data import DGLDataset 6 | import networkx as nx 7 | 8 | class DGLAssemblyDataset(DGLDataset): 9 | def __init__(self, assembly): 10 | self.assembly = assembly 11 | self.logger = assembly.logger 12 | super().__init__(name=assembly.name + "assembly_graph", save_dir=assembly.cache_dir, force_reload=False) 13 | 14 | def __getitem__(self, i): 15 | return self.graph 16 | 17 | def __len__(self): 18 | return 1 19 | 20 | def process(self, root=False): 21 | """Open GFA file to DGL format""" 22 | # TODO: this should receive and assembly dataset object and initialize self.graph 23 | # if root: 24 | # root_node = G.add_node("root", length=0) 25 | # TODO: skip unconnected and too short/big 26 | self.logger.info("creating DGL graph") 27 | self.graph = dgl.graph( 28 | (self.assembly.edges_src, self.assembly.edges_dst), 29 | num_nodes=len(self.assembly.node_names), 30 | ) 31 | self.graph.edata["weight"] = torch.tensor(self.assembly.edge_weights) 32 | 33 | self.logger.info("done") 34 | self.graph.ndata["label"] = torch.LongTensor( 35 | [self.assembly.labels.index(self.assembly.node_to_label[n]) if self.assembly.node_to_label[n] in self.assembly.labels else 0 for n in self.assembly.node_names] 36 | ) 37 | 38 | nx_graph = self.graph.to_networkx().to_undirected() 39 | self.logger.info("connected components...") 40 | # self.connected = [c for c in sorted(nx.connected_components(nx_graph), key=len, reverse=True) if len(c) > 1] 41 | # breakpoint() 42 | self.connected = [c for c in sorted(nx.connected_components(nx_graph), key=len, reverse=True) if len(c) > 0] 43 | self.logger.info((len(self.connected), "connected")) 44 | # for group in self.connected: 45 | # self.graphs.append( 46 | # dgl.node_subgraph(self.graph, [self.node_names.index(c) for c in group if c in self.node_names]) 47 | # ) 48 | 49 | assert len([c for comp in self.connected for c in comp]) <= len(self.assembly.node_names) 50 | 51 | # self.set_node_mask() 52 | 53 | def save(self): 54 | pass 55 | # save graphs and labels 56 | # save other information in python dict 57 | """info_path = os.path.join(self.save_path, "cache.pkl") 58 | print("saving graph", info_path) 59 | if not os.path.exists(self.save_path): 60 | os.makedirs(self.save_path) 61 | save_info(info_path, vars(self))""" 62 | 63 | def load(self): 64 | pass 65 | # load processed data from directory `self.save_path` 66 | """ info_path = os.path.join(self.save_path, "cache.pkl") 67 | print("loading from", info_path) 68 | loaded_info = load_info(info_path) 69 | for key in loaded_info: 70 | setattr(self, key, loaded_info[key])""" 71 | 72 | def has_cache(self): 73 | """# check whether there are processed data in `self.save_path` 74 | info_path = os.path.join(self.save_path, "cache.pkl") 75 | return os.path.exists(info_path)""" 76 | return False 77 | 78 | 79 | class ContigsDataset(DGLDataset): 80 | def __init__( 81 | self, 82 | name, 83 | assembly_path=None, 84 | assembly_name="assembly.fasta", 85 | graph_file="assembly_graph.gfa", 86 | labels=None, 87 | save_dir=None, 88 | force_reload=False, 89 | min_contig=1000, 90 | kmer=4, 91 | depth=None, 92 | markers=None, 93 | load_kmer=False, 94 | assembly_type="flye", 95 | ): 96 | self.mode = "train" 97 | # self.save_dir = save_dir 98 | self.assembly = assembly_path 99 | if self.assembly is None: 100 | self.assembly = "" 101 | self.readmapping = assembly_path 102 | self.assembly_name = assembly_name 103 | self.graph_file = graph_file 104 | self.depth = depth 105 | self.markers = markers 106 | self.contig_names = [] 107 | self.contig_seqs = {} 108 | self.read_names = [] 109 | self.node_names = [] # contig_names + read_names 110 | self.nodes_len = [] 111 | self.nodes_depths = [] 112 | self.nodes_markers = [] 113 | self.nodes_kmer = [] 114 | self.graphs = [] 115 | self.edges_src = [] 116 | self.edges_dst = [] 117 | self.edges_weight = [] 118 | self.nodes_data = [] 119 | self.node_to_label = {} 120 | self.node_labels = [] 121 | self.kmer = kmer 122 | self.load_kmer = load_kmer 123 | self.assembly_type = assembly_type 124 | if self.load_kmer: 125 | self.kmer_to_ids, self.canonical_k = get_kmer_to_id(self.kmer) 126 | 127 | self.connected = [] 128 | self.min_contig_len = min_contig 129 | if labels is None: 130 | self.species = ["NA"] 131 | self.add_new_species = True 132 | else: 133 | self.species = labels 134 | self.add_new_species = False 135 | super().__init__(name=name, save_dir=save_dir, force_reload=force_reload) 136 | 137 | def filter_edges(self, weight=0): 138 | """Filter edges based on weight""" 139 | # print(max(self.edges_weight), min(self.edges_weight)) 140 | if weight < 0: 141 | weight = sum(self.edges_weight) / len(self.edges_weight) 142 | # for i in range(max(self.edges_weight)): 143 | # print(i, self.edges_weight.count(i)) 144 | keep_idx = self.graph.edata["weight"] >= weight 145 | idx_to_remove = [i for i, x in enumerate(keep_idx) if not x] 146 | self.edges_src = [self.edges_src[i] for i in keep_idx] 147 | self.edges_dst = [self.edges_dst[i] for i in keep_idx] 148 | # self.graph.edata["weight"] = self.graph.edata["weight"][keep_idx] 149 | self.graph.remove_edges(idx_to_remove) 150 | 151 | def filter_contigs(self): 152 | # remove disconnected 153 | keep_idx = [i for i, c in enumerate(self.contig_names) if c in self.edges_src or c in self.edges_dst] 154 | # keep_idx = [i for i, c in enumerate(self.contig_names)] 155 | self.contig_names = [self.contig_names[i] for i in keep_idx] 156 | self.nodes_kmer = [self.nodes_kmer[i] for i in keep_idx] 157 | self.nodes_depths = [self.nodes_depths[i] for i in keep_idx] 158 | self.nodes_len = [self.nodes_len[i] for i in keep_idx] 159 | 160 | def remove_nodes(self, remove_list): 161 | self.graph.remove_nodes(torch.tensor(remove_list)) 162 | # self.contig_seqs = {} 163 | self.node_names = [self.node_names[i] for i in range(len(self.node_names)) if i not in remove_list] 164 | self.nodes_len = [self.nodes_len[i] for i in range(len(self.nodes_len)) if i not in remove_list] 165 | self.nodes_depths = [self.nodes_depths[i] for i in range(len(self.nodes_depths)) if i not in remove_list] 166 | self.nodes_markers = [self.nodes_markers[i] for i in range(len(self.nodes_markers)) if i not in remove_list] 167 | self.nodes_kmer = [self.nodes_kmer[i] for i in range(len(self.nodes_kmer)) if i not in remove_list] 168 | # self.edges_src = [] 169 | # self.edges_dst = [] 170 | # self.edges_weight = [] 171 | self.nodes_data = [self.nodes_data[i] for i in range(len(self.nodes_data)) if i not in remove_list] 172 | # self.node_to_label = {} 173 | self.node_labels = [self.node_labels[i] for i in range(len(self.node_labels)) if i not in remove_list] 174 | 175 | def rename_nodes_to_index(self): 176 | # self.edges_src = [self.contig_names.index(i) for i in self.edges_src if i in self.contig_names] 177 | # self.edges_dst = [self.contig_names.index(i) for i in self.edges_dst if i in self.contig_names] 178 | edge_name_to_index = {n: i for i, n in enumerate(self.contig_names)} 179 | self.edges_src = [edge_name_to_index[n] for n in self.edges_src] 180 | self.edges_dst = [edge_name_to_index[n] for n in self.edges_dst] 181 | 182 | def set_node_mask(self): 183 | """Set contig nodes""" 184 | self.graph.ndata["contigs"] = torch.zeros(len(self.node_names), dtype=torch.bool) 185 | self.graph.ndata["contigs"][: len(self.contig_names)] = True 186 | 187 | def get_labels_from_reads(self, reads, add_new=True): 188 | contig_to_species = {} 189 | read_to_species = {} 190 | # contig_lens = {} 191 | for r in reads: 192 | # print(r) 193 | speciesname = r.split("_reads")[0] 194 | 195 | if speciesname not in self.species: 196 | if add_new: 197 | self.species.append(speciesname) 198 | else: 199 | continue 200 | for m in reads[r].mappings: 201 | if m.contigname not in contig_to_species: # and m.contigname in contig_lens: 202 | contig_to_species[m.contigname] = {} 203 | if speciesname not in contig_to_species[m.contigname]: 204 | contig_to_species[m.contigname][speciesname] = 0 205 | # contig_to_species[m.contigname][speciesname] += m.mapq # weight mapping by quality 206 | contig_to_species[m.contigname][speciesname] += 1 # weight mapping by len 207 | read_to_species[r] = speciesname 208 | # reads_count = int(values[0]) 209 | # contig_to_species[values[1]][speciesname] = reads_count 210 | return contig_to_species, read_to_species 211 | -------------------------------------------------------------------------------- /docs/source/_build/html/_static/doctools.js: -------------------------------------------------------------------------------- 1 | /* 2 | * doctools.js 3 | * ~~~~~~~~~~~ 4 | * 5 | * Sphinx JavaScript utilities for all documentation. 6 | * 7 | * :copyright: Copyright 2007-2021 by the Sphinx team, see AUTHORS. 8 | * :license: BSD, see LICENSE for details. 9 | * 10 | */ 11 | 12 | /** 13 | * select a different prefix for underscore 14 | */ 15 | $u = _.noConflict(); 16 | 17 | /** 18 | * make the code below compatible with browsers without 19 | * an installed firebug like debugger 20 | if (!window.console || !console.firebug) { 21 | var names = ["log", "debug", "info", "warn", "error", "assert", "dir", 22 | "dirxml", "group", "groupEnd", "time", "timeEnd", "count", "trace", 23 | "profile", "profileEnd"]; 24 | window.console = {}; 25 | for (var i = 0; i < names.length; ++i) 26 | window.console[names[i]] = function() {}; 27 | } 28 | */ 29 | 30 | /** 31 | * small helper function to urldecode strings 32 | * 33 | * See https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/decodeURIComponent#Decoding_query_parameters_from_a_URL 34 | */ 35 | jQuery.urldecode = function(x) { 36 | if (!x) { 37 | return x 38 | } 39 | return decodeURIComponent(x.replace(/\+/g, ' ')); 40 | }; 41 | 42 | /** 43 | * small helper function to urlencode strings 44 | */ 45 | jQuery.urlencode = encodeURIComponent; 46 | 47 | /** 48 | * This function returns the parsed url parameters of the 49 | * current request. Multiple values per key are supported, 50 | * it will always return arrays of strings for the value parts. 51 | */ 52 | jQuery.getQueryParameters = function(s) { 53 | if (typeof s === 'undefined') 54 | s = document.location.search; 55 | var parts = s.substr(s.indexOf('?') + 1).split('&'); 56 | var result = {}; 57 | for (var i = 0; i < parts.length; i++) { 58 | var tmp = parts[i].split('=', 2); 59 | var key = jQuery.urldecode(tmp[0]); 60 | var value = jQuery.urldecode(tmp[1]); 61 | if (key in result) 62 | result[key].push(value); 63 | else 64 | result[key] = [value]; 65 | } 66 | return result; 67 | }; 68 | 69 | /** 70 | * highlight a given string on a jquery object by wrapping it in 71 | * span elements with the given class name. 72 | */ 73 | jQuery.fn.highlightText = function(text, className) { 74 | function highlight(node, addItems) { 75 | if (node.nodeType === 3) { 76 | var val = node.nodeValue; 77 | var pos = val.toLowerCase().indexOf(text); 78 | if (pos >= 0 && 79 | !jQuery(node.parentNode).hasClass(className) && 80 | !jQuery(node.parentNode).hasClass("nohighlight")) { 81 | var span; 82 | var isInSVG = jQuery(node).closest("body, svg, foreignObject").is("svg"); 83 | if (isInSVG) { 84 | span = document.createElementNS("http://www.w3.org/2000/svg", "tspan"); 85 | } else { 86 | span = document.createElement("span"); 87 | span.className = className; 88 | } 89 | span.appendChild(document.createTextNode(val.substr(pos, text.length))); 90 | node.parentNode.insertBefore(span, node.parentNode.insertBefore( 91 | document.createTextNode(val.substr(pos + text.length)), 92 | node.nextSibling)); 93 | node.nodeValue = val.substr(0, pos); 94 | if (isInSVG) { 95 | var rect = document.createElementNS("http://www.w3.org/2000/svg", "rect"); 96 | var bbox = node.parentElement.getBBox(); 97 | rect.x.baseVal.value = bbox.x; 98 | rect.y.baseVal.value = bbox.y; 99 | rect.width.baseVal.value = bbox.width; 100 | rect.height.baseVal.value = bbox.height; 101 | rect.setAttribute('class', className); 102 | addItems.push({ 103 | "parent": node.parentNode, 104 | "target": rect}); 105 | } 106 | } 107 | } 108 | else if (!jQuery(node).is("button, select, textarea")) { 109 | jQuery.each(node.childNodes, function() { 110 | highlight(this, addItems); 111 | }); 112 | } 113 | } 114 | var addItems = []; 115 | var result = this.each(function() { 116 | highlight(this, addItems); 117 | }); 118 | for (var i = 0; i < addItems.length; ++i) { 119 | jQuery(addItems[i].parent).before(addItems[i].target); 120 | } 121 | return result; 122 | }; 123 | 124 | /* 125 | * backward compatibility for jQuery.browser 126 | * This will be supported until firefox bug is fixed. 127 | */ 128 | if (!jQuery.browser) { 129 | jQuery.uaMatch = function(ua) { 130 | ua = ua.toLowerCase(); 131 | 132 | var match = /(chrome)[ \/]([\w.]+)/.exec(ua) || 133 | /(webkit)[ \/]([\w.]+)/.exec(ua) || 134 | /(opera)(?:.*version|)[ \/]([\w.]+)/.exec(ua) || 135 | /(msie) ([\w.]+)/.exec(ua) || 136 | ua.indexOf("compatible") < 0 && /(mozilla)(?:.*? rv:([\w.]+)|)/.exec(ua) || 137 | []; 138 | 139 | return { 140 | browser: match[ 1 ] || "", 141 | version: match[ 2 ] || "0" 142 | }; 143 | }; 144 | jQuery.browser = {}; 145 | jQuery.browser[jQuery.uaMatch(navigator.userAgent).browser] = true; 146 | } 147 | 148 | /** 149 | * Small JavaScript module for the documentation. 150 | */ 151 | var Documentation = { 152 | 153 | init : function() { 154 | this.fixFirefoxAnchorBug(); 155 | this.highlightSearchWords(); 156 | this.initIndexTable(); 157 | if (DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) { 158 | this.initOnKeyListeners(); 159 | } 160 | }, 161 | 162 | /** 163 | * i18n support 164 | */ 165 | TRANSLATIONS : {}, 166 | PLURAL_EXPR : function(n) { return n === 1 ? 0 : 1; }, 167 | LOCALE : 'unknown', 168 | 169 | // gettext and ngettext don't access this so that the functions 170 | // can safely bound to a different name (_ = Documentation.gettext) 171 | gettext : function(string) { 172 | var translated = Documentation.TRANSLATIONS[string]; 173 | if (typeof translated === 'undefined') 174 | return string; 175 | return (typeof translated === 'string') ? translated : translated[0]; 176 | }, 177 | 178 | ngettext : function(singular, plural, n) { 179 | var translated = Documentation.TRANSLATIONS[singular]; 180 | if (typeof translated === 'undefined') 181 | return (n == 1) ? singular : plural; 182 | return translated[Documentation.PLURALEXPR(n)]; 183 | }, 184 | 185 | addTranslations : function(catalog) { 186 | for (var key in catalog.messages) 187 | this.TRANSLATIONS[key] = catalog.messages[key]; 188 | this.PLURAL_EXPR = new Function('n', 'return +(' + catalog.plural_expr + ')'); 189 | this.LOCALE = catalog.locale; 190 | }, 191 | 192 | /** 193 | * add context elements like header anchor links 194 | */ 195 | addContextElements : function() { 196 | $('div[id] > :header:first').each(function() { 197 | $('\u00B6'). 198 | attr('href', '#' + this.id). 199 | attr('title', _('Permalink to this headline')). 200 | appendTo(this); 201 | }); 202 | $('dt[id]').each(function() { 203 | $('\u00B6'). 204 | attr('href', '#' + this.id). 205 | attr('title', _('Permalink to this definition')). 206 | appendTo(this); 207 | }); 208 | }, 209 | 210 | /** 211 | * workaround a firefox stupidity 212 | * see: https://bugzilla.mozilla.org/show_bug.cgi?id=645075 213 | */ 214 | fixFirefoxAnchorBug : function() { 215 | if (document.location.hash && $.browser.mozilla) 216 | window.setTimeout(function() { 217 | document.location.href += ''; 218 | }, 10); 219 | }, 220 | 221 | /** 222 | * highlight the search words provided in the url in the text 223 | */ 224 | highlightSearchWords : function() { 225 | var params = $.getQueryParameters(); 226 | var terms = (params.highlight) ? params.highlight[0].split(/\s+/) : []; 227 | if (terms.length) { 228 | var body = $('div.body'); 229 | if (!body.length) { 230 | body = $('body'); 231 | } 232 | window.setTimeout(function() { 233 | $.each(terms, function() { 234 | body.highlightText(this.toLowerCase(), 'highlighted'); 235 | }); 236 | }, 10); 237 | $('') 239 | .appendTo($('#searchbox')); 240 | } 241 | }, 242 | 243 | /** 244 | * init the domain index toggle buttons 245 | */ 246 | initIndexTable : function() { 247 | var togglers = $('img.toggler').click(function() { 248 | var src = $(this).attr('src'); 249 | var idnum = $(this).attr('id').substr(7); 250 | $('tr.cg-' + idnum).toggle(); 251 | if (src.substr(-9) === 'minus.png') 252 | $(this).attr('src', src.substr(0, src.length-9) + 'plus.png'); 253 | else 254 | $(this).attr('src', src.substr(0, src.length-8) + 'minus.png'); 255 | }).css('display', ''); 256 | if (DOCUMENTATION_OPTIONS.COLLAPSE_INDEX) { 257 | togglers.click(); 258 | } 259 | }, 260 | 261 | /** 262 | * helper function to hide the search marks again 263 | */ 264 | hideSearchWords : function() { 265 | $('#searchbox .highlight-link').fadeOut(300); 266 | $('span.highlighted').removeClass('highlighted'); 267 | }, 268 | 269 | /** 270 | * make the url absolute 271 | */ 272 | makeURL : function(relativeURL) { 273 | return DOCUMENTATION_OPTIONS.URL_ROOT + '/' + relativeURL; 274 | }, 275 | 276 | /** 277 | * get the current relative url 278 | */ 279 | getCurrentURL : function() { 280 | var path = document.location.pathname; 281 | var parts = path.split(/\//); 282 | $.each(DOCUMENTATION_OPTIONS.URL_ROOT.split(/\//), function() { 283 | if (this === '..') 284 | parts.pop(); 285 | }); 286 | var url = parts.join('/'); 287 | return path.substring(url.lastIndexOf('/') + 1, path.length - 1); 288 | }, 289 | 290 | initOnKeyListeners: function() { 291 | $(document).keydown(function(event) { 292 | var activeElementType = document.activeElement.tagName; 293 | // don't navigate when in search box, textarea, dropdown or button 294 | if (activeElementType !== 'TEXTAREA' && activeElementType !== 'INPUT' && activeElementType !== 'SELECT' 295 | && activeElementType !== 'BUTTON' && !event.altKey && !event.ctrlKey && !event.metaKey 296 | && !event.shiftKey) { 297 | switch (event.keyCode) { 298 | case 37: // left 299 | var prevHref = $('link[rel="prev"]').prop('href'); 300 | if (prevHref) { 301 | window.location.href = prevHref; 302 | return false; 303 | } 304 | break; 305 | case 39: // right 306 | var nextHref = $('link[rel="next"]').prop('href'); 307 | if (nextHref) { 308 | window.location.href = nextHref; 309 | return false; 310 | } 311 | break; 312 | } 313 | } 314 | }); 315 | } 316 | }; 317 | 318 | // quick alias for translations 319 | _ = Documentation.gettext; 320 | 321 | $(document).ready(function() { 322 | Documentation.init(); 323 | }); 324 | -------------------------------------------------------------------------------- /src/graphmb/utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import time 3 | import os 4 | import sys 5 | import math 6 | import pdb 7 | import itertools 8 | from collections import Counter 9 | import networkx as nx 10 | import numpy as np 11 | from tqdm import tqdm 12 | import datetime 13 | import operator 14 | import scipy 15 | import logging 16 | # import dgl 17 | import random 18 | from sklearn.cluster import KMeans 19 | #import tensorflow as tf 20 | 21 | SEED = 0 22 | 23 | def set_seed(seed=0): 24 | if "dgl" in sys.modules: 25 | import dgl 26 | print("setting dgl seed") 27 | dgl.random.seed(seed) 28 | if "torch" in sys.modules: 29 | import torch 30 | print("setting torch seed") 31 | torch.manual_seed(seed) 32 | random.seed(seed) 33 | np.random.seed(seed) 34 | if "tensorflow" in sys.modules: 35 | import tensorflow 36 | print("setting tf seed") 37 | tensorflow.random.set_seed(seed) 38 | 39 | 40 | 41 | class Read: 42 | def __init__(self, readid, species=None): 43 | self.readid = readid 44 | self.species = species 45 | self.mappings = set() 46 | 47 | 48 | class ReadMapping: 49 | def __init__(self, readid, bitflag, contigname, pos, mapq, seq): 50 | self.readid = readid 51 | self.bitflag = bitflag 52 | self.contigname = contigname 53 | self.pos = pos 54 | self.mapq = mapq 55 | 56 | 57 | def get_cluster_mask(quick, dataset): 58 | if quick and dataset.contig_markers is not None: 59 | #connected_marker_nodes = filter_disconnected(dataset.adj_matrix, dataset.node_names, dataset.contig_markers) 60 | nodes_with_markers = [ 61 | i 62 | for i, n in enumerate(dataset.node_names) 63 | if n in dataset.contig_markers and len(dataset.contig_markers[n]) > 0 64 | ] 65 | print("eval cluster with ", len(nodes_with_markers), "contigds with markers") 66 | cluster_mask = [n in nodes_with_markers for n in range(len(dataset.node_names))] 67 | else: 68 | cluster_mask = [True] * len(dataset.node_names) 69 | return cluster_mask 70 | 71 | def save_model(args, epoch, th, th_vae): 72 | if th_vae is not None: 73 | # save encoder and decoder 74 | th_vae.encoder.save(os.path.join(args.outdir, args.outname + "_best_encoder")) 75 | th_vae.decoder.save(os.path.join(args.outdir, args.outname + "_best_decoder")) 76 | if th is not None: 77 | th.gnn_model.save(os.path.join(args.outdir, args.outname + "_best_gnn")) 78 | 79 | 80 | def run_clustering(X, node_names, clustering_algo, cuda, k=0, tsne=False): 81 | 82 | if clustering_algo == "vamb": 83 | from graphmb.vamb_clustering import cluster as vamb_cluster 84 | starttime = datetime.datetime.now() 85 | X = X.astype(np.float32) 86 | cluster_to_contig = { 87 | i: c for (i, (n, c)) in enumerate(vamb_cluster(X, node_names, cuda=cuda)) 88 | } 89 | clustering_time = datetime.datetime.now() 90 | #print("clustering time", clustering_time-starttime) 91 | contig_to_bin = {} 92 | #for b in cluster_to_contig: 93 | # for contig in cluster_to_contig[b]: 94 | # contig_to_bin[contig] = b 95 | for k, v in cluster_to_contig.items(): 96 | contig_to_bin.update({n: k for n in v}) 97 | labels = np.array([contig_to_bin[n] for n in node_names]) 98 | # very slow code: 99 | cluster_centroids = None 100 | if tsne: 101 | cluster_to_embs = { 102 | c: np.array([X[i] for i, n in enumerate(node_names) if n in cluster_to_contig[c]]) 103 | for c in cluster_to_contig 104 | } 105 | cluster_centroids = np.array([cluster_to_embs[c].mean(0) for c in cluster_to_contig]) 106 | processing_time = datetime.datetime.now() 107 | #print("processing time", processing_time - clustering_time) 108 | elif clustering_algo == "kmeansbatch": 109 | kmeans = MiniBatchKMeans(n_clusters=k, random_state=0, batch_size=2048, verbose=0) #, init=seed_matrix) 110 | labels = kmeans.fit_predict(X) 111 | contig_to_bin = {node_names[i]: labels[i] for i in range(len(node_names))} 112 | cluster_to_contig = {i: [] for i in range(k)} 113 | for i in range(len(node_names)): 114 | cluster_to_contig[labels[i]].append(node_names[i]) 115 | #cluster_centroids = kmeans.cluster_centers_ 116 | elif clustering_algo == "kmeansgpu": 117 | pass 118 | elif clustering_algo == "kmedoids": 119 | import kmedoids 120 | breakpoint() 121 | # TODO do this on gpu if avail 122 | D = np.sum((X[:,None]-X[None])**2, axis=-1) 123 | # TODO find best k 124 | km = kmedoids.KMedoids(20, method='fasterpam') 125 | cluster_labels = km.fit_predict(D).astype(np.int64) 126 | elif clustering_algo == "kmeans": 127 | clf = KMeans(k, random_state=1234) 128 | labels = clf.fit_predict(X) 129 | contig_to_bin = {node_names[i]: labels[i] for i in range(len(node_names))} 130 | cluster_to_contig = {i: [] for i in range(k)} 131 | for i in range(len(node_names)): 132 | cluster_to_contig[labels[i]].append(node_names[i]) 133 | cluster_centroids = None 134 | return cluster_to_contig, contig_to_bin, labels, cluster_centroids 135 | 136 | def filter_disconnected(adj, node_names, markers): 137 | # get idx of nodes that are connected or have at least one marker 138 | graph = nx.convert_matrix.from_scipy_sparse_matrix(adj, edge_attribute="weight") 139 | # breakpoint() 140 | nodes_to_remove = set() 141 | for n1 in graph.nodes: 142 | if len(list(graph.neighbors(n1))) == 0 or ( 143 | node_names[n1] not in markers or len(markers[node_names[n1]]) == 0 144 | ): 145 | nodes_to_remove.add(n1) 146 | 147 | graph.remove_nodes_from(list(nodes_to_remove)) 148 | assert len(graph.nodes()) == (len(node_names)-len(nodes_to_remove)) 149 | print(f"{len(nodes_to_remove)} nodes without edges nor markers, keeping {len(graph.nodes())} nodes") 150 | return set(graph.nodes()) 151 | 152 | def run_model_vgae(dataset, args, logger, nrun): 153 | node_names = np.array(dataset.node_names) 154 | RESULT_EVERY = args.evalepochs 155 | hidden_gnn = args.hidden_gnn 156 | hidden_vae = args.hidden_vae 157 | output_dim_gnn = args.embsize_gnn 158 | output_dim_vae = args.embsize_vae 159 | epochs = args.epoch 160 | lr_vae = args.lr_vae 161 | lr_gnn = args.lr_gnn 162 | nlayers_gnn = args.layers_gnn 163 | clustering = args.clusteringalgo 164 | k = args.kclusters 165 | use_edge_weights = True 166 | cluster_markers_only = args.quick 167 | decay = 0.5 ** (2.0 / epochs) 168 | concat_features = args.concat_features 169 | 170 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 171 | train_log_dir = os.path.join(args.outdir, 'logs/' + args.outname + current_time + '/train') 172 | summary_writer = tf.summary.create_file_writer(train_log_dir) 173 | print("logging to tensorboard") 174 | tb_handler = TensorboardLogger(summary_writer, runname=args.outname + current_time) 175 | logger.addHandler(tb_handler) 176 | #tf.summary.trace_on(graph=True) 177 | 178 | logger.info("******* Running model: VGAE **********") 179 | logger.info("***** using edge weights: {} ******".format(use_edge_weights)) 180 | logger.info("***** concat features: {} *****".format(concat_features)) 181 | logger.info("***** cluster markers only: {} *****".format(cluster_markers_only)) 182 | logger.info("***** threshold adj matrix: {} *****".format(args.binarize)) 183 | logger.info("***** self edges only: {} *****".format(args.noedges)) 184 | logger.info("***** Using raw kmer+abund features: {}".format(args.rawfeatures)) 185 | tf.config.experimental_run_functions_eagerly(True) 186 | 187 | 188 | X, adj, cluster_mask, neg_pair_idx, pos_pair_idx = prepare_data_for_gnn( 189 | dataset, use_edge_weights, cluster_markers_only, use_raw=True, 190 | binarize=args.binarize, remove_edges=args.noedges) 191 | logger.info("***** SCG neg pairs: {}".format(neg_pair_idx.shape)) 192 | logger.info("***** input features dimension: {}".format(X[cluster_mask].shape)) 193 | # pre train clustering 194 | if not args.skip_preclustering: 195 | cluster_labels, stats, _, hq_bins = compute_clusters_and_stats( 196 | X[cluster_mask], node_names[cluster_mask], 197 | dataset, clustering=clustering, k=k, tsne=args.tsne, 198 | amber=(args.labels is not None and "amber" in args.labels), 199 | unresolved=True, cuda=args.cuda, 200 | ) 201 | logger.info(f">>> Pre train stats: {str(stats)}") 202 | 203 | 204 | model = VGAE(X.shape, hidden_dim1=hidden_gnn, hidden_dim2=output_dim_gnn, dropout=0.1, 205 | l2_reg=1e-5, embeddings=X, freeze_embeddings=True, lr=lr_gnn) 206 | X_train = np.arange(len(X))[:,None].astype(np.int64) 207 | A_train = tf.sparse.to_dense(adj) 208 | labels = dataset.adj_matrix.toarray() 209 | pos_weight = (adj.shape[0] * adj.shape[0] - tf.sparse.reduce_sum(adj)) / tf.sparse.reduce_sum(adj) 210 | 211 | norm = adj.shape[0] * adj.shape[0] / ((adj.shape[0] * adj.shape[0] - tf.sparse.reduce_sum(adj)) * 2) 212 | 213 | pbar_epoch = tqdm(range(epochs), disable=args.quiet, position=0) 214 | decay = 0.5**(2./10000) 215 | scores = [] 216 | best_hq = 0 217 | batch_size = args.batchsize 218 | if batch_size == 0: 219 | batch_size = adj.shape[0] 220 | train_idx = list(range(adj.shape[0])) 221 | for e in pbar_epoch: 222 | np.random.shuffle(train_idx) 223 | n_batches = len(train_idx)//batch_size 224 | pbar_vaebatch = tqdm(range(n_batches), disable=(args.quiet or batch_size == len(train_idx) or n_batches < 100), position=1, ascii=' =') 225 | loss = 0 226 | for b in pbar_vaebatch: 227 | batch_idx = train_idx[b*batch_size:(b+1)*batch_size] 228 | loss += model.train_step(X_train, A_train, labels, pos_weight, norm, batch_idx) 229 | pbar_epoch.set_description(f'{loss:.3f}') 230 | model.optimizer.learning_rate = model.optimizer.learning_rate*decay 231 | gpu_mem_alloc = tf.config.experimental.get_memory_usage('GPU:0') / 1000000 if args.cuda else 0 232 | if (e + 1) % RESULT_EVERY == 0: # and e >= int(epochs/2): 233 | _, embs, _, _, _ = model((X_train, A_train), training=False) 234 | node_new_features = embs.numpy() 235 | 236 | best_hq, best_embs, best_epoch, scores, cluster_labels = eval_epoch(logger, summary_writer, node_new_features, 237 | cluster_mask, e, args, dataset, e, scores, 238 | best_hq, best_embs, best_epoch) 239 | 240 | if args.quiet: 241 | logger.info(f"--- EPOCH {e:d} ---") 242 | logger.info(f"[VGAE {nlayers_gnn}l] L={loss:.3f} HQ={stats['hq']} BestHQ={best_hq} Best Epoch={best_epoch} Max GPU MB={gpu_mem_alloc:.1f}") 243 | logger.info(str(stats)) 244 | 245 | 246 | _, embs, _, _, _ = model((X_train, A_train), training=False) 247 | embs = embs.numpy() 248 | 249 | 250 | -------------------------------------------------------------------------------- /src/graphmb/train_gnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import datetime 3 | import os 4 | import tensorflow as tf 5 | import random 6 | import logging 7 | from tqdm import tqdm 8 | import mlflow 9 | from graphmb.models import TH 10 | from graphmb.utils import set_seed # , run_tsne, plot_embs, plot_edges_sim 11 | from graphmb.train_ccvae import prepare_data_for_gnn 12 | from graphmb.visualize import plot_edges_sim 13 | from graphmb.evaluate import compute_clusters_and_stats, eval_epoch 14 | from graphmb.utils import get_cluster_mask 15 | from graphmb.gnn_models import name_to_model 16 | 17 | 18 | def run_model_gnn(dataset, args, logger, nrun, target_metric): 19 | set_seed(args.seed) 20 | node_names = np.array(dataset.node_names) 21 | RESULT_EVERY = args.evalepochs 22 | hidden_gnn = args.hidden_gnn 23 | output_dim_gnn = args.embsize_gnn 24 | epochs = args.epoch 25 | lr_gnn = args.lr_gnn 26 | nlayers_gnn = args.layers_gnn 27 | gname = args.model_name 28 | gmodel_type = name_to_model[gname.split("_")[0].upper()] 29 | clustering = args.clusteringalgo 30 | k = args.kclusters 31 | use_disconnected = not args.quick 32 | cluster_markers_only = args.quick 33 | use_edge_weights = True 34 | concat_features = True # bypass args, otherwise results are bad 35 | 36 | with mlflow.start_run(run_name=args.assembly.split("/")[-1] + "-" + args.outname): 37 | mlflow.log_params(vars(args)) 38 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 39 | tf.config.run_functions_eagerly(True) 40 | 41 | X, adj, cluster_mask, neg_pair_idx, pos_pair_idx = prepare_data_for_gnn( 42 | dataset, use_edge_weights, cluster_markers_only, use_raw=args.rawfeatures, 43 | binarize=args.binarize, remove_edges=args.noedges) 44 | if nrun == 0: 45 | print("logging to mlflow") 46 | logger.info("******* Running model: {} **********".format(gname)) 47 | logger.info("***** using edge weights: {} ******".format(use_edge_weights)) 48 | logger.info("***** using disconnected: {} ******".format(use_disconnected)) 49 | logger.info("***** concat features: {} *****".format(concat_features)) 50 | logger.info("***** cluster markers only: {} *****".format(cluster_markers_only)) 51 | logger.info("***** threshold adj matrix: {} *****".format(args.binarize)) 52 | logger.info("***** self edges only: {} *****".format(args.noedges)) 53 | logger.info("***** Using raw kmer+abund features: {}".format(args.rawfeatures)) 54 | 55 | logger.info("***** SCG neg pairs: {}".format(neg_pair_idx.shape)) 56 | logger.info("***** input features dimension: {}".format(X[cluster_mask].shape[1])) 57 | logger.info("***** Nodes used for clustering: {}".format(X[cluster_mask].shape[0])) 58 | 59 | #plot edges vs initial embs 60 | #id_to_scg = {i: set(dataset.contig_markers.get(node_name, {}).keys()) for i, node_name in enumerate(dataset.node_names)} 61 | #plot_edges_sim(X, dataset.adj_matrix, id_to_scg, f"{args.outdir}/{args.outname}_pretrain_") 62 | 63 | # pre train clustering 64 | if not args.skip_preclustering and nrun == 0: 65 | cluster_labels, stats, _, hq_bins = compute_clusters_and_stats( 66 | X[cluster_mask], node_names[cluster_mask], 67 | dataset, clustering=clustering, k=k, 68 | amber=(args.labels is not None and "amber" in args.labels), 69 | cuda=args.cuda, 70 | ) 71 | logger.info(f">>> Pre train stats: {str(stats)}") 72 | else: 73 | stats = {"hq": 0, "epoch":0, target_metric: 0} 74 | 75 | scores = [stats] 76 | losses = {"total": [], "ae": [], "gnn": [], "scg": []} 77 | X = X.astype(np.float32) 78 | features = tf.constant(X) 79 | input_dim_gnn = X.shape[1] 80 | 81 | logger.info(f"*** Model input dim {X.shape[1]}, GNN input dim {input_dim_gnn}") 82 | 83 | gnn_model = gmodel_type( 84 | features_shape=features.shape, 85 | input_dim=input_dim_gnn, 86 | labels=None, 87 | adj=adj, 88 | embsize=output_dim_gnn, 89 | hidden_units=hidden_gnn, 90 | layers=nlayers_gnn, 91 | conv_last=False, 92 | ) 93 | logger.info(f"*** output clustering dim {output_dim_gnn}") 94 | 95 | th = TH( 96 | features, 97 | gnn_model=gnn_model, 98 | lr=lr_gnn, 99 | all_different_idx=neg_pair_idx, 100 | all_same_idx=pos_pair_idx, 101 | ae_encoder=None, 102 | ae_decoder=None, 103 | latentdim=output_dim_gnn, 104 | graph_weight=float(args.graph_alpha), 105 | ae_weight=float(args.ae_alpha), 106 | scg_weight=float(args.scg_alpha), 107 | num_negatives=args.negatives, 108 | decoder_input=args.decoder_input, 109 | ) 110 | 111 | #gnn_model.summary() 112 | if args.eval_split == 0: 113 | train_idx = np.arange(len(features)) 114 | eval_idx = [] 115 | else: 116 | train_idx = np.array(random.sample(list(range(len(features))), int(len(features)*(1-args.eval_split)))) 117 | eval_idx = np.array([x for x in np.arange(len(features)) if x not in train_idx]) 118 | logging.info(f"**** using {len(train_idx)} for training and {len(eval_idx)} for eval") 119 | features = np.array(features) 120 | pbar_epoch = tqdm(range(epochs), disable=args.quiet, position=0) 121 | scores = [] 122 | best_embs = None 123 | best_model = th.gnn_model.get_weights() 124 | best_score = 0 125 | best_epoch = 0 126 | batch_size = args.batchsize 127 | scores_string = "" 128 | if batch_size == 0: 129 | batch_size = len(train_idx) 130 | all_cluster_labels = [] 131 | step = 0 132 | for e in pbar_epoch: 133 | np.random.shuffle(train_idx) 134 | step += 1 135 | loss, gnn_losses, ae_losses = th.train_unsupervised(train_idx, scg=True) 136 | epoch_losses = {"Total": loss.numpy(), 137 | "gnn": gnn_losses["gnn_loss"].numpy(), 138 | "SCG": gnn_losses["scg_loss"].numpy(), 139 | #"pred": gnn_losses["pred_loss"], 140 | #'GNN LR': float(trainer.opt._decayed_lr(float)), 141 | "pos": gnn_losses["pos"].numpy(), 142 | "neg": gnn_losses["neg"].numpy()} 143 | mlflow.log_metrics(epoch_losses, step=step) 144 | gnn_loss = epoch_losses["gnn"] 145 | diff_loss = epoch_losses["SCG"] 146 | 147 | if args.eval_split > 0: 148 | eval_total_loss, eval_gnn_loss, eval_diff_loss, eval_pos_loss, \ 149 | eval_neg_loss = th.train_unsupervised(eval_idx, training=False) 150 | eval_epoch_losses = {"Eval gnn loss": float(eval_gnn_loss), "Eval SCG loss": float(eval_diff_loss), 151 | "Eval GNN loss": float(eval_total_loss), 152 | "Eval pos loss": float(eval_pos_loss), 153 | "Eval neg_loss": float(eval_neg_loss)} 154 | mlflow.log_metrics(eval_epoch_losses, step=step) 155 | #else: 156 | # eval_loss, eval_mse1, eval_mse2, eval_kld = 0, 0, 0, 0 157 | 158 | gpu_mem_alloc = tf.config.experimental.get_memory_info('GPU:0')["peak"] / 1000000 if args.cuda else 0 159 | #gpu_mem_alloc = tf.config.experimental.get_memory_usage('GPU:0') / 1000000 if args.cuda else 0 160 | if (e + 1) % RESULT_EVERY == 0 and e > args.evalskip and target_metric != "noeval": 161 | gnn_input_features = features 162 | node_new_features = th.gnn_model(gnn_input_features, None, training=False) 163 | node_new_features = node_new_features.numpy() 164 | if concat_features: 165 | node_new_features = tf.concat([gnn_input_features, node_new_features], axis=1).numpy() 166 | eval_output = eval_epoch(node_new_features, cluster_mask, th.gnn_model.get_weights(), 167 | args, dataset, e, scores, best_score, best_embs, best_epoch, 168 | best_model, target_metric=target_metric) 169 | 170 | best_score, best_embs, best_epoch, scores, best_model, cluster_labels = eval_output 171 | stats = scores[-1] 172 | if args.quiet: 173 | logger.info(f"--- EPOCH {e:d} ---") 174 | scores_string = f"HQ={stats['hq']} Best{target_metric}={round(best_score, 3)} Best Epoch={best_epoch}" 175 | logger.info(f"[{gname} {nlayers_gnn}l] L={gnn_loss:.3f} D={diff_loss:.3f} {scores_string} GPU={gpu_mem_alloc:.1f}MB") 176 | logger.info(str(stats)) 177 | mlflow.log_metrics(stats, step=step) 178 | all_cluster_labels.append(cluster_labels) 179 | scores_string = f"HQ={stats['hq']} Best{target_metric}={round(best_score, 3)} Best Epoch={best_epoch}" 180 | pbar_epoch.set_description( 181 | f"[{args.outname} {nlayers_gnn}l] L={gnn_loss:.3f} D={diff_loss:.3f} {scores_string} GPU={gpu_mem_alloc:.1f}MB" 182 | ) 183 | total_loss = gnn_loss + diff_loss 184 | losses["gnn"].append(gnn_loss) 185 | losses["scg"].append(diff_loss) 186 | losses["total"].append(total_loss) 187 | 188 | gnn_model.set_weights(best_model) 189 | node_new_features = th.gnn_model(features, None, training=False) 190 | node_new_features = node_new_features.numpy() 191 | if best_embs is None or target_metric != "noeval": 192 | best_embs = node_new_features 193 | if concat_features: 194 | node_new_features = tf.concat([features, node_new_features], axis=1).numpy() 195 | cluster_labels, stats, _, _ = compute_clusters_and_stats( 196 | node_new_features, node_names, dataset, 197 | clustering=clustering, k=k, amber=(args.labels is not None and "amber" in args.labels), 198 | cuda=args.cuda, 199 | ) 200 | all_cluster_labels.append(cluster_labels) 201 | stats["epoch"] = e 202 | scores.append(stats) 203 | if target_metric != "noeval": 204 | # get best stats: 205 | target_scores = [s[target_metric] for s in scores] 206 | best_idx = np.argmax(target_scores) 207 | else: 208 | best_embs = node_new_features 209 | best_idx = -1 210 | mlflow.log_metrics(scores[best_idx], step=step+1) 211 | logger.info(f">>> best epoch all contigs: {RESULT_EVERY + (best_idx*RESULT_EVERY)} : {stats} <<<") 212 | logger.info(f">>> best epoch: {RESULT_EVERY + (best_idx*RESULT_EVERY)} : {scores[best_idx]} <<<") 213 | with open(f"{args.outdir}/{args.outname}_{nrun}_best_contig2bin.tsv", "w") as f: 214 | f.write("@Version:0.9.0\n@SampleID:SAMPLEID\n@@SEQUENCEID\tBINID\n") 215 | for i in range(len(all_cluster_labels[best_idx])): 216 | f.write(f"{node_names[i]}\t{all_cluster_labels[best_idx][i]}\n") 217 | #plot edges vs final embs 218 | #plot_edges_sim(best_embs, dataset.adj_matrix, id_to_scg, f"{args.outdir}/{args.outname}_posttrain_") 219 | return best_embs, scores[best_idx], all_cluster_labels[best_idx] 220 | 221 | -------------------------------------------------------------------------------- /docs/source/_build/html/_static/language_data.js: -------------------------------------------------------------------------------- 1 | /* 2 | * language_data.js 3 | * ~~~~~~~~~~~~~~~~ 4 | * 5 | * This script contains the language-specific data used by searchtools.js, 6 | * namely the list of stopwords, stemmer, scorer and splitter. 7 | * 8 | * :copyright: Copyright 2007-2021 by the Sphinx team, see AUTHORS. 9 | * :license: BSD, see LICENSE for details. 10 | * 11 | */ 12 | 13 | var stopwords = ["a","and","are","as","at","be","but","by","for","if","in","into","is","it","near","no","not","of","on","or","such","that","the","their","then","there","these","they","this","to","was","will","with"]; 14 | 15 | 16 | /* Non-minified version is copied as a separate JS file, is available */ 17 | 18 | /** 19 | * Porter Stemmer 20 | */ 21 | var Stemmer = function() { 22 | 23 | var step2list = { 24 | ational: 'ate', 25 | tional: 'tion', 26 | enci: 'ence', 27 | anci: 'ance', 28 | izer: 'ize', 29 | bli: 'ble', 30 | alli: 'al', 31 | entli: 'ent', 32 | eli: 'e', 33 | ousli: 'ous', 34 | ization: 'ize', 35 | ation: 'ate', 36 | ator: 'ate', 37 | alism: 'al', 38 | iveness: 'ive', 39 | fulness: 'ful', 40 | ousness: 'ous', 41 | aliti: 'al', 42 | iviti: 'ive', 43 | biliti: 'ble', 44 | logi: 'log' 45 | }; 46 | 47 | var step3list = { 48 | icate: 'ic', 49 | ative: '', 50 | alize: 'al', 51 | iciti: 'ic', 52 | ical: 'ic', 53 | ful: '', 54 | ness: '' 55 | }; 56 | 57 | var c = "[^aeiou]"; // consonant 58 | var v = "[aeiouy]"; // vowel 59 | var C = c + "[^aeiouy]*"; // consonant sequence 60 | var V = v + "[aeiou]*"; // vowel sequence 61 | 62 | var mgr0 = "^(" + C + ")?" + V + C; // [C]VC... is m>0 63 | var meq1 = "^(" + C + ")?" + V + C + "(" + V + ")?$"; // [C]VC[V] is m=1 64 | var mgr1 = "^(" + C + ")?" + V + C + V + C; // [C]VCVC... is m>1 65 | var s_v = "^(" + C + ")?" + v; // vowel in stem 66 | 67 | this.stemWord = function (w) { 68 | var stem; 69 | var suffix; 70 | var firstch; 71 | var origword = w; 72 | 73 | if (w.length < 3) 74 | return w; 75 | 76 | var re; 77 | var re2; 78 | var re3; 79 | var re4; 80 | 81 | firstch = w.substr(0,1); 82 | if (firstch == "y") 83 | w = firstch.toUpperCase() + w.substr(1); 84 | 85 | // Step 1a 86 | re = /^(.+?)(ss|i)es$/; 87 | re2 = /^(.+?)([^s])s$/; 88 | 89 | if (re.test(w)) 90 | w = w.replace(re,"$1$2"); 91 | else if (re2.test(w)) 92 | w = w.replace(re2,"$1$2"); 93 | 94 | // Step 1b 95 | re = /^(.+?)eed$/; 96 | re2 = /^(.+?)(ed|ing)$/; 97 | if (re.test(w)) { 98 | var fp = re.exec(w); 99 | re = new RegExp(mgr0); 100 | if (re.test(fp[1])) { 101 | re = /.$/; 102 | w = w.replace(re,""); 103 | } 104 | } 105 | else if (re2.test(w)) { 106 | var fp = re2.exec(w); 107 | stem = fp[1]; 108 | re2 = new RegExp(s_v); 109 | if (re2.test(stem)) { 110 | w = stem; 111 | re2 = /(at|bl|iz)$/; 112 | re3 = new RegExp("([^aeiouylsz])\\1$"); 113 | re4 = new RegExp("^" + C + v + "[^aeiouwxy]$"); 114 | if (re2.test(w)) 115 | w = w + "e"; 116 | else if (re3.test(w)) { 117 | re = /.$/; 118 | w = w.replace(re,""); 119 | } 120 | else if (re4.test(w)) 121 | w = w + "e"; 122 | } 123 | } 124 | 125 | // Step 1c 126 | re = /^(.+?)y$/; 127 | if (re.test(w)) { 128 | var fp = re.exec(w); 129 | stem = fp[1]; 130 | re = new RegExp(s_v); 131 | if (re.test(stem)) 132 | w = stem + "i"; 133 | } 134 | 135 | // Step 2 136 | re = /^(.+?)(ational|tional|enci|anci|izer|bli|alli|entli|eli|ousli|ization|ation|ator|alism|iveness|fulness|ousness|aliti|iviti|biliti|logi)$/; 137 | if (re.test(w)) { 138 | var fp = re.exec(w); 139 | stem = fp[1]; 140 | suffix = fp[2]; 141 | re = new RegExp(mgr0); 142 | if (re.test(stem)) 143 | w = stem + step2list[suffix]; 144 | } 145 | 146 | // Step 3 147 | re = /^(.+?)(icate|ative|alize|iciti|ical|ful|ness)$/; 148 | if (re.test(w)) { 149 | var fp = re.exec(w); 150 | stem = fp[1]; 151 | suffix = fp[2]; 152 | re = new RegExp(mgr0); 153 | if (re.test(stem)) 154 | w = stem + step3list[suffix]; 155 | } 156 | 157 | // Step 4 158 | re = /^(.+?)(al|ance|ence|er|ic|able|ible|ant|ement|ment|ent|ou|ism|ate|iti|ous|ive|ize)$/; 159 | re2 = /^(.+?)(s|t)(ion)$/; 160 | if (re.test(w)) { 161 | var fp = re.exec(w); 162 | stem = fp[1]; 163 | re = new RegExp(mgr1); 164 | if (re.test(stem)) 165 | w = stem; 166 | } 167 | else if (re2.test(w)) { 168 | var fp = re2.exec(w); 169 | stem = fp[1] + fp[2]; 170 | re2 = new RegExp(mgr1); 171 | if (re2.test(stem)) 172 | w = stem; 173 | } 174 | 175 | // Step 5 176 | re = /^(.+?)e$/; 177 | if (re.test(w)) { 178 | var fp = re.exec(w); 179 | stem = fp[1]; 180 | re = new RegExp(mgr1); 181 | re2 = new RegExp(meq1); 182 | re3 = new RegExp("^" + C + v + "[^aeiouwxy]$"); 183 | if (re.test(stem) || (re2.test(stem) && !(re3.test(stem)))) 184 | w = stem; 185 | } 186 | re = /ll$/; 187 | re2 = new RegExp(mgr1); 188 | if (re.test(w) && re2.test(w)) { 189 | re = /.$/; 190 | w = w.replace(re,""); 191 | } 192 | 193 | // and turn initial Y back to y 194 | if (firstch == "y") 195 | w = firstch.toLowerCase() + w.substr(1); 196 | return w; 197 | } 198 | } 199 | 200 | 201 | 202 | 203 | var splitChars = (function() { 204 | var result = {}; 205 | var singles = [96, 180, 187, 191, 215, 247, 749, 885, 903, 907, 909, 930, 1014, 1648, 206 | 1748, 1809, 2416, 2473, 2481, 2526, 2601, 2609, 2612, 2615, 2653, 2702, 207 | 2706, 2729, 2737, 2740, 2857, 2865, 2868, 2910, 2928, 2948, 2961, 2971, 208 | 2973, 3085, 3089, 3113, 3124, 3213, 3217, 3241, 3252, 3295, 3341, 3345, 209 | 3369, 3506, 3516, 3633, 3715, 3721, 3736, 3744, 3748, 3750, 3756, 3761, 210 | 3781, 3912, 4239, 4347, 4681, 4695, 4697, 4745, 4785, 4799, 4801, 4823, 211 | 4881, 5760, 5901, 5997, 6313, 7405, 8024, 8026, 8028, 8030, 8117, 8125, 212 | 8133, 8181, 8468, 8485, 8487, 8489, 8494, 8527, 11311, 11359, 11687, 11695, 213 | 11703, 11711, 11719, 11727, 11735, 12448, 12539, 43010, 43014, 43019, 43587, 214 | 43696, 43713, 64286, 64297, 64311, 64317, 64319, 64322, 64325, 65141]; 215 | var i, j, start, end; 216 | for (i = 0; i < singles.length; i++) { 217 | result[singles[i]] = true; 218 | } 219 | var ranges = [[0, 47], [58, 64], [91, 94], [123, 169], [171, 177], [182, 184], [706, 709], 220 | [722, 735], [741, 747], [751, 879], [888, 889], [894, 901], [1154, 1161], 221 | [1318, 1328], [1367, 1368], [1370, 1376], [1416, 1487], [1515, 1519], [1523, 1568], 222 | [1611, 1631], [1642, 1645], [1750, 1764], [1767, 1773], [1789, 1790], [1792, 1807], 223 | [1840, 1868], [1958, 1968], [1970, 1983], [2027, 2035], [2038, 2041], [2043, 2047], 224 | [2070, 2073], [2075, 2083], [2085, 2087], [2089, 2307], [2362, 2364], [2366, 2383], 225 | [2385, 2391], [2402, 2405], [2419, 2424], [2432, 2436], [2445, 2446], [2449, 2450], 226 | [2483, 2485], [2490, 2492], [2494, 2509], [2511, 2523], [2530, 2533], [2546, 2547], 227 | [2554, 2564], [2571, 2574], [2577, 2578], [2618, 2648], [2655, 2661], [2672, 2673], 228 | [2677, 2692], [2746, 2748], [2750, 2767], [2769, 2783], [2786, 2789], [2800, 2820], 229 | [2829, 2830], [2833, 2834], [2874, 2876], [2878, 2907], [2914, 2917], [2930, 2946], 230 | [2955, 2957], [2966, 2968], [2976, 2978], [2981, 2983], [2987, 2989], [3002, 3023], 231 | [3025, 3045], [3059, 3076], [3130, 3132], [3134, 3159], [3162, 3167], [3170, 3173], 232 | [3184, 3191], [3199, 3204], [3258, 3260], [3262, 3293], [3298, 3301], [3312, 3332], 233 | [3386, 3388], [3390, 3423], [3426, 3429], [3446, 3449], [3456, 3460], [3479, 3481], 234 | [3518, 3519], [3527, 3584], [3636, 3647], [3655, 3663], [3674, 3712], [3717, 3718], 235 | [3723, 3724], [3726, 3731], [3752, 3753], [3764, 3772], [3774, 3775], [3783, 3791], 236 | [3802, 3803], [3806, 3839], [3841, 3871], [3892, 3903], [3949, 3975], [3980, 4095], 237 | [4139, 4158], [4170, 4175], [4182, 4185], [4190, 4192], [4194, 4196], [4199, 4205], 238 | [4209, 4212], [4226, 4237], [4250, 4255], [4294, 4303], [4349, 4351], [4686, 4687], 239 | [4702, 4703], [4750, 4751], [4790, 4791], [4806, 4807], [4886, 4887], [4955, 4968], 240 | [4989, 4991], [5008, 5023], [5109, 5120], [5741, 5742], [5787, 5791], [5867, 5869], 241 | [5873, 5887], [5906, 5919], [5938, 5951], [5970, 5983], [6001, 6015], [6068, 6102], 242 | [6104, 6107], [6109, 6111], [6122, 6127], [6138, 6159], [6170, 6175], [6264, 6271], 243 | [6315, 6319], [6390, 6399], [6429, 6469], [6510, 6511], [6517, 6527], [6572, 6592], 244 | [6600, 6607], [6619, 6655], [6679, 6687], [6741, 6783], [6794, 6799], [6810, 6822], 245 | [6824, 6916], [6964, 6980], [6988, 6991], [7002, 7042], [7073, 7085], [7098, 7167], 246 | [7204, 7231], [7242, 7244], [7294, 7400], [7410, 7423], [7616, 7679], [7958, 7959], 247 | [7966, 7967], [8006, 8007], [8014, 8015], [8062, 8063], [8127, 8129], [8141, 8143], 248 | [8148, 8149], [8156, 8159], [8173, 8177], [8189, 8303], [8306, 8307], [8314, 8318], 249 | [8330, 8335], [8341, 8449], [8451, 8454], [8456, 8457], [8470, 8472], [8478, 8483], 250 | [8506, 8507], [8512, 8516], [8522, 8525], [8586, 9311], [9372, 9449], [9472, 10101], 251 | [10132, 11263], [11493, 11498], [11503, 11516], [11518, 11519], [11558, 11567], 252 | [11622, 11630], [11632, 11647], [11671, 11679], [11743, 11822], [11824, 12292], 253 | [12296, 12320], [12330, 12336], [12342, 12343], [12349, 12352], [12439, 12444], 254 | [12544, 12548], [12590, 12592], [12687, 12689], [12694, 12703], [12728, 12783], 255 | [12800, 12831], [12842, 12880], [12896, 12927], [12938, 12976], [12992, 13311], 256 | [19894, 19967], [40908, 40959], [42125, 42191], [42238, 42239], [42509, 42511], 257 | [42540, 42559], [42592, 42593], [42607, 42622], [42648, 42655], [42736, 42774], 258 | [42784, 42785], [42889, 42890], [42893, 43002], [43043, 43055], [43062, 43071], 259 | [43124, 43137], [43188, 43215], [43226, 43249], [43256, 43258], [43260, 43263], 260 | [43302, 43311], [43335, 43359], [43389, 43395], [43443, 43470], [43482, 43519], 261 | [43561, 43583], [43596, 43599], [43610, 43615], [43639, 43641], [43643, 43647], 262 | [43698, 43700], [43703, 43704], [43710, 43711], [43715, 43738], [43742, 43967], 263 | [44003, 44015], [44026, 44031], [55204, 55215], [55239, 55242], [55292, 55295], 264 | [57344, 63743], [64046, 64047], [64110, 64111], [64218, 64255], [64263, 64274], 265 | [64280, 64284], [64434, 64466], [64830, 64847], [64912, 64913], [64968, 65007], 266 | [65020, 65135], [65277, 65295], [65306, 65312], [65339, 65344], [65371, 65381], 267 | [65471, 65473], [65480, 65481], [65488, 65489], [65496, 65497]]; 268 | for (i = 0; i < ranges.length; i++) { 269 | start = ranges[i][0]; 270 | end = ranges[i][1]; 271 | for (j = start; j <= end; j++) { 272 | result[j] = true; 273 | } 274 | } 275 | return result; 276 | })(); 277 | 278 | function splitQuery(query) { 279 | var result = []; 280 | var start = -1; 281 | for (var i = 0; i < query.length; i++) { 282 | if (splitChars[query.charCodeAt(i)]) { 283 | if (start !== -1) { 284 | result.push(query.slice(start, i)); 285 | start = -1; 286 | } 287 | } else if (start === -1) { 288 | start = i; 289 | } 290 | } 291 | if (start !== -1) { 292 | result.push(query.slice(start)); 293 | } 294 | return result; 295 | } 296 | 297 | 298 | -------------------------------------------------------------------------------- /src/graphmb/visualize.py: -------------------------------------------------------------------------------- 1 | 2 | import networkx as nx 3 | import numpy as np 4 | #import tensorflow 5 | 6 | colors = [ 7 | "black", 8 | "red", 9 | "blue", 10 | "green", 11 | "orange", 12 | "purple", 13 | "brown", 14 | "pink", 15 | "yellow", 16 | "silver", 17 | "maroon", 18 | "fuchsia", 19 | "lime", 20 | "olive", 21 | "yellow", 22 | "navy", 23 | "teal", 24 | "steelblue", 25 | "darkred", 26 | "darkgreen", 27 | "darkblue", 28 | "darkorange", 29 | "lightpink", 30 | "lightgreen", 31 | "lightblue", 32 | "crimson", 33 | "darkviolet", 34 | "tomato", 35 | "tan", 36 | "tab:blue", 37 | "tab:orange", 38 | "tab:green", 39 | "tab:red", 40 | "tab:purple", 41 | "tab:brown", 42 | "tab:pink", 43 | "tab:gray", 44 | "tab:olive", 45 | "tab:cyan", 46 | ] 47 | 48 | def plot_embs(node_ids, node_embeddings_2dim, labels_to_node, centroids, hq_centroids, node_sizes, outputname=None): 49 | """Plot embs of most labels with most support 50 | 51 | Args: 52 | node_ids ([type]): [description] 53 | node_embeddings_2dim ([type]): [description] 54 | labels_to_node ([type]): [description] 55 | """ 56 | import matplotlib.pyplot as plt 57 | 58 | markers = ["o", "s", "p", "*"] 59 | if "NA" in labels_to_node: 60 | del labels_to_node["NA"] 61 | #breakpoint() 62 | labels_to_node = {label: labels_to_node[label] for label in labels_to_node if len(labels_to_node[label]) > 0} 63 | labels_to_plot = sorted(labels_to_node, key=lambda key: len(labels_to_node[key]), reverse=True)[ 64 | : len(colors) * len(markers) 65 | ]#[:20] 66 | # print("ploting these labels", [l, colors[il], len(labels_to_node[l]) for il, l in enumerate(labels_to_plot)]) 67 | x_to_plot = [] 68 | y_to_plot = [] 69 | colors_to_plot = [] 70 | sizes_to_plot = [] 71 | markers_to_plot = [] 72 | #print(labels_to_plot) 73 | plt.figure() 74 | #print(" LABEL COLOR SIZE DOTS") 75 | for i, l in enumerate(labels_to_plot): 76 | valid_nodes = 0 77 | if len(labels_to_node) == 0: 78 | continue 79 | for node in labels_to_node[l]: 80 | if node not in node_ids: 81 | # print("skipping", node) 82 | continue 83 | node_idx = node_ids.index(node) 84 | x_to_plot.append(node_embeddings_2dim[node_idx][0]) 85 | y_to_plot.append(node_embeddings_2dim[node_idx][1]) 86 | if node_sizes is not None: 87 | sizes_to_plot.append(node_sizes[node_idx]) 88 | else: 89 | sizes_to_plot.append(50) 90 | valid_nodes += 1 91 | colors_to_plot.append(colors[i % len(colors)]) 92 | markers_to_plot.append(markers[i // len(colors)]) 93 | # breakpoint() 94 | # print("plotting", l, colors[i % len(colors)], markers[i // len(colors)], len(labels_to_node[l]), valid_nodes) 95 | # plt.scatter(x_to_plot, y_to_plot, s=sizes_to_plot, c=colors[i], label=l) # , alpha=0.5) 96 | sc = plt.scatter( 97 | x_to_plot, 98 | y_to_plot, 99 | s=sizes_to_plot, 100 | c=colors[i % len(colors)], 101 | label=l, 102 | marker=markers[i // len(colors)], 103 | alpha=0.4, 104 | ) # , alpha=0.5) 105 | x_to_plot = [] 106 | y_to_plot = [] 107 | sizes_to_plot = [] 108 | plt.legend('') 109 | if centroids is not None: 110 | hq_centroids_mask = [x in hq_centroids for x in range(len(centroids))] 111 | lq_centroids_mask = [x not in hq_centroids for x in range(len(centroids))] 112 | # lq_centroids = set(range(len(centroids))) - set(hq_centroids) 113 | 114 | # plt.scatter( 115 | # centroids[lq_centroids_mask, 0], centroids[lq_centroids_mask, 1], c="black", label="centroids (LQ)", marker="x" 116 | # ) 117 | plt.scatter( 118 | centroids[hq_centroids_mask, 0], 119 | centroids[hq_centroids_mask, 1], 120 | c="black", 121 | label="centroids (HQ)", 122 | marker="P", 123 | ) 124 | 125 | # for n in node_embeddings: 126 | # plt.scatter(x_to_plot, y_to_plot, c=colors_to_plot) #, alpha=0.5) 127 | #plt.legend() 128 | if outputname is not None: 129 | print("saving embs plot to {}".format(outputname)) 130 | plt.savefig(outputname, bbox_inches="tight", dpi=400) 131 | else: 132 | plt.show() 133 | 134 | def plot_edges_sim(X, adj, scgs, outname="", max_edge_value=150, min_edge_value=2): 135 | """ 136 | X: feature matrix 137 | adj: adjacency matrix in sparse format 138 | """ 139 | # for each pair in adj, calculate sim 140 | x_values = [] 141 | y_values = [] 142 | x_same_scgs = [] 143 | y_same_scgs = [] 144 | plotted_edges = set() 145 | for x, (i, j) in enumerate(zip(adj.row, adj.col)): 146 | if i != j and (i,j) not in plotted_edges and (j,i) not in plotted_edges and adj.data[x] > min_edge_value: 147 | #y_values.append(np.dot(X[i], X[j])) 148 | #y_values.append(scipy.spatial.distance.cosine(X[i], X[j])) 149 | plotted_edges.add((i,j)) 150 | if len(scgs[i] & scgs[j]) > 0: 151 | y_same_scgs.append(np.dot(X[i], X[j])/(np.linalg.norm(X[i])*np.linalg.norm(X[j]))) 152 | x_same_scgs.append(adj.data[x]) 153 | #TODO plot edge weight by overlap 154 | else: 155 | y_values.append(np.dot(X[i], X[j])/(np.linalg.norm(X[i])*np.linalg.norm(X[j]))) 156 | x_values.append(adj.data[x]) 157 | if max_edge_value is not None: 158 | x_values = [min(x, max_edge_value) for x in x_values] 159 | x_same_scgs = [min(x, max_edge_value) for x in x_same_scgs] 160 | #x_values = adj.values 161 | #y_values = [] 162 | #for (i, j) in adj.indices: 163 | # y_values.append(np.dot(X[i], X[j] 164 | #)) 165 | assert len(x_values) == len(y_values) 166 | import matplotlib.pyplot as plt 167 | plt.set_loglevel("error") 168 | 169 | plt.figure(0) 170 | plt.scatter( 171 | x_values, 172 | y_values, label=outname, marker=".", alpha=0.5, s=1) 173 | plt.scatter( 174 | x_same_scgs, 175 | y_same_scgs, label=outname+"SCG", marker="o", alpha=0.5, s=3) 176 | 177 | plt.xlabel("edge weight capped at {}".format(max_edge_value)) 178 | 179 | plt.ylabel("cosine sim") 180 | plt.legend(loc='upper right') 181 | plt.savefig(outname + "edges_embs.png", dpi=500) 182 | #plt.show() 183 | plt.close() 184 | 185 | # dist histogram 186 | plt.figure(1) 187 | counts, edges, bars = plt.hist(y_values, bins=50) 188 | plt.bar_label(bars) 189 | plt.savefig(outname + "embs_dists_histogram.png", dpi=500) 190 | plt.close() 191 | 192 | 193 | 194 | 195 | def run_tsne(embs, dataset, cluster_to_contig, hq_bins, centroids=None): 196 | from sklearn.manifold import TSNE 197 | 198 | SEED = 0 199 | print("running tSNE") 200 | # filter only good clusters 201 | tsne = TSNE(n_components=2, random_state=SEED) 202 | if len(dataset.labels) == 1: 203 | label_to_node = {c: cluster_to_contig[c] for c in hq_bins} 204 | label_to_node["mq/lq"] = [] 205 | for c in cluster_to_contig: 206 | if c not in hq_bins: 207 | label_to_node["mq/lq"] += list(cluster_to_contig[c]) 208 | if centroids is not None: 209 | all_embs = tsne.fit_transform(np.concatenate((np.array(embs), np.array(centroids)), axis=0)) 210 | centroids_2dim = all_embs[embs.shape[0] :] 211 | node_embeddings_2dim = all_embs[: embs.shape[0]] 212 | else: 213 | centroids_2dim = None 214 | node_embeddings_2dim = tsne.fit_transform(embs) 215 | return node_embeddings_2dim, centroids_2dim 216 | 217 | def draw_nx_graph(adj, node_to_label, labels_to_node, basename, contig_sizes=None, node_titles=None, cluster_info={}): 218 | # draw graph with pybiz library, creates an HTML file with graph 219 | # del labels_to_node["NA"] 220 | from pyvis.network import Network 221 | 222 | labels_to_color = {l: colors[i % len(colors)] for i, l in enumerate(labels_to_node.keys())} 223 | labels_to_color["NA"] = "white" 224 | sorted(labels_to_node, key=lambda key: len(labels_to_node[key]), reverse=True)[: len(colors)] 225 | # node_labels to plot 226 | node_labels = { 227 | node: { 228 | "label": str(node) + ":" + str(node_to_label[node]), 229 | "color": labels_to_color[node_to_label[node]], 230 | } 231 | for node in node_to_label 232 | } 233 | if contig_sizes is not None: 234 | for n in node_labels: 235 | node_labels[n]["size"] = int(contig_sizes[n]) 236 | 237 | if node_titles is not None: 238 | for n in node_labels: 239 | node_labels[n]["title"] = node_titles[n] 240 | #breakpoint() 241 | graph = nx.from_scipy_sparse_matrix(adj, parallel_edges=False, create_using=None, edge_attribute='weight') 242 | graph.remove_edges_from(nx.selfloop_edges(graph)) 243 | nx.set_node_attributes(graph, node_labels) 244 | nodes_to_plot = [n for n in graph.nodes() if node_to_label[n] != "NA" and len(graph.edges(n)) > 0] 245 | 246 | net = Network(notebook=False, height="750px", width="100%") 247 | net.add_nodes( 248 | [int(n) for n in node_labels.keys() if n in nodes_to_plot], 249 | label=[node_labels[n]["label"] for n in node_labels if n in nodes_to_plot], 250 | size=[node_labels[n].get("size", 100000)/100_000 for n in node_labels if n in nodes_to_plot], 251 | color=[node_labels[n]["color"] for n in node_labels if n in nodes_to_plot], 252 | title=[node_labels[n].get("title", f"{cluster_info.get(node_to_label[n])}") for n in node_labels if n in nodes_to_plot], 253 | ) 254 | for u, v, a in graph.edges(data=True): 255 | if u != v: 256 | #if u not in net.get_nodes() or v not in net.get_nodes(): 257 | # breakpoint() 258 | weight = float(a["weight"].item()) 259 | if weight != 1: 260 | net.add_edge(int(u), int(v), color="gray", title="reads weight: {}".format(weight)) 261 | else: 262 | net.add_edge(int(u), int(v)) 263 | 264 | #net.toggle_physics(False) 265 | net.show_buttons() 266 | print("saving graph to", basename) 267 | net.show("{}.html".format(basename)) 268 | 269 | 270 | def connected_components(graph, node_to_label, basename, min_elems=1): 271 | # explore connected components 272 | connected = [c for c in sorted(nx.connected_components(graph), key=len, reverse=True) if len(c) > min_elems] 273 | print("writing components to", basename + "_node_to_component.csv") 274 | write_components_file(connected, basename + "_node_to_component.csv") 275 | multiple_contigs = 0 276 | mixed_contigs = 0 277 | for group in connected: 278 | multiple_contigs += 1 279 | group_labels = [node_to_label[c] for c in group if "edge" in c] 280 | group_labels = set(group_labels) 281 | if len(group_labels) > 1: 282 | mixed_contigs += 1 283 | # print(group, group_labels)Process some integers. 284 | 285 | disconnected = [c for c in sorted(nx.connected_components(graph), key=len, reverse=True) if len(c) <= min_elems] 286 | for group in disconnected: 287 | for node in group: 288 | graph.remove_node(node) 289 | 290 | print("graph density:", nx.density(graph)) 291 | print(">1", multiple_contigs) 292 | print("mixed groups", mixed_contigs) 293 | return connected, disconnected 294 | 295 | 296 | def write_components_file(components, outpath, minsize=2): 297 | """Write file mapping each contig/node to a component ID (diff for each assembly) 298 | 299 | Args: 300 | components (list): List of connected components of a graph 301 | outpath (str): path to write file 302 | minsize: minimum number of elements of a component 303 | """ 304 | with open(outpath, "w") as outfile: 305 | for ic, c in enumerate(components): 306 | if len(c) < minsize: 307 | continue 308 | for node in c: 309 | if "edge" in node: # ignore read nodes 310 | outfile.write(f"{node}\t{ic}\n") 311 | -------------------------------------------------------------------------------- /src/graphmb/unused/train_vae.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import datetime 3 | import os 4 | import tensorflow as tf 5 | import random 6 | import logging 7 | from tqdm import tqdm 8 | import mlflow 9 | import mlflow.tensorflow 10 | 11 | from graphmb.models import TH, TrainHelperVAE, VAEDecoder, VAEEncoder 12 | from graph_functions import set_seed, run_tsne, plot_embs, plot_edges_sim 13 | from graphmb.evaluate import calculate_overall_prf 14 | from vaegbin import name_to_model, TensorboardLogger, compute_clusters_and_stats, log_to_tensorboard, eval_epoch 15 | 16 | def prepare_data_for_vae(dataset): 17 | # less preparation necessary than for GNN 18 | node_raw = np.hstack((dataset.node_depths, dataset.node_kmers)) 19 | ab_dim = dataset.node_depths.shape[1] 20 | kmer_dim = dataset.node_kmers.shape[1] 21 | X = node_raw 22 | return X, ab_dim, kmer_dim 23 | 24 | def run_model_vae(dataset, args, logger, nrun): 25 | set_seed(args.seed) 26 | mlflow.tensorflow.autolog() 27 | node_names = np.array(dataset.node_names) 28 | RESULT_EVERY = args.evalepochs 29 | hidden_vae = args.hidden_vae 30 | output_dim_vae = args.embsize_vae 31 | epochs = args.epoch 32 | lr_vae = args.lr_vae 33 | clustering = args.clusteringalgo 34 | k = args.kclusters 35 | with mlflow.start_run(run_name=args.assembly.split("/")[-1] + "-" + args.outname): 36 | mlflow.log_params(vars(args)) 37 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 38 | train_log_dir = os.path.join(args.outdir, 'logs/' + args.outname + current_time + '/train') 39 | summary_writer = tf.summary.create_file_writer(train_log_dir) 40 | 41 | tb_handler = TensorboardLogger(summary_writer, runname=args.outname + current_time) 42 | logger.addHandler(tb_handler) 43 | #tf.summary.trace_on(graph=True) 44 | if nrun == 0: 45 | print("logging to tensorboard") 46 | logger.info("******* Running model: VAE **********") 47 | logger.info("***** Using raw kmer+abund features: {}".format(args.rawfeatures)) 48 | tf.config.run_functions_eagerly(True) 49 | X, ab_dim, kmer_dim = prepare_data_for_vae(dataset) 50 | cluster_mask = [True] * len(dataset.node_names) 51 | 52 | if not args.skip_preclustering and nrun == 0: 53 | cluster_labels, stats, _, hq_bins = compute_clusters_and_stats( 54 | X[cluster_mask], node_names[cluster_mask], dataset, clustering=clustering, k=k, 55 | amber=(args.labels is not None and "amber" in args.labels), 56 | cuda=args.cuda, 57 | ) 58 | else: 59 | stats = {"hq": 0, "epoch": 0} 60 | scores = [stats] 61 | losses = {"total": [], "ae": [], "gnn": [], "scg": []} 62 | all_cluster_labels = [] 63 | X = X.astype(np.float32) 64 | features = tf.constant(X) 65 | input_dim_gnn = output_dim_vae 66 | if nrun == 0: 67 | logger.info(f"*** Model input dim {X.shape[1]}") 68 | logger.info(f"*** output clustering dim {output_dim_vae}") 69 | 70 | gold_labels=np.array([dataset.labels.index(dataset.node_to_label[n]) for n in dataset.node_names]) 71 | encoder = VAEEncoder(ab_dim, kmer_dim, hidden_vae, zdim=output_dim_vae, dropout=args.dropout_vae) 72 | decoder = VAEDecoder(ab_dim, kmer_dim, hidden_vae, zdim=output_dim_vae, dropout=args.dropout_vae) 73 | th_vae = TrainHelperVAE(encoder, decoder, learning_rate=lr_vae, 74 | kld_weight=1/args.kld_alpha, 75 | classification=args.classify, n_classes=len(dataset.labels)) 76 | if args.eval_split == 0: 77 | train_idx = np.arange(len(features)) 78 | eval_idx = [] 79 | else: 80 | train_idx = np.array(random.sample(list(range(len(features))), int(len(features)*(1-args.eval_split)))) 81 | eval_idx = np.array([x for x in np.arange(len(features)) if x not in train_idx]) 82 | logging.info(f"**** using {len(train_idx)} for training and {len(eval_idx)} for eval") 83 | features = np.array(features) 84 | pbar_epoch = tqdm(range(epochs), disable=args.quiet, position=0) 85 | scores = [] 86 | best_embs = None 87 | best_model = None 88 | best_hq = 0 89 | best_epoch = 0 90 | batch_size = args.batchsize 91 | if batch_size == 0: 92 | batch_size = len(train_idx) 93 | 94 | 95 | batch_steps = [25, 75, 150, 300] 96 | batch_steps = [x for i, x in enumerate(batch_steps) if (2 ** (i+1))*batch_size < len(train_idx)] 97 | if nrun == 0: 98 | logger.info("**** initial batch size: {} ****".format(batch_size)) 99 | logger.info("**** epoch batch size doubles: {} ****".format(str(batch_steps))) 100 | vae_losses = [] 101 | step = 0 102 | #mlflow.create_experiment(name=args.outname + current_time, tags={'run': nrun, 'dataset': args.dataset, 'model': 'vae'}) 103 | #with mlflow.start_run(): 104 | for e in pbar_epoch: 105 | vae_epoch_losses = {"kld_loss": [], "Total loss": [], "kmer_loss": [], 106 | "ab_loss": [], "pred_loss": []} 107 | np.random.shuffle(train_idx) 108 | recon_loss = 0 109 | 110 | # train VAE in batches 111 | if e in batch_steps: 112 | #print(f'Increasing batch size from {batch_size:d} to {batch_size*2:d}') 113 | batch_size = batch_size * 2 114 | np.random.shuffle(train_idx) 115 | n_batches = len(train_idx)//batch_size 116 | pbar_vaebatch = tqdm(range(n_batches), disable=(args.quiet or batch_size == len(train_idx) or n_batches < 100), position=1, ascii=' =') 117 | for b in pbar_vaebatch: 118 | batch_idx = train_idx[b*batch_size:(b+1)*batch_size] 119 | vae_losses = th_vae.train_step(X[batch_idx], summary_writer, step, 120 | vae=True, gold_labels=gold_labels[batch_idx]) 121 | vae_epoch_losses["Total loss"].append(vae_losses[0]) 122 | vae_epoch_losses["kmer_loss"].append(vae_losses[1]) 123 | vae_epoch_losses["ab_loss"].append(vae_losses[2]) 124 | vae_epoch_losses["kld_loss"].append(vae_losses[3]) 125 | vae_epoch_losses["pred_loss"].append(vae_losses[4]) 126 | vae_epoch_losses["kld_weight"] = th_vae.kld_weight 127 | vae_epoch_losses["kmer_weight"] = th_vae.kmer_weight 128 | #vae_epoch_losses["ab_weight"] = th_vae.abundance_weight 129 | #pbar_vaebatch.set_description(f'E={e} L={np.mean(vae_epoch_losses["total_loss"][-10:]):.4f}') 130 | vae_epoch_losses_avg = {k: np.mean(v) for k, v in vae_epoch_losses.items()} 131 | losses_string = " ".join([f"{k}={v:.3f}" for k, v in vae_epoch_losses_avg.items()]) 132 | pbar_vaebatch.set_description(f'E={e} {losses_string}') 133 | step += 1 134 | vae_epoch_losses = {k: np.mean(v) for k, v in vae_epoch_losses.items()} 135 | log_to_tensorboard(summary_writer, vae_epoch_losses, step) 136 | mlflow.log_metrics(vae_epoch_losses, step=step) 137 | 138 | if args.eval_split > 0: 139 | eval_mu, eval_logsigma = th_vae.encoder(X[eval_idx], training=False) 140 | eval_mse1, eval_mse2, eval_kld, eval_pred = th_vae.loss(X[eval_idx], eval_mu, eval_logsigma, 141 | vae=True, training=False, gold_labels=gold_labels[eval_idx]) 142 | eval_loss = eval_mse1 + eval_mse2 + eval_kld + eval_pred 143 | eval_losses = {"eval loss": eval_loss, "eval kmer loss": eval_mse2, "eval ab loss": eval_mse1, 144 | "eval kld loss": eval_kld, "eval pred loss": eval_pred} 145 | log_to_tensorboard(summary_writer, eval_losses, step) 146 | 147 | 148 | else: 149 | eval_loss, eval_mse1, eval_mse2, eval_kld = 0, 0, 0, 0 150 | recon_loss = np.mean(vae_epoch_losses["Total loss"]) 151 | 152 | with summary_writer.as_default(): 153 | tf.summary.scalar('epoch', e, step=step) 154 | 155 | #gpu_mem_alloc = tf.config.experimental.get_memory_info('GPU:0')["peak"] / 1000000 if args.cuda else 0 156 | gpu_mem_alloc = tf.config.experimental.get_memory_usage('GPU:0') / 1000000 if args.cuda else 0 157 | if (e + 1) % RESULT_EVERY == 0 and e > args.evalskip: 158 | 159 | latent_features = encoder(features)[0] 160 | node_new_features = latent_features.numpy() 161 | if args.classify: 162 | labels = th_vae.classifier(latent_features, mask=np.arange(latent_features.shape[0]), training=False) 163 | 164 | with summary_writer.as_default(): 165 | tf.summary.scalar('Embs average', np.mean(node_new_features), step=step) 166 | tf.summary.scalar('Embs std', np.std(node_new_features), step=step) 167 | mlflow.log_metrics({'Embs average': np.mean(node_new_features), 168 | 'Embs std': np.std(node_new_features)}, step=step) 169 | 170 | cluster_labels, stats, _, hq_bins = compute_clusters_and_stats( 171 | node_new_features[cluster_mask], node_names[cluster_mask], 172 | dataset, clustering=clustering, k=k, tsne=args.tsne, 173 | use_labels=args.classify, amber=(args.labels is not None and "amber" in args.labels), 174 | cuda=args.cuda, 175 | ) 176 | stats["epoch"] = e 177 | scores.append(stats) 178 | #logger.info(str(stats)) 179 | with summary_writer.as_default(): 180 | tf.summary.scalar('hq_bins', stats["hq"], step=step) 181 | tf.summary.scalar('mq_bins', stats["mq"], step=step) 182 | mlflow.log_metrics(stats, step=step) 183 | all_cluster_labels.append(cluster_labels) 184 | if dataset.contig_markers is not None and stats["hq"] > best_hq: 185 | best_hq = stats["hq"] 186 | best_embs = node_new_features 187 | best_epoch = e 188 | #save_model(args, e, th, th_vae) 189 | 190 | elif dataset.contig_markers is None and stats["f1"] > best_hq: 191 | best_hq = stats["f1"] 192 | best_embs = node_new_features 193 | best_epoch = e 194 | #save_model(args, e, th, th_vae) 195 | # print('--- END ---') 196 | if args.quiet: 197 | logger.info(f"--- EPOCH {e:d} ---") 198 | logger.info(f"[VAE] L={recon_loss:.3f} HQ={stats['hq']} BestHQ={best_hq} BestEp={best_epoch} GPU={gpu_mem_alloc:.1f}MB") 199 | logger.info(str(stats)) 200 | scores_string = f"HQ={stats['hq']} BestHQ={best_hq} Best Epoch={best_epoch} F1={round(stats.get('f1_avg_bp',0), 3)}" 201 | losses_string = " ".join([f"{k}={v:.3f}" for k, v in vae_epoch_losses_avg.items()]) 202 | pbar_epoch.set_description( 203 | f"[VAE {losses_string} {scores_string} Max GPU MB={gpu_mem_alloc:.1f}" 204 | ) 205 | total_loss = recon_loss 206 | losses["ae"].append(recon_loss) 207 | losses["total"].append(total_loss) 208 | if best_embs is None: 209 | best_embs = node_new_features 210 | 211 | cluster_labels, stats, _, _ = compute_clusters_and_stats( 212 | best_embs[cluster_mask], node_names[cluster_mask], 213 | dataset, clustering=clustering, k=k, #cuda=args.cuda, 214 | ) 215 | stats["epoch"] = e 216 | scores.append(stats) 217 | # get best stats: 218 | # if concat_features: # use HQ 219 | hqs = [s["hq"] for s in scores] 220 | epoch_hqs = [s["epoch"] for s in scores] 221 | best_idx = np.argmax(hqs) 222 | mlflow.log_metrics(scores[best_idx], step=step+1) 223 | # else: # use F1 224 | # f1s = [s["f1"] for s in scores] 225 | # best_idx = np.argmax(f1s) 226 | logger.info(f">>> best epoch: {RESULT_EVERY + (best_idx*RESULT_EVERY)} : {scores[best_idx]} <<<") 227 | with open(f"{dataset.cache_dir}/{dataset.name}_best_contig2bin.tsv", "w") as f: 228 | f.write("@Version:0.9.0\n@SampleID:SAMPLEID\n@@SEQUENCEID\tBINID\n") 229 | for i in range(len(all_cluster_labels[best_idx])): 230 | f.write(f"{node_names[i]}\t{all_cluster_labels[best_idx][i]}\n") 231 | return best_embs, scores[best_idx] 232 | -------------------------------------------------------------------------------- /src/graphmb/unused/train_gnn_decode.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import datetime 3 | import os 4 | import tensorflow as tf 5 | import random 6 | import logging 7 | from tqdm import tqdm 8 | import mlflow 9 | 10 | from graphmb.models import TH, TrainHelperVAE, VAEDecoder, VAEEncoder, GVAE 11 | from graph_functions import set_seed, run_tsne, plot_embs, plot_edges_sim 12 | from graphmb.evaluate import calculate_overall_prf 13 | from vaegbin import name_to_model, TensorboardLogger, prepare_data_for_gnn, compute_clusters_and_stats, log_to_tensorboard, eval_epoch 14 | 15 | def run_model_gnn_recon(dataset, args, logger, nrun): 16 | set_seed(args.seed) 17 | node_names = np.array(dataset.node_names) 18 | RESULT_EVERY = args.evalepochs 19 | hidden_gnn = args.hidden_gnn 20 | hidden_vae = args.hidden_vae 21 | output_dim_gnn = args.embsize_gnn 22 | output_dim_vae = args.embsize_vae 23 | epochs = args.epoch 24 | lr_vae = args.lr_vae 25 | lr_gnn = args.lr_gnn 26 | nlayers_gnn = args.layers_gnn 27 | gname = args.model_name 28 | if gname == "vae": 29 | args.ae_only = True 30 | else: 31 | gmodel_type = name_to_model[gname.split("_")[0].upper()] 32 | clustering = args.clusteringalgo 33 | k = args.kclusters 34 | use_edge_weights = True 35 | use_disconnected = not args.quick 36 | cluster_markers_only = args.quick 37 | decay = 0.5 ** (2.0 / epochs) 38 | concat_features = args.concat_features 39 | use_ae = gname.endswith("_ae") or args.ae_only or gname == "vae" 40 | 41 | with mlflow.start_run(run_name=args.assembly.split("/")[-1] + "-" + args.outname): 42 | mlflow.log_params(vars(args)) 43 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 44 | train_log_dir = os.path.join(args.outdir, 'logs/' + args.outname + current_time + '/train') 45 | summary_writer = tf.summary.create_file_writer(train_log_dir) 46 | print("logging to tensorboard") 47 | tb_handler = TensorboardLogger(summary_writer, runname=args.outname + current_time) 48 | logger.addHandler(tb_handler) 49 | #tf.summary.trace_on(graph=True) 50 | 51 | logger.info("******* Running model: {} **********".format(gname)) 52 | logger.info("***** using edge weights: {} ******".format(use_edge_weights)) 53 | logger.info("***** using disconnected: {} ******".format(use_disconnected)) 54 | logger.info("***** concat features: {} *****".format(concat_features)) 55 | logger.info("***** cluster markers only: {} *****".format(cluster_markers_only)) 56 | logger.info("***** threshold adj matrix: {} *****".format(args.binarize)) 57 | logger.info("***** self edges only: {} *****".format(args.noedges)) 58 | args.rawfeatures = True 59 | logger.info("***** Using raw kmer+abund features: {}".format(args.rawfeatures)) 60 | tf.config.experimental_run_functions_eagerly(True) 61 | 62 | 63 | X, adj, cluster_mask, neg_pair_idx, pos_pair_idx = prepare_data_for_gnn( 64 | dataset, use_edge_weights, cluster_markers_only, use_raw=args.rawfeatures, 65 | binarize=args.binarize, remove_edges=args.noedges) 66 | logger.info("***** SCG neg pairs: {}".format(neg_pair_idx.shape)) 67 | logger.info("***** input features dimension: {}".format(X[cluster_mask].shape)) 68 | # pre train clustering 69 | if not args.skip_preclustering: 70 | cluster_labels, stats, _, hq_bins = compute_clusters_and_stats( 71 | X[cluster_mask], node_names[cluster_mask], 72 | dataset, clustering=clustering, k=k, tsne=args.tsne, 73 | amber=(args.labels is not None and "amber" in args.labels), 74 | #cuda=args.cuda, 75 | ) 76 | logger.info(f">>> Pre train stats: {str(stats)}") 77 | else: 78 | stats = {"hq": 0, "epoch":0 } 79 | 80 | 81 | pname = "" 82 | 83 | #plot edges vs initial embs 84 | id_to_scg = {i: set(dataset.contig_markers[node_name].keys()) for i, node_name in enumerate(dataset.node_names)} 85 | plot_edges_sim(X, dataset.adj_matrix, id_to_scg, f"{args.outdir}/{args.outname}_pretrain_") 86 | 87 | scores = [stats] 88 | losses = {"total": [], "ae": [], "gnn": [], "scg": []} 89 | all_cluster_labels = [] 90 | X = X.astype(np.float32) 91 | features = tf.constant(X) 92 | input_dim_gnn = X.shape[1] 93 | 94 | logger.info(f"*** Model input dim {X.shape[1]}, GNN input dim {input_dim_gnn}, use_ae: {use_ae}, run AE only: {args.ae_only}") 95 | 96 | S = [] 97 | logger.info(f"*** output clustering dim {output_dim_gnn}") 98 | 99 | model = GVAE(dataset.node_depths.shape[1], dataset.node_kmers.shape[1], 100 | X.shape[0], hidden_vae, zdim=output_dim_gnn, 101 | dropout=args.dropout_vae, layers=nlayers_gnn) 102 | model.adj = adj 103 | th = TH( 104 | features, 105 | gnn_model=model, 106 | lr=lr_gnn, 107 | all_different_idx=neg_pair_idx, 108 | all_same_idx=pos_pair_idx, 109 | ae_encoder=None, 110 | ae_decoder=None, 111 | latentdim=output_dim_gnn, 112 | gnn_weight=float(args.gnn_alpha), 113 | ae_weight=float(args.ae_alpha), 114 | scg_weight=float(args.scg_alpha), 115 | num_negatives=args.negatives, 116 | decoder_input=args.decoder_input, 117 | kmers_dim=dataset.node_kmers.shape[1], 118 | abundance_dim=dataset.node_depth.shape[1], 119 | ) 120 | th.adj = adj 121 | #model.summary() 122 | 123 | if args.eval_split == 0: 124 | train_idx = np.arange(len(features)) 125 | eval_idx = [] 126 | else: 127 | train_idx = np.array(random.sample(list(range(len(features))), int(len(features)*(1-args.eval_split)))) 128 | eval_idx = np.array([x for x in np.arange(len(features)) if x not in train_idx]) 129 | logging.info(f"**** using {len(train_idx)} for training and {len(eval_idx)} for eval") 130 | features = np.array(features) 131 | pbar_epoch = tqdm(range(epochs), disable=args.quiet, position=0) 132 | scores = [stats] 133 | best_embs = None 134 | best_model = None 135 | best_hq = 0 136 | best_epoch = 0 137 | batch_size = args.batchsize 138 | if batch_size == 0: 139 | batch_size = len(train_idx) 140 | logger.info("**** initial batch size: {} ****".format(batch_size)) 141 | batch_steps = [25, 75, 150, 300] 142 | batch_steps = [x for i, x in enumerate(batch_steps) if (2 ** (i+1))*batch_size < len(train_idx)] 143 | logger.info("**** epoch batch size doubles: {} ****".format(str(batch_steps))) 144 | step = 0 145 | for e in pbar_epoch: 146 | vae_epoch_losses = {"kld": [], "total": [], "kmer": [], "abundance": [], "scg": [], "gnn": []} 147 | np.random.shuffle(train_idx) 148 | recon_loss = 0 149 | 150 | # train VAE in batches 151 | if e in batch_steps: 152 | #print(f'Increasing batch size from {batch_size:d} to {batch_size*2:d}') 153 | batch_size = batch_size * 2 154 | np.random.shuffle(train_idx) 155 | n_batches = len(train_idx)//batch_size + 1 156 | pbar_vaebatch = tqdm(range(n_batches), disable=(args.quiet or batch_size == len(train_idx) or n_batches < 100), position=1, ascii=' =') 157 | for b in pbar_vaebatch: 158 | batch_idx = train_idx[b*batch_size:(b+1)*batch_size] 159 | #vae_losses = th_vae.train_step(X[batch_idx], summary_writer, step, vae=True) 160 | with summary_writer.as_default(): 161 | tf.summary.scalar('epoch', e, step=step) 162 | 163 | total_loss, gnn_loss, diff_loss, kmer_loss, ab_loss, kld_loss = th.train_unsupervised_decode(batch_idx) 164 | vae_epoch_losses["total"].append(total_loss) 165 | vae_epoch_losses["kmer"].append(kmer_loss) 166 | vae_epoch_losses["abundance"].append(ab_loss) 167 | vae_epoch_losses["kld"].append(kld_loss) 168 | vae_epoch_losses["scg"].append(diff_loss) 169 | vae_epoch_losses["gnn"].append(gnn_loss) 170 | gnn_loss = gnn_loss.numpy() 171 | diff_loss = diff_loss.numpy() 172 | pbar_vaebatch.set_description(f'E={e} L={np.mean(vae_epoch_losses["total"][-10:]):.4f}') 173 | step += 1 174 | 175 | vae_epoch_losses = {k: np.mean(v) for k, v in vae_epoch_losses.items()} 176 | log_to_tensorboard(summary_writer, vae_epoch_losses, step) 177 | mlflow.log_metrics(vae_epoch_losses, step=step) 178 | 179 | if args.eval_split > 0: 180 | eval_mu, eval_logsigma = th_vae.encoder(X[eval_idx], training=False) 181 | eval_mse1, eval_mse2, eval_kld = th_vae.loss(X[eval_idx], eval_mu, eval_logsigma, vae=True, training=False) 182 | eval_loss = eval_mse1 + eval_mse2 - eval_kld 183 | log_to_tensorboard(summary_writer, {"eval_kmer": eval_mse2, "eval_ab": eval_mse1, 184 | "eval_kld": eval_kld, "eval loss": eval_loss}, step) 185 | else: 186 | eval_loss, eval_mse1, eval_mse2, eval_kld = 0, 0, 0, 0 187 | 188 | 189 | #gpu_mem_alloc = tf.config.experimental.get_memory_info('GPU:0')["peak"] / 1000000 if args.cuda else 0 190 | gpu_mem_alloc = tf.config.experimental.get_memory_usage('GPU:0') / 1000000 if args.cuda else 0 191 | if (e + 1) % RESULT_EVERY == 0 and e > args.evalskip: 192 | #gnn_input_features = features 193 | #node_new_features = encoder(th.gnn_model(features, None))[0] 194 | node_new_features = th.gnn_model.encode(features, adj) 195 | #node_new_features = th.gnn_model(features, None) 196 | node_new_features = node_new_features.numpy() 197 | weights = th.gnn_model.get_weights() 198 | best_hq, best_embs, best_epoch, scores, best_model, cluster_labels = eval_epoch(logger, summary_writer, node_new_features, 199 | cluster_mask, weights, step, args, dataset, e, scores, 200 | best_hq, best_embs, best_epoch, best_model) 201 | if args.quiet: 202 | logger.info(f"--- EPOCH {e:d} ---") 203 | logger.info(f"[{gname} {nlayers_gnn}l {pname}] L={gnn_loss:.3f} D={diff_loss:.3f} R={recon_loss:.3f} HQ={scores[-1]['hq']} BestHQ={best_hq} Best Epoch={best_epoch} Max GPU MB={gpu_mem_alloc:.1f}") 204 | logger.info(str(stats)) 205 | mlflow.log_metrics(scores[-1], step=step) 206 | 207 | losses_string = " ".join([f"{k}={v:.3f}" for k, v in vae_epoch_losses.items()]) 208 | pbar_epoch.set_description( 209 | f"[{args.outname} {nlayers_gnn}l {pname}] {losses_string} HQ={scores[-1]['hq']} BestHQ={best_hq} Best Epoch={best_epoch} Max GPU MB={gpu_mem_alloc:.1f}" 210 | ) 211 | total_loss = gnn_loss + diff_loss + recon_loss 212 | losses["gnn"].append(gnn_loss) 213 | losses["scg"].append(diff_loss) 214 | losses["ae"].append(recon_loss) 215 | losses["total"].append(total_loss) 216 | 217 | 218 | if best_embs is None: 219 | best_embs = node_new_features 220 | 221 | cluster_labels, stats, _, _ = compute_clusters_and_stats( 222 | best_embs, node_names, dataset, clustering=clustering, k=k, 223 | #cuda=args.cuda, 224 | ) 225 | stats["epoch"] = e 226 | scores.append(stats) 227 | # get best stats: 228 | # if concat_features: # use HQ 229 | hqs = [s["hq"] for s in scores] 230 | epoch_hqs = [s["epoch"] for s in scores] 231 | best_idx = np.argmax(hqs) 232 | mlflow.log_metrics(scores[best_idx], step=step+1) 233 | # else: # use F1 234 | # f1s = [s["f1"] for s in scores] 235 | # best_idx = np.argmax(f1s) 236 | # S.append(stats) 237 | S.append(scores[best_idx]) 238 | logger.info(f">>> best epoch all contigs: {RESULT_EVERY + (best_idx*RESULT_EVERY)} : {stats} <<<") 239 | logger.info(f">>> best epoch: {RESULT_EVERY + (best_idx*RESULT_EVERY)} : {scores[best_idx]} <<<") 240 | with open(f"{dataset.name}_{gname}_{clustering}{k}_{nlayers_gnn}l_{pname}_results.tsv", "w") as f: 241 | f.write("@Version:0.9.0\n@SampleID:SAMPLEID\n@@SEQUENCEID\tBINID\n") 242 | for i in range(len(cluster_labels)): 243 | f.write(f"{node_names[i]}\t{cluster_labels[i]}\n") 244 | 245 | #plot edges vs initial embs 246 | #plot_edges_sim(best_vae_embs, dataset.adj_matrix, id_to_scg, "vae_") 247 | plot_edges_sim(best_embs, dataset.adj_matrix, id_to_scg, f"{args.outdir}/{args.outname}_posttrain_") 248 | return best_embs, scores[best_idx] 249 | -------------------------------------------------------------------------------- /src/graphmb/amber_eval.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from collections import defaultdict 3 | import itertools 4 | import pandas as pd 5 | from graphmb.contigsdataset import process_node_name 6 | 7 | 8 | def write_amber_bins(contig_to_bin, outputfile): 9 | with open(outputfile, "w") as f: 10 | f.write("#\n@Version:0.9.0\n@SampleID:SAMPLEID\n@@SEQUENCEID\tBINID\n") 11 | for c in contig_to_bin: 12 | f.write(f"{str(c)}\t{str(contig_to_bin[c])}\n") 13 | 14 | # adapt eval scripts from amber but simplified 15 | 16 | 17 | def load_binnings(file_path_query, assemblytype, columns=["SEQUENCEID", "BINID"]): 18 | # columns = ["SEQUENCEID", "BINID", "TAXID", "LENGTH", "_LENGTH"] 19 | # sample_id_to_query_df = OrderedDict() 20 | #for metadata in columns: 21 | # logging.getLogger('amber').info('Loading ' + metadata[2]['SAMPLEID']) 22 | # nrows = metadata[1] - metadata[0] + 1 23 | # col_indices = [k for k, v in metadata[3].items() if v in columns] 24 | # amber files start with 4 header lines 25 | df = pd.read_csv(file_path_query, sep="\t", comment="#", skiprows=3, header=0) # , usecols=col_indices) 26 | df = df.rename(columns={df.columns[i]: c for i, c in enumerate(columns)}, inplace=False) 27 | df = df.astype({'SEQUENCEID':'string'}) 28 | df["SEQUENCEID"]= df["SEQUENCEID"].apply(process_node_name, assembly_type=assemblytype) 29 | if "_LENGTH" in columns: 30 | df.rename(columns={"_LENGTH": "LENGTH"}, inplace=True) 31 | df["LENGTH"] = pd.to_numeric(df["LENGTH"]) 32 | return df 33 | 34 | 35 | def amber_eval(gs_path, bin_path, labels=["graphmb"], assemblytype="flye"): 36 | gs_df = load_binnings(gs_path, columns=["SEQUENCEID", "BINID", "_LENGTH"], assemblytype=assemblytype) 37 | bin_df = load_binnings(bin_path, assemblytype=assemblytype) 38 | # load_queries(gs_df, bin_df, labels=labels, options, options_gs) 39 | gs_df = gs_df[["SEQUENCEID", "BINID", "LENGTH"]].rename(columns={"LENGTH": "seq_length", 40 | "BINID": "genome_id", 41 | "SEQUENCEID": "SEQUENCEID"}) 42 | bin_df = bin_df[["SEQUENCEID", "BINID"]].rename(columns={"SEQUENCEID": "SEQUENCEID"}) 43 | gs_df["seq_length"] = pd.to_numeric(gs_df["seq_length"]) 44 | query_df = bin_df[["SEQUENCEID", "BINID"]] 45 | query_w_length = pd.merge(query_df, gs_df.drop_duplicates("SEQUENCEID"), on="SEQUENCEID", sort=False) 46 | query_w_length_no_dups = query_w_length.drop_duplicates("SEQUENCEID") 47 | gs_df_no_dups = gs_df.drop_duplicates("SEQUENCEID") 48 | percentage_of_assigned_bps = query_w_length_no_dups["seq_length"].sum() / gs_df_no_dups["seq_length"].sum() 49 | percentage_of_assigned_seqs = query_w_length_no_dups.shape[0] / gs_df_no_dups["SEQUENCEID"].shape[0] 50 | query_w_length_mult_seqs = query_df.reset_index().merge(gs_df, on="SEQUENCEID", sort=False) 51 | """if query_w_length.shape[0] < query_w_length_mult_seqs.shape[0]: 52 | query_w_length_mult_seqs.drop_duplicates(['index', 'genome_id'], inplace=True) 53 | confusion_df = query_w_length_mult_seqs.groupby(['BINID', 'genome_id'], sort=False).agg({'seq_length': 'sum', 'SEQUENCEID': 'count'}).rename(columns={'seq_length': 'genome_length', 'SEQUENCEID': 'genome_seq_counts'}) 54 | 55 | most_abundant_genome_df = confusion_df.loc[confusion_df.groupby('BINID', sort=False)['genome_length'].idxmax()] 56 | most_abundant_genome_df = most_abundant_genome_df.reset_index()[['BINID', 'genome_id']] 57 | 58 | matching_genomes_df = pd.merge(query_w_length_mult_seqs, most_abundant_genome_df, on=['BINID', 'genome_id']).set_index('index') 59 | query_w_length_mult_seqs.set_index('index', inplace=True) 60 | difference_df = query_w_length_mult_seqs.drop(matching_genomes_df.index).groupby(['index'], sort=False).first() 61 | query_w_length = pd.concat([matching_genomes_df, difference_df]) 62 | 63 | # Modify gs such that multiple binnings of the same sequence are not required 64 | matching_genomes_df = pd.merge(gs_df, query_w_length[['SEQUENCEID', 'genome_id']], on=['SEQUENCEID', 'genome_id']) 65 | matching_genomes_df = matching_genomes_df[['SEQUENCEID', 'genome_id', 'seq_length']].drop_duplicates(['SEQUENCEID', 'genome_id']) 66 | condition = gs_df_no_dups['SEQUENCEID'].isin(matching_genomes_df['SEQUENCEID']) 67 | difference_df = gs_df_no_dups[~condition] 68 | gs_df = pd.concat([difference_df, matching_genomes_df])""" 69 | 70 | # query_w_length_mult_seqs.reset_index(inplace=True) 71 | # query_w_length_mult_seqs = pd.merge(query_w_length_mult_seqs, most_abundant_genome_df, on=['BINID']) 72 | # grouped = query_w_length_mult_seqs.groupby(['index'], sort=False, as_index=False) 73 | # query_w_length = grouped.apply(lambda x: x[x['genome_id_x'] == x['genome_id_y'] if any(x['genome_id_x'] == x['genome_id_y']) else len(x) * [True]]) 74 | # query_w_length = query_w_length.groupby(['index'], sort=False).first().drop(columns='genome_id_y').rename(columns={'genome_id_x': 'genome_id'}) 75 | 76 | df = query_w_length 77 | 78 | confusion_df = ( 79 | query_w_length.groupby(["BINID", "genome_id"], sort=False) 80 | .agg({"seq_length": "sum", "SEQUENCEID": "count"}) 81 | .rename(columns={"seq_length": "genome_length", "SEQUENCEID": "genome_seq_counts"}) 82 | ) 83 | # self.confusion_df = confusion_df 84 | 85 | # rand_index_bp, adjusted_rand_index_bp = Metrics.compute_rand_index( 86 | # confusion_df, "BINID", "genome_id", "genome_length" 87 | # ) 88 | # rand_index_seq, adjusted_rand_index_seq = Metrics.compute_rand_index( 89 | # confusion_df, "BINID", "genome_id", "genome_seq_counts" 90 | # ) 91 | 92 | most_abundant_genome_df = ( 93 | confusion_df.loc[confusion_df.groupby("BINID", sort=False)["genome_length"].idxmax()] 94 | .reset_index() 95 | .set_index("BINID") 96 | ) 97 | 98 | query_w_length["seq_length_mean"] = query_w_length["seq_length"] 99 | 100 | precision_df = ( 101 | query_w_length.groupby("BINID", sort=False) 102 | .agg({"seq_length": "sum", "seq_length_mean": "mean", "SEQUENCEID": "count"}) 103 | .rename(columns={"seq_length": "total_length", "SEQUENCEID": "total_seq_counts"}) 104 | ) 105 | precision_df = pd.merge(precision_df, most_abundant_genome_df, on="BINID") 106 | precision_df.rename(columns={"genome_length": "tp_length", "genome_seq_counts": "tp_seq_counts"}, inplace=True) 107 | precision_df["precision_bp"] = precision_df["tp_length"] / precision_df["total_length"] 108 | precision_df["precision_seq"] = precision_df["tp_seq_counts"] / precision_df["total_seq_counts"] 109 | 110 | """if self.options.filter_tail_percentage: 111 | precision_df['total_length_pct'] = precision_df['total_length'] / precision_df['total_length'].sum() 112 | precision_df.sort_values(by='total_length', inplace=True) 113 | precision_df['cumsum_length_pct'] = precision_df['total_length_pct'].cumsum(axis=0) 114 | precision_df['precision_bp'].mask(precision_df['cumsum_length_pct'] <= self.options.filter_tail_percentage / 100, inplace=True) 115 | precision_df['precision_seq'].mask(precision_df['precision_bp'].isna(), inplace=True) 116 | precision_df.drop(columns=['cumsum_length_pct', 'total_length_pct'], inplace=True) 117 | if self.options.genome_to_unique_common: 118 | precision_df = precision_df[~precision_df['genome_id'].isin(self.options.genome_to_unique_common)]""" 119 | 120 | precision_avg_bp = precision_df["precision_bp"].mean() 121 | precision_avg_bp_sem = precision_df["precision_bp"].sem() 122 | precision_avg_bp_var = precision_df["precision_bp"].var() 123 | precision_avg_seq = precision_df["precision_seq"].mean() 124 | precision_avg_seq_sem = precision_df["precision_seq"].sem() 125 | precision_weighted_bp = precision_df["tp_length"].sum() / precision_df["total_length"].sum() 126 | precision_weighted_seq = precision_df["tp_seq_counts"].sum() / precision_df["total_seq_counts"].sum() 127 | 128 | genome_sizes_df = ( 129 | gs_df.groupby("genome_id", sort=False) 130 | .agg({"seq_length": "sum", "SEQUENCEID": "count"}) 131 | .rename(columns={"seq_length": "length_gs", "SEQUENCEID": "seq_counts_gs"}) 132 | ) 133 | precision_df = ( 134 | precision_df.reset_index().join(genome_sizes_df, on="genome_id", how="left", sort=False).set_index("BINID") 135 | ) 136 | precision_df["recall_bp"] = precision_df["tp_length"] / precision_df["length_gs"] 137 | precision_df["recall_seq"] = precision_df["tp_seq_counts"] / precision_df["seq_counts_gs"] 138 | precision_df["rank"] = "NA" 139 | 140 | recall_df = confusion_df.loc[confusion_df.groupby("genome_id", sort=False)["genome_length"].idxmax()] 141 | recall_df = ( 142 | recall_df.reset_index().join(genome_sizes_df, on="genome_id", how="right", sort=False).set_index("BINID") 143 | ) 144 | recall_df.fillna({"genome_length": 0, "genome_seq_counts": 0}, inplace=True) 145 | recall_df["recall_bp"] = recall_df["genome_length"] / recall_df["length_gs"] 146 | recall_df["recall_seq"] = recall_df["genome_seq_counts"] / recall_df["seq_counts_gs"] 147 | 148 | recall_df = recall_df.join(precision_df[["total_length", "seq_length_mean"]], how="left", sort=False) 149 | 150 | # if self.options.genome_to_unique_common: 151 | # recall_df = recall_df[~recall_df["genome_id"].isin(self.options.genome_to_unique_common)] 152 | 153 | recall_avg_bp = recall_df["recall_bp"].mean() 154 | recall_avg_bp_var = recall_df["recall_bp"].var() 155 | recall_avg_bp_sem = recall_df["recall_bp"].sem() 156 | recall_avg_seq = recall_df["recall_seq"].mean() 157 | recall_avg_seq_sem = recall_df["recall_seq"].sem() 158 | recall_weighted_bp = recall_df["genome_length"].sum() / recall_df["length_gs"].sum() 159 | recall_weighted_seq = recall_df["genome_seq_counts"].sum() / recall_df["seq_counts_gs"].sum() 160 | 161 | # Compute recall as in CAMI 1 162 | """unmapped_genomes = set(gs_df["genome_id"].unique()) - set(precision_df["genome_id"].unique()) 163 | #if self.options.genome_to_unique_common: 164 | # unmapped_genomes -= set(self.options.genome_to_unique_common) 165 | num_unmapped_genomes = len(unmapped_genomes) 166 | prec_copy = precision_df.reset_index() 167 | if num_unmapped_genomes: 168 | prec_copy = prec_copy.reindex( 169 | prec_copy.index.tolist() + list(range(len(prec_copy), len(prec_copy) + num_unmapped_genomes)) 170 | ).fillna(0.0) 171 | self.metrics.recall_avg_bp_cami1 = prec_copy["recall_bp"].mean() 172 | self.metrics.recall_avg_seq_cami1 = prec_copy["recall_seq"].mean() 173 | self.metrics.recall_avg_bp_sem_cami1 = prec_copy["recall_bp"].sem() 174 | self.metrics.recall_avg_seq_sem_cami1 = prec_copy["recall_seq"].sem() 175 | self.metrics.recall_avg_bp_var_cami1 = prec_copy["recall_bp"].var() 176 | self.recall_df_cami1 = prec_copy""" 177 | # End Compute recall as in CAMI 1 178 | 179 | accuracy_bp = precision_df["tp_length"].sum() / recall_df["length_gs"].sum() 180 | accuracy_seq = precision_df["tp_seq_counts"].sum() / recall_df["seq_counts_gs"].sum() 181 | 182 | precision_df = precision_df.sort_values(by=["recall_bp"], axis=0, ascending=False) 183 | recall_df = recall_df 184 | metrics = { 185 | "precision_avg_bp": precision_avg_bp, 186 | "precision_avg_bp_sem": precision_avg_bp_sem, 187 | "precision_avg_bp_var": precision_avg_bp_var, 188 | "precision_avg_seq": precision_avg_seq, 189 | "precision_avg_seq_sem": precision_avg_seq_sem, 190 | "precision_weighted_bp": precision_weighted_bp, 191 | "precision_weighted_seq": precision_weighted_seq, 192 | "recall_avg_bp": recall_avg_bp, 193 | "recall_avg_bp_sem": recall_avg_bp_sem, 194 | "recall_avg_bp_var": recall_avg_bp_var, 195 | "recall_avg_seq": recall_avg_seq, 196 | "recall_avg_seq_sem": recall_avg_seq_sem, 197 | "recall_weighted_bp": recall_weighted_bp, 198 | "recall_weighted_seq": recall_weighted_seq, 199 | "accuracy_bp": accuracy_bp, 200 | "accuracy_seq": accuracy_seq, 201 | "f1_avg_bp": (2*precision_avg_bp*recall_avg_bp)/(precision_avg_bp+recall_avg_bp) if (precision_avg_bp+recall_avg_bp) > 0 else 0 202 | } 203 | bins_eval = calc_num_recovered_genomes(precision_df, [0.9, 0.5], [0.05, 0.1]) 204 | return metrics, bins_eval 205 | 206 | 207 | def calc_num_recovered_genomes(bins, min_completeness, max_contamination): 208 | counts_list = [] 209 | for x in itertools.product(min_completeness, max_contamination): 210 | count = bins[(bins["recall_bp"] > x[0]) & (bins["precision_bp"] > (1 - x[1]))].shape[0] 211 | counts_list.append(("> " + str(int(x[0] * 100)) + "% completeness", "< " + str(int(x[1] * 100)) + "%", count)) 212 | 213 | pd_counts = pd.DataFrame(counts_list, columns=["Completeness", "Contamination", "count"]) 214 | pd_counts = pd.pivot_table( 215 | pd_counts, 216 | values="count", 217 | index=["Contamination"], 218 | columns=["Completeness"], 219 | ).reset_index() 220 | return pd_counts 221 | --------------------------------------------------------------------------------