├── .gitignore ├── .readthedocs.yml ├── CHANGELOG.md ├── LICENSE ├── README.md ├── docs ├── Makefile ├── api.rst ├── changelog.md ├── cite.md ├── conf.py ├── examples.rst ├── explain.rst ├── index.rst ├── layers.rst ├── make.bat ├── quickstart.md ├── requirements.txt └── tutorial.ipynb ├── fastISM ├── __init__.py ├── change_range.py ├── fast_ism.py ├── fast_ism_utils.py ├── flatten_model.py ├── ism_base.py └── models │ ├── __init__.py │ ├── basset.py │ ├── bpnet.py │ ├── bpnet_dense.py │ └── factorized_basset.py ├── images ├── annotated_basset.pdf ├── logo.jpeg └── logo_1280x640.jpeg ├── notebooks ├── Akita.ipynb ├── BassetFast.ipynb ├── BassetTFKeras.ipynb ├── DeepSHAPBenchmark.ipynb ├── Enformer.ipynb ├── GradxInputBenchmark.ipynb ├── ISMBenchmark.ipynb ├── IntegratedGradientsBenchmark.ipynb ├── MaxBatchSize.ipynb ├── TimeBassetParts.ipynb ├── VaryParams.ipynb ├── colab │ └── DeepSEA.ipynb ├── seq_to_np.ipynb └── test.seq.txt ├── pyproject.toml └── test ├── context.py ├── test_cropping.py ├── test_custom_stop_layer.py ├── test_example_architectures.py ├── test_simple_multi_in_architectures.py ├── test_simple_nested_architectures.py ├── test_simple_single_in_multi_out_architectures.py ├── test_simple_single_in_single_out_architectures.py ├── test_simple_skip_conn_architectures.py └── test_unresolved.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # surag 10 | .vscode 11 | .DS_Store 12 | *.npy 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # Required 2 | version: 2 3 | 4 | # Build documentation in the docs/ directory with Sphinx 5 | sphinx: 6 | configuration: docs/conf.py 7 | 8 | formats: all 9 | 10 | python: 11 | version: 3.7 12 | install: 13 | - requirements: docs/requirements.txt 14 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## [Unreleased] 4 | 5 | ## [0.5.0] - 2022-02-08 6 | 7 | ### Added 8 | - Cropping1D Support 9 | - User specified stop layers (undocumented) 10 | - Support for MultiHeadAttention layers 11 | 12 | ### Changed 13 | - Refinements to segmenting 14 | - Segment starting with see-through layers followed by Conv1Ds with valid padding are kept in one segment 15 | - Layers are duplicated with `from_config` and `get_config` 16 | - Generalized pooling layers and added ability to add custom pooling layers 17 | 18 | ### Fixed 19 | - Runs for batch size 1 20 | - Multi-input layers that had the same input twice (e.g. Add()([x,x])) would not run, fixed this 21 | - Support for newer versions of tensorflow which changed sub-models class from keras to tf.keras (in `flatten_model`) 22 | - Stop layers were traversed redundantly 23 | 24 | ## [0.4.0] - 2020-09-16 25 | 26 | ### Added 27 | - Sequences for benchmarking in notebooks dir and a notebook to process the sequence 28 | - Benchmarking notebooks 29 | - Notebook to time Basset conv and fc separately 30 | - Ability to specify custom mutations 31 | - For each mutation, models only run on input sequences for which character is different from mutation. As a result, each batch usually has a different size. This is slow for the first few batches as it entails a one-time cost. 32 | - Lots of documentation and a logo! 33 | 34 | ### Changed 35 | - Models updated: 36 | - Activation added to Basset 37 | - Num output for Basset and Factorized Basset 38 | - For BPNet, only one channel output and one counts instead of two 39 | 40 | ### Fixed 41 | - FastISM object would keep intermediate outputs of a batch even after it was used, as a result it would occupy extra memory. Get rid of such objects now through a `cleanup()` function. This has stopped GPU Resource errors that popped up after running a few batches 42 | 43 | ## [0.3.0] - 2020-08-24 44 | 45 | ### Added 46 | - Support for multi-input models where alternate input does not merge with primary sequence input before a stop layer. 47 | - Support for layers that dependend on exact order of inputs, e.g. Subtract and Concat. 48 | 49 | 50 | ## [0.2.0] - 2020-08-22 51 | 52 | ### Added 53 | - Support for recursively defined networks with 3 test cases 54 | - This Changelog file. 55 | 56 | ### Changed 57 | - BPNet test cases atol changed to 1e-5 so they pass deterministically 58 | 59 | ## [0.1.3] - 2020-08-21 60 | ### Added 61 | - First PyPI release and tagged version 62 | - Tested and working on non-recursively defined single-input, single and multi-output architectures 63 | - Tested and working on arcitectures with skip connections 64 | 65 | --- 66 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). 67 | 68 | 69 | [unreleased]: https://github.com/kundajelab/fastISM/compare/v0.5.0...HEAD 70 | [0.5.0]: https://github.com/kundajelab/fastISM/compare/v0.4.0...v0.5.0 71 | [0.4.0]: https://github.com/kundajelab/fastISM/compare/v0.3.0...v0.4.0 72 | [0.3.0]: https://github.com/kundajelab/fastISM/compare/v0.2.0...v0.3.0 73 | [0.2.0]: https://github.com/kundajelab/fastISM/compare/v0.1.3...v0.2.0 74 | [0.1.3]: https://github.com/kundajelab/fastISM/releases/tag/v0.1.3 75 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Kundaje Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![](https://github.com/kundajelab/fastISM/raw/docs/images/logo.jpeg)](https://github.com/kundajelab/fastISM) 2 | 3 | [![](https://img.shields.io/pypi/v/fastism.svg)](https://pypi.org/project/fastism/) [![](https://readthedocs.org/projects/fastism/badge/?version=latest)](https://fastism.readthedocs.io/en/latest/?badge=latest) 4 | 5 | # Quickstart 6 | 7 | A Keras implementation for fast in-silico saturated mutagenesis (ISM) for convolution-based architectures. It speeds up ISM by 10x or more by restricting computation to those regions of each layer that are affected by a mutation in the input. 8 | 9 | ## Installation 10 | 11 | Currently, fastISM is available to download from PyPI. Bioconda support is expected to be added in the future. fastISM requires TensorFlow 2.3.0 or above. 12 | ```bash 13 | pip install fastism 14 | ``` 15 | 16 | ## Usage 17 | 18 | fastISM provides a simple interface that takes as input Keras models. For any Keras ``model`` that takes in sequence as input of dimensions `(B, S, C)`, where 19 | - `B`: batch size 20 | - `S`: sequence length 21 | - `C`: number of characters in vocabulary (e.g. 4 for DNA/RNA, 20 for proteins) 22 | 23 | Perform ISM as follows: 24 | 25 | ```python 26 | from fastism import FastISM 27 | 28 | fast_ism_model = FastISM(model) 29 | 30 | for seq_batch in sequences: 31 | # seq_batch has dim (B, S, C) 32 | ism_seq_batch = fast_ism_model(seq_batch) 33 | # ism_seq_batch has dim (B, S, num_outputs) 34 | ``` 35 | 36 | fastISM does a check for correctness when the model is initialised, which may take a few seconds depending on the size of your model. This ensures that the outputs of the model match that of an unoptimised implementation. You can turn it off as `FastISM(model, test_correctness=False)`. fastISM also supports introducing specific mutations, mutating different ranges of the input sequence, and models with multiple outputs. Check the [Examples](https://fastism.readthedocs.io/en/latest/examples.html) section of the documentation for more details. An executable tutorial is available on [Colab](https://colab.research.google.com/github/kundajelab/fastISM/blob/master/notebooks/colab/DeepSEA.ipynb). 37 | 38 | ## Benchmark 39 | You can estimate the speedup obtained by comparing with a naive implementation of ISM. 40 | ```python 41 | # Test this code as is 42 | >>> from fastism import FastISM, NaiveISM 43 | >>> from fastism.models.basset import basset_model 44 | >>> import tensorflow as tf 45 | >>> import numpy as np 46 | >>> from time import time 47 | 48 | >>> model = basset_model(seqlen=1000) 49 | >>> naive_ism_model = NaiveISM(model) 50 | >>> fast_ism_model = FastISM(model) 51 | 52 | >>> def time_ism(m, x): 53 | t = time() 54 | o = m(x) 55 | print(time()-t) 56 | return o 57 | 58 | >>> x = tf.random.uniform((1024, 1000, 4), 59 | dtype=model.input.dtype) 60 | 61 | >>> naive_out = time_ism(naive_ism_model, x) 62 | 144.013728 63 | >>> fast_out = time_ism(fast_ism_model, x) 64 | 13.894407 65 | >>> np.allclose(naive_out, fast_out, atol=1e-6) 66 | True 67 | >>> np.allclose(fast_out, naive_out, atol=1e-6) 68 | True # np.allclose is not symmetric 69 | ``` 70 | 71 | See `notebooks/ISMBenchmark.ipynb` for benchmarking code that accounts for initial warm-up. 72 | 73 | ## Getting Help 74 | fastISM supports the most commonly used subset of Keras for biological sequence-based models. Occasionally, you may find that some of the layers used in your model are not supported by fastISM. Refer to the [Supported Layers](https://fastism.readthedocs.io/en/latest/layers.html) section in Documentation for instructions on how to incorporate custom layers. In a few cases, the fastISM model may fail correctness checks, indicating there are likely some issues in the fastISM code. In such cases or any other bugs, feel free to reach out to the author by posting an [Issue](https://github.com/kundajelab/fastISM/issues) on GitHub along with your architecture, and we'll try to work out a solution! 75 | 76 | ## Citation 77 | fastISM: Performant *in-silico* saturation mutagenesis for convolutional neural networks; Surag Nair, Avanti Shrikumar*, Jacob Schreiber*, Anshul Kundaje (Bioinformatics 2022) 78 | [http://doi.org/10.1093/bioinformatics/btac135](http://doi.org/10.1093/bioinformatics/btac135). 79 | 80 | \*equal contribtion 81 | 82 | Preprint available on bioRxiv. 83 | 84 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | .. _api: 2 | 3 | fastISM package 4 | =============== 5 | 6 | fastISM takes a Keras model as input. The main steps of fastISM are as follows: 7 | 8 | 1. One-time Initialization (:py:func:`fastISM.fast_ism_utils.generate_models`): 9 | 10 | - Obtain the computational graph from the model. This is done in :py:func:`fastISM.flatten_model.get_flattened_graph`. 11 | - Chunk the computational graph into segments that can be run as a unit. This is done in :py:func:`fastISM.fast_ism_utils.segment_model`. 12 | - Augment the model to create an “intermediate output model” (referred to as ``intout_model`` in the code) that returns intermediate outputs at the end of each segment for reference input sequences. This is done in :py:func:`fastISM.fast_ism_utils.generate_intermediate_output_model`. 13 | - Create a second “mutation propagation model” (referred to as ``fast_ism_model`` in the code) that largely resembles the original model, but incorporates as additional inputs the necessary flanking regions from outputs of the IntOut model on reference input sequences between segments. This is done in :py:func:`fastISM.fast_ism_utils.generate_fast_ism_model`. 14 | 15 | 16 | 2. For each batch of input sequences: 17 | 18 | - Run the ``intout_model`` on the sequences (unperturbed) and cache the intermediate outputs at the end of each segment. This is done in :py:func:`fastISM.fast_ism.FastISM.pre_change_range_loop_prep`. 19 | - For each positional mutation: 20 | 21 | - Introduce the mutation in the input sequences 22 | - Run the ``fast_ism_model`` feeding as input appropriate slices of the ``intout_model`` outputs. This is done in :py:func:`fastISM.fast_ism.FastISM.get_ith_output`. 23 | 24 | See :ref:`How fastISM Works ` for a more intuitive understanding of the algorithm. 25 | 26 | ism\_base module 27 | ------------------------ 28 | 29 | This module contains a :class:`base ISM ` class, from which the :class:`NaiveISM ` and :class:`FastISM ` classes inherit. It also includes implementation of :class:`NaiveISM `. 30 | 31 | .. automodule:: fastISM.ism_base 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | fast\_ism module 37 | ------------------------ 38 | 39 | This module contains the :class:`FastISM ` class. 40 | 41 | .. automodule:: fastISM.fast_ism 42 | :members: 43 | :undoc-members: 44 | :show-inheritance: 45 | 46 | fast\_ism\_utils module 47 | ------------------------------- 48 | 49 | .. automodule:: fastISM.fast_ism_utils 50 | :members: 51 | :undoc-members: 52 | :show-inheritance: 53 | 54 | change\_range module 55 | ---------------------------- 56 | 57 | .. automodule:: fastISM.change_range 58 | :members: 59 | :undoc-members: 60 | :show-inheritance: 61 | 62 | 63 | flatten\_model module 64 | ----------------------------- 65 | 66 | This module implements functions required to take an arbitrary Keras model and reduce them to a graph representation that is then manipulated by :mod:`fast_ism_utils `. 67 | 68 | .. automodule:: fastISM.flatten_model 69 | :members: 70 | :undoc-members: 71 | :show-inheritance: 72 | 73 | -------------------------------------------------------------------------------- /docs/changelog.md: -------------------------------------------------------------------------------- 1 | ../CHANGELOG.md -------------------------------------------------------------------------------- /docs/cite.md: -------------------------------------------------------------------------------- 1 | # Citation 2 | 3 | fastISM: Performant *in-silico* saturation mutagenesis for convolutional neural networks; Surag Nair, Avanti Shrikumar, Anshul Kundaje (Bioinformatics 2022) 4 | [http://doi.org/10.1093/bioinformatics/btac135](http://doi.org/10.1093/bioinformatics/btac135) 5 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('..')) 16 | from fastISM.fast_ism_utils import * 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'fastISM' 21 | copyright = '2022, Kundaje Lab' 22 | author = 'Surag Nair' 23 | 24 | # -- General configuration --------------------------------------------------- 25 | 26 | # Add any Sphinx extension module names here, as strings. They can be 27 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 28 | # ones. 29 | extensions = ['sphinx.ext.autodoc', 'recommonmark', 'nbsphinx'] 30 | 31 | 32 | # Add any paths that contain templates here, relative to this directory. 33 | templates_path = ['_templates'] 34 | 35 | # List of patterns, relative to source directory, that match files and 36 | # directories to ignore when looking for source files. 37 | # This pattern also affects html_static_path and html_extra_path. 38 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 39 | 40 | 41 | # -- Added by Surag --------------------------------------------------------- 42 | pygments_style = 'colorful' 43 | master_doc = 'index' 44 | 45 | def format_set(s): 46 | return ', '.join(sorted(s)) 47 | 48 | 49 | # supported layers 50 | rst_epilog = """ 51 | .. |SEETHRU| replace:: {} 52 | .. |AGG| replace:: {} 53 | .. |LOCAL| replace:: {} 54 | .. |STOP| replace:: {} 55 | .. |POOL| replace:: {} 56 | """.format(format_set(SEE_THROUGH_LAYERS), 57 | format_set(AGGREGATE_LAYERS), 58 | format_set(LOCAL_LAYERS), 59 | format_set(STOP_LAYERS), 60 | format_set(POOLING_LAYERS)) 61 | 62 | 63 | # -- Options for HTML output ------------------------------------------------- 64 | 65 | # The theme to use for HTML and HTML Help pages. See the documentation for 66 | # a list of builtin themes. 67 | # 68 | html_theme = 'alabaster' 69 | 70 | 71 | html_theme_options = { 72 | # 'code_font_size': '14px' 73 | } 74 | 75 | # Add any paths that contain custom static files (such as style sheets) here, 76 | # relative to this directory. They are copied after the builtin static files, 77 | # so a file named "default.css" will overwrite the builtin "default.css". 78 | html_static_path = ['_static'] 79 | -------------------------------------------------------------------------------- /docs/examples.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | 4 | This section covers some of the common use cases and functionalities of fastISM. 5 | 6 | fastISM provides a simple interface that takes as input Keras model For any Keras ``model`` that takes in sequence as input of dimensions ``(B, S, C)``, where 7 | 8 | - ``B``: batch size 9 | - ``S``: sequence length 10 | - ``C``: number of characters in vocabulary (e.g. 4 for DNA/RNA, 20 for proteins) 11 | 12 | Alternate Mutations 13 | ------------------- 14 | By default, inputs at the ith position are set to zero. It is possible to specify mutations of interest by passing them to ``replace_with`` in the call to the fastISM model. To perform ISM with all possible mutations for DNA: 15 | 16 | .. code-block:: python 17 | 18 | fast_ism_model = FastISM(model) 19 | 20 | mutations = [[1,0,0,0], 21 | [0,1,0,0], 22 | [0,0,1,0], 23 | [0,0,0,1]] 24 | 25 | for seq_batch in sequences: 26 | # seq_batch has dim (B, S, C) 27 | for m in mutations: 28 | ism_seq_batch = fast_ism_model(seq_batch, replace_with=m) 29 | # ism_seq_batch has dim (B, S, num_outputs) 30 | # process/store ism_seq_batch 31 | 32 | Each ``ism_seq_batch`` has the same dimensions ``(B, S, num_outputs)``. The outputs of the model are computed on the mutations only for those positions where the base differs from the mutation. Where the base is the same as the mutation, the output is the same as for the unperturbed sequence. 33 | 34 | Alternate Ranges 35 | ---------------- 36 | By default, mutations are introduced at every single position in the input. You can also set a list of equal-sized ranges as input instead of single positions. Consider a model that takes as input 1000 length sequences, and we wish to introduce a specific mutation of length 3 in the central 150 positions: 37 | 38 | **TODO**: test this 39 | 40 | .. code-block:: python 41 | 42 | # specific mutation to introduce 43 | mut = [[0,0,0,1], 44 | [0,0,0,1], 45 | [0,0,0,1]] 46 | 47 | # ranges where mutation should be introduced 48 | mut_ranges = [(i,i+3) for i in range(425,575)] 49 | 50 | fast_ism_model = FastISM(model, 51 | change_ranges = mut_ranges) 52 | 53 | for seq_batch in sequences: 54 | ism_seq_batch = fast_ism_model(seq_batch, replace_with=mut) 55 | 56 | Multi-input Models 57 | ------------------ 58 | fastISM supports models which have other inputs in addition to the sequence input that is perturbed. These alternate inputs are assumed to stay constant through different perturbations of the primary sequence input. Consider the model below in which an addition vector is concatenated with the flattened sequence output: 59 | 60 | .. code-block:: python 61 | 62 | def get_model(): 63 | rna = tf.keras.Input((100,)) # non-sequence input 64 | seq = tf.keras.Input((100,4)) 65 | 66 | x = tf.keras.layers.Conv1D(20, 3)(seq) 67 | x = tf.keras.layers.Conv1D(20, 3)(x) 68 | x = tf.keras.layers.Flatten()(x) 69 | 70 | rna_fc = tf.keras.layers.Dense(10)(rna) 71 | 72 | x = tf.keras.layers.Concatenate()([x, rna_fc]) 73 | x = tf.keras.layers.Dense(10)(x) 74 | x = tf.keras.layers.Dense(1)(x) 75 | model = tf.keras.Model(inputs=[rna,seq], outputs=x) 76 | 77 | return model 78 | 79 | To inform fastISM that the second input is the primary sequence input that will be perturbed: 80 | 81 | .. code-block:: python 82 | 83 | >>> model = get_model() 84 | >>> fast_ism_model = FastISM(model, seq_input_idx=1) 85 | 86 | Then to obtain the outputs: 87 | 88 | .. code-block:: python 89 | 90 | for rna_batch, seq_batch in data_batches: 91 | ism_batch = fast_ism_model([rna_batch, seq_batch]) 92 | 93 | # or equivalently without splitting inputs 94 | for data_batch in data_batches 95 | ism_batch = fast_ism_model(data_batch) 96 | 97 | **NOTE**: Currently, multi-input models in which descendants of alternate inputs interact directly with descendants of primary sequence input *before* a :ref:`Stop Layer ` are not supported, i.e. a descendant of an alternate input in general should only interact with a flattened version of primary input sequence. 98 | 99 | Recursively Defined Models 100 | -------------------------- 101 | Keras allows defining models in a nested fashion. As such, recursively defined models should not pose an issue to fastISM. The example below works: 102 | 103 | .. code-block:: python 104 | 105 | def res_block(input_shape): 106 | inp = tf.keras.Input(shape=input_shape) 107 | x = tf.keras.layers.Conv1D(20, 3, padding='same')(inp) 108 | x = tf.keras.layers.Add()([inp, x]) 109 | model = tf.keras.Model(inputs=inp, outputs=x) 110 | return model 111 | 112 | def fc_block(input_shape): 113 | inp = tf.keras.Input(shape=input_shape) 114 | x = tf.keras.layers.Dense(10)(inp) 115 | x = tf.keras.layers.Dense(1)(x) 116 | 117 | model = tf.keras.Model(inputs=inp, outputs=x) 118 | return model 119 | 120 | def get_model(): 121 | res = res_block(input_shape=(108,20))) 122 | fcs = fc_block(input_shape=(36*20,)) 123 | 124 | inp = tf.keras.Input((108, 4)) 125 | x = tf.keras.layers.Conv1D(20, 3, padding='same')(inp) 126 | x = res(x) 127 | x = tf.keras.layers.MaxPooling1D(3)(x) 128 | x = tf.keras.layers.Flatten()(x) 129 | x = fcs(x) 130 | 131 | model = tf.keras.Model(inputs=inp, outputs=x) 132 | 133 | return model 134 | 135 | >>> model = get_model() 136 | >>> fast_ism_model = FastISM(model) -------------------------------------------------------------------------------- /docs/explain.rst: -------------------------------------------------------------------------------- 1 | .. _explain: 2 | 3 | How fastISM Works 4 | ================= 5 | 6 | This section gives a high level overview of the fastISM algorithm. For more detail, check out the paper, or better still, take a look at the `source code `_! 7 | 8 | fastISM is based on the observation that neural networks spend the majority of their computation time in convolutional layers and that point mutations in the input sequence only affect limited a range of intermediate layers. As a result, most of the computation in ISM is redundant and avoiding it can result in significant speedups. 9 | 10 | .. figure:: ../images/annotated_basset.pdf 11 | 12 | Consider the above annotated diagram of a Basset-like architecture `(Kelley et al., 2016) `_ on an input DNA sequence of length 1000, with a 1 base-pair mutation at position 500. Positions marked in red indicate the regions that are affected by the point mutation in the input. Positions marked in yellow, flanking the positions in red, indicate unaffected regions that contribute to the output of the next layer. Ticks at the bottom of each layer correspond to position indices. Numbers on the right in black indicate the approximate number of computations required at that layer for a naive implementation of ISM. For convolution layers, the numbers in gray and green indicate the minimal computations required. 13 | 14 | For a single position change in the middle of the input sequence, the output of the first convolution, which has a kernel size of 19, is perturbed at 19 positions which can be computed from just 37 positions in the input. It then goes on to affect 7 out of 333 positions after the first Max Pool layer (Layer 2) and 5 out of 83 positions after the second Max Pool (Layer 3). Once the output of the final Max Pool layer is flattened and passed through a fully connected layer, all the neurons are affected by a single change in the input, and thus all subsequent computations must be recomputed entirely. 15 | 16 | fastISM works by restricting computation in the convolution layers to only those positions that are affected by the mutation in the input. Since the most time is spent in convolution layers, fastISM avoids down a major amount of redundant computation and speeds up ISM. See :ref:`API ` for more details on how this is achieved. -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. fastISM documentation master file, created by 2 | sphinx-quickstart on Fri Aug 21 14:05:53 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | .. image:: https://github.com/kundajelab/fastISM/raw/docs/images/logo.jpeg 7 | :target: https://github.com/kundajelab/fastISM 8 | 9 | fastISM Documentation 10 | ===================== 11 | 12 | A Keras implementation for fast in-silico saturated mutagenesis (ISM) for convolution-based architectures. It speeds up ISM by 10x or more by restricting computation to those regions of each layer that are affected by a mutation in the input. 13 | 14 | .. toctree:: 15 | :maxdepth: 2 16 | :caption: Contents: 17 | 18 | Quick Start 19 | Examples 20 | Tutorial 21 | How it Works 22 | Supported Layers 23 | API 24 | Change Log 25 | Citation 26 | 27 | .. 28 | Indices and tables 29 | ================== 30 | 31 | * :ref:`genindex` 32 | * :ref:`modindex` 33 | * :ref:`search` 34 | -------------------------------------------------------------------------------- /docs/layers.rst: -------------------------------------------------------------------------------- 1 | Supported Layers 2 | ================ 3 | This sections covers the layers that are currently supported by fastISM. fastISM supports a subset of layers in ``tf.keras.layers`` that are most commonly used for sequence-based models. 4 | 5 | **NOTE**: Restrictions on layers apply only till :ref:`stop-layers`, beyond which all layers are allowed. 6 | 7 | The layers below have been classified by which positions of the output are a function of the input at the ``i`` th position. 8 | 9 | Local Layers 10 | ------------ 11 | For Local layers, the input at the ``i`` th position affects a fixed interval of outputs around the ``i`` th position. 12 | 13 | **Supported**: 14 | |LOCAL| 15 | 16 | Currently, custom Local Layers are not supported as they may require additional logic to be incorporated into the code. Please post an `Issue `_ on GitHub to work out a solution. 17 | 18 | See Through Layers 19 | ------------------ 20 | See through layers are layers for which the output at the ``i`` th position depends on the input at the ``i`` th position only. 21 | 22 | **Supported**: 23 | |SEETHRU| 24 | 25 | To add a custom see-through layer: 26 | ``fastism.fast_ism_utils.SEE_THROUGH_LAYERS.add("YourLayer")`` 27 | 28 | Aggregation Layers 29 | ------------------ 30 | Aggregation layers are also See Through Layers as the output at the ``i`` th position depends on the input at the ``i`` th position only. The main difference is that Aggregation layers take in multiple inputs, and thus their output at the ``i`` th position depends on the ``i`` th position of all their inputs. 31 | 32 | **Supported**: 33 | |AGG| 34 | 35 | To add a custom aggregation layer: 36 | ``fastism.fast_ism_utils.AGGREGATE_LAYERS.add("YourLayer")`` 37 | 38 | .. _stop-layers: 39 | 40 | Stop Layers 41 | ----------- 42 | Layers after which output at ``i`` th position depends on inputs at most or all positions in the input. However, this is not strictly true for Flatten/Reshape, but it is assumed these are followed by Dense or similar. 43 | 44 | **Supported**: 45 | |STOP| 46 | 47 | To add a custom stop layer: 48 | ``fastism.fast_ism_utils.STOP_LAYERS.add("YourLayer")`` 49 | 50 | Pooling Layers 51 | -------------- 52 | Pooling layers are also Local Layers but are special since they are typically used to reduce the size of the input. 53 | 54 | **Supported**: 55 | |POOL| 56 | 57 | To add a custom pooling layer: 58 | ``fastism.fast_ism_utils.POOLING_LAYERS.add("YourLayer")`` 59 | 60 | Custom pooling layers must have the class attributes ``pool_size``, ``strides`` (which must be equal to ``pool_size``), ``padding`` (which must be ``valid``), ``data_format`` (which must be ``channels_last``). Here is an example of a custom pooling layer. 61 | 62 | .. code-block:: python 63 | 64 | class AttentionPooling1D(tf.keras.layers.Layer): 65 | # don't forget to add **kwargs 66 | def __init__(self, pool_size = 2, **kwargs): 67 | super().__init__() 68 | self.pool_size = pool_size 69 | 70 | # need for pooling layer 71 | self.strides = self.pool_size 72 | self.padding = "valid" # ensure it behaves like MaxPooling1D with valid padding 73 | self.data_format = "channels_last" 74 | 75 | def build(self, input_shape): 76 | _, length, num_features = input_shape 77 | self.w = self.add_weight( 78 | shape=(num_features, num_features), 79 | initializer="random_normal", 80 | trainable=True, 81 | ) 82 | 83 | # implement so that layer can be duplicated 84 | def get_config(self): 85 | config = super().get_config() 86 | config.update({ 87 | "pool_size": self.pool_size, 88 | "data_format": self.data_format, 89 | "strides": self.strides, 90 | "padding": self.padding 91 | }) 92 | return config 93 | 94 | def call(self, inputs): 95 | _, length, num_features = inputs.shape 96 | 97 | if length == None: # this can happen at when creating fast_ism_model 98 | return inputs # don't do anything for now 99 | 100 | inputs = tf.reshape( 101 | inputs, 102 | (-1, length // self.pool_size, self.pool_size, num_features)) 103 | 104 | return tf.reduce_sum( 105 | inputs * tf.nn.softmax(tf.matmul(inputs, self.w), axis=-2), 106 | axis=-2) 107 | 108 | 109 | Code adapted from `Enformer `_. Note that pooling layers can contain weights. 110 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/quickstart.md: -------------------------------------------------------------------------------- 1 | ../README.md -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | pydot==1.4.1 2 | tensorflow==2.5.3 3 | nbsphinx==0.7.1 4 | ipython==7.16.3 5 | -------------------------------------------------------------------------------- /docs/tutorial.ipynb: -------------------------------------------------------------------------------- 1 | ../notebooks/colab/DeepSEA.ipynb -------------------------------------------------------------------------------- /fastISM/__init__.py: -------------------------------------------------------------------------------- 1 | from .fast_ism import FastISM 2 | from .ism_base import NaiveISM 3 | -------------------------------------------------------------------------------- /fastISM/change_range.py: -------------------------------------------------------------------------------- 1 | from math import ceil, floor 2 | 3 | 4 | def get_int_if_tuple(param, idx=0): 5 | if isinstance(param, tuple): 6 | return param[idx] 7 | return param 8 | 9 | 10 | def not_supported_error(message): 11 | raise NotImplementedError("""{} not supported yet, please post an Issue with 12 | your architecture and the authors will try their 13 | best to help you!""".format(message)) 14 | 15 | 16 | class ChangeRangesBase(): 17 | """ 18 | Base class for layer-specific computations of which indices of the output 19 | are changed when list of input changed indices are specified. Conversely, given 20 | output ranges of indices that need to be produced by the layer, compute the input 21 | ranges that will be required for the same. 22 | 23 | In addition, given an input.... 24 | 25 | TODO: document better and with examples! 26 | """ 27 | 28 | def __init__(self, config): 29 | self.config = config 30 | self.validate_config() 31 | 32 | def validate_config(self): 33 | pass 34 | 35 | def forward(self, input_seqlen, input_change_ranges): 36 | """ 37 | list of tuples. e.g. [(0,1), (1,2), (2,3)...] if single bp ISM 38 | """ 39 | pass 40 | 41 | def backward(self, output_select_ranges): 42 | pass 43 | 44 | @staticmethod 45 | def forward_compose(change_ranges_objects_list, input_seqlen, input_change_ranges): 46 | # multiple ChangeRanges objects (in order), e.g. in a segment with conv->maxpool 47 | 48 | # base case which should generally only happen with input tensor 49 | if len(change_ranges_objects_list) == 0: 50 | return input_change_ranges, (0, 0), input_seqlen, input_change_ranges 51 | 52 | if len(change_ranges_objects_list) == 1: 53 | return change_ranges_objects_list[0].forward(input_seqlen, input_change_ranges) 54 | 55 | # chain forwards 56 | seqlen = input_seqlen 57 | affected_range = input_change_ranges 58 | for change_range_object in change_ranges_objects_list: 59 | input_range_corrected, _, seqlen, affected_range = change_range_object.forward( 60 | seqlen, affected_range) 61 | 62 | # chain backwards 63 | # this computes the input region to entire segment that is required 64 | # to obtain the final affected_range output. This ensures the segment 65 | # does not need any SliceAssign internally 66 | for change_range_object in reversed(change_ranges_objects_list[:-1]): 67 | input_range_corrected = change_range_object.backward( 68 | input_range_corrected) 69 | 70 | # compute new offsets 71 | # they will be relative to initial offsets 72 | initial_input_range_corrected, initial_padding, \ 73 | _, _ = change_ranges_objects_list[0].forward( 74 | input_seqlen, input_change_ranges) 75 | 76 | # all other except the first must have 0 padding 77 | # as there is no provision to bad within a segment 78 | assert(all([x.forward(input_seqlen, input_change_ranges)[1] == (0, 0) 79 | for x in change_ranges_objects_list[1:]])) 80 | 81 | # modified range_corrected should span at least the initial range 82 | # unless a Cropping1D layer is present, in which case this does not 83 | # need to hold true (e.g. at the edges the initial range can be cropped out) 84 | # this may also happen if MaxPooling1D is placed in a segment of its own 85 | # e.g. segment 0 before any other conv -- this is not handled for now 86 | if not any([isinstance(x, Cropping1DChangeRanges) for 87 | x in change_ranges_objects_list]): 88 | assert(all([x_new <= x_old for (x_new, y_new), (x_old, y_old) in zip( 89 | input_range_corrected, initial_input_range_corrected)])) 90 | 91 | return input_range_corrected, initial_padding, seqlen, affected_range 92 | 93 | 94 | class Conv1DChangeRanges(ChangeRangesBase): 95 | def __init__(self, config): 96 | ChangeRangesBase.__init__(self, config) 97 | 98 | # in case dilation_rate > 1, compute effective kernel size 99 | self.effective_kernel_size = (get_int_if_tuple( 100 | config['kernel_size'])-1) * get_int_if_tuple(config['dilation_rate']) + 1 101 | 102 | # assuming "same" if not "valid" (checked in validate_config) 103 | # if valid and effective size is even then keras will pad with more zeros 104 | # on the right (used to be left before) 105 | self.padding_num = (0, 0) if config['padding'] == 'valid' else \ 106 | (floor((self.effective_kernel_size-1)/2), 107 | ceil((self.effective_kernel_size-1)/2)) 108 | 109 | def validate_config(self): 110 | if self.config['data_format'] != "channels_last": 111 | not_supported_error("data_format \"{}\"".format( 112 | self.config['data_format'])) 113 | 114 | strides = get_int_if_tuple(self.config['strides']) 115 | if strides != 1: 116 | not_supported_error("Conv1D strides!=1") 117 | 118 | if self.config['groups'] > 1: 119 | not_supported_error("Groups > 1") 120 | 121 | if self.config['padding'] not in ['valid', 'same']: 122 | not_supported_error( 123 | "Padding \"{}\" for Conv1D".format(self.config['padding'])) 124 | 125 | def forward(self, input_seqlen, input_change_ranges): 126 | # NB: returned input_range_corrected, offsets are wrt padded input 127 | # MAKE VERY CLEAR 128 | assert(all([(0 <= x < input_seqlen and 0 < y <= input_seqlen and y > x) 129 | for x, y in input_change_ranges])) 130 | seqlen_with_padding = input_seqlen + sum(self.padding_num) 131 | 132 | # assuming input ranges have same width 133 | if (len(set([y-x for x, y in input_change_ranges])) != 1): 134 | not_supported_error("Input Change Ranges of different sizes") 135 | 136 | # required input range will involve regions around input_change_range 137 | input_change_range_with_filter = [ 138 | (x-self.effective_kernel_size+1, y+self.effective_kernel_size-1) for 139 | x, y in input_change_ranges] 140 | 141 | # there will be self.padding_num[0] zeros in the beginning 142 | input_change_range_padded = [ 143 | (x+self.padding_num[0], y+self.padding_num[0]) for 144 | x, y in input_change_range_with_filter] 145 | 146 | # account for edge effects 147 | input_range_corrected = [] 148 | for x, y in input_change_range_padded: 149 | # this can happen e.g. in dilated convs where the effective 150 | # width gets as wide as input sequence 151 | if y-x > seqlen_with_padding: 152 | #import pdb;pdb.set_trace() 153 | x, y = 0, seqlen_with_padding 154 | if x < 0: 155 | x, y = 0, y-x 156 | elif y > seqlen_with_padding: 157 | x, y = x-(y-seqlen_with_padding), seqlen_with_padding 158 | 159 | input_range_corrected.append((x, y)) 160 | 161 | # follows from requirement above 162 | assert(len(set([y-x for x, y in input_range_corrected])) == 1) 163 | 164 | # corrected change ranges must include input_change_ranges 165 | assert([x_c <= x and y_c >= y for (x, y), (x_c, y_c) in zip( 166 | input_change_ranges, input_range_corrected)]) 167 | 168 | # output affected ranges 169 | output_affected_ranges = [(x, y-self.effective_kernel_size+1) for 170 | x, y in input_range_corrected] 171 | 172 | # output sequence length 173 | outseqlen = seqlen_with_padding - self.effective_kernel_size + 1 174 | 175 | return input_range_corrected, self.padding_num, outseqlen, output_affected_ranges 176 | 177 | def backward(self, output_select_ranges): 178 | assert(len(set([y-x for x, y in output_select_ranges])) == 1) 179 | 180 | ranges = [(x, y+self.effective_kernel_size-1) 181 | for x, y in output_select_ranges] 182 | 183 | assert(all([(x >= 0 and y >= 0 and y > x) 184 | for x, y in ranges])) 185 | 186 | return ranges 187 | 188 | 189 | class Pooling1DChangeRanges(ChangeRangesBase): 190 | def __init__(self, config): 191 | ChangeRangesBase.__init__(self, config) 192 | self.pool_size = get_int_if_tuple(self.config['pool_size']) 193 | self.strides = get_int_if_tuple(self.config['strides']) 194 | 195 | def validate_config(self): 196 | if self.config['data_format'] != "channels_last": 197 | not_supported_error("data_format \"{}\"".format( 198 | self.config['data_format'])) 199 | 200 | pool_size = get_int_if_tuple(self.config['pool_size']) 201 | strides = get_int_if_tuple(self.config['strides']) 202 | 203 | if pool_size != strides: 204 | not_supported_error("pool_size != strides") 205 | 206 | if self.config['padding'] != 'valid': 207 | not_supported_error( 208 | "Padding \"{}\" for Maxpooling1D".format(self.config['padding'])) 209 | 210 | def forward(self, input_seqlen, input_change_ranges): 211 | # assuming input ranges have same width 212 | if (len(set([y-x for x, y in input_change_ranges])) != 1): 213 | not_supported_error("Input Change Ranges of different sizes") 214 | 215 | # shift to edges of nearest maxpool block 216 | input_change_range_shifted = [(self.pool_size*(x//self.pool_size), 217 | self.pool_size*ceil(y/self.pool_size)) for 218 | x, y in input_change_ranges] 219 | 220 | # sizes can change, calculate maxwidth and set all to same 221 | maxwidth = max([y-x for x, y in input_change_range_shifted]) 222 | 223 | # set to same length 224 | input_range_corrected = [(x, x+maxwidth) if y <= input_seqlen 225 | else (y-maxwidth, y) 226 | for x, y in input_change_range_shifted] 227 | # NOTE: the below code ignores the last block when seqlen is not a multiple 228 | # of pool_size. This works only when padding == 'valid' and strides==pool_size. 229 | input_range_corrected = [(x, y) if y <= input_seqlen else 230 | (x-self.pool_size, y-self.pool_size) for x, y in input_range_corrected] 231 | assert([y <= input_seqlen for _, y in input_range_corrected]) 232 | 233 | # corrected change ranges must include input_change_ranges 234 | assert([x_c <= x and y_c >= y for (x, y), (x_c, y_c) in zip( 235 | input_change_ranges, input_range_corrected)]) 236 | 237 | output_affected_ranges = [(x//self.pool_size, y//self.pool_size) for 238 | (x, y) in input_range_corrected] 239 | 240 | # output sequence length (assumes "valid" paddng) 241 | assert(self.config["padding"] == "valid") 242 | outseqlen = input_seqlen // self.pool_size 243 | 244 | # (0,0) for no padding -- this would change if padding="same" is allowed 245 | return input_range_corrected, (0, 0), outseqlen, output_affected_ranges 246 | 247 | def backward(self, output_select_ranges): 248 | assert(len(set([y-x for x, y in output_select_ranges])) == 1) 249 | return [(x*self.pool_size, y*self.pool_size) for x, y in output_select_ranges] 250 | 251 | 252 | class Cropping1DChangeRanges(ChangeRangesBase): 253 | def __init__(self, config): 254 | ChangeRangesBase.__init__(self, config) 255 | self.cropping = self.config['cropping'] 256 | 257 | def validate_config(self): 258 | # all configs accepted 259 | return True 260 | 261 | def forward(self, input_seqlen, input_change_ranges): 262 | # assuming input ranges have same width 263 | if (len(set([y-x for x, y in input_change_ranges])) != 1): 264 | not_supported_error("Input Change Ranges of different sizes") 265 | 266 | # push right if within left cropping 267 | input_range_corrected = [(x, y) if x >= self.cropping[0] else 268 | (self.cropping[0], self.cropping[0] + (y-x)) 269 | for x, y in input_change_ranges] 270 | 271 | # push left if within right cropping 272 | right_edge = (input_seqlen-self.cropping[1]) 273 | input_range_corrected = [(x, y) if y < right_edge else 274 | (max(right_edge - (y-x), self.cropping[0]), 275 | right_edge) 276 | for x, y in input_range_corrected] 277 | 278 | output_affected_ranges = [(x-self.cropping[0], y-self.cropping[0]) for 279 | (x, y) in input_range_corrected] 280 | 281 | outseqlen = input_seqlen - sum(self.cropping) 282 | 283 | assert(len(set([y-x for x, y in input_range_corrected])) == 1) 284 | 285 | # (0,0) for no padding 286 | return input_range_corrected, (0, 0), outseqlen, output_affected_ranges 287 | 288 | def backward(self, output_select_ranges): 289 | assert(len(set([y-x for x, y in output_select_ranges])) == 1) 290 | return [(x+self.cropping[0], y+self.cropping[0]) for x, y in output_select_ranges] 291 | -------------------------------------------------------------------------------- /fastISM/fast_ism.py: -------------------------------------------------------------------------------- 1 | from .ism_base import ISMBase, NaiveISM 2 | from .fast_ism_utils import generate_models 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | 7 | 8 | class FastISM(ISMBase): 9 | def __init__(self, model, seq_input_idx=0, change_ranges=None, 10 | early_stop_layers=None, test_correctness=True): 11 | super().__init__(model, seq_input_idx, change_ranges) 12 | 13 | self.output_nodes, self.intermediate_output_model, self.intout_output_tensors, \ 14 | self.fast_ism_model, self.input_specs = generate_models( 15 | self.model, self.seqlen, self.num_chars, self.seq_input_idx, 16 | self.change_ranges, early_stop_layers) 17 | 18 | self.intout_output_tensor_to_idx = { 19 | x: i for i, x in enumerate(self.intout_output_tensors)} 20 | 21 | if test_correctness: 22 | if not self.test_correctness(): 23 | raise ValueError("""Fast ISM model built is incorrect, likely 24 | due to internal errors. Please post an Issue 25 | with your architecture and the authors will 26 | try their best to help you!""") 27 | 28 | def pre_change_range_loop_prep(self, inp_batch, num_seqs): 29 | # run intermediate output on unperturbed sequence 30 | intout_output = self.intermediate_output_model( 31 | inp_batch, training=False) 32 | 33 | self.padded_inputs = self.prepare_intout_output( 34 | intout_output, num_seqs) # better name than padded? 35 | 36 | # return output on unperturbed input 37 | if self.num_outputs == 1: 38 | return intout_output[self.intout_output_tensor_to_idx[self.output_nodes[0]]] 39 | else: 40 | return [intout_output[self.intout_output_tensor_to_idx[self.output_nodes[i]]] for 41 | i in range(self.num_outputs)] 42 | 43 | def prepare_intout_output(self, intout_output, num_seqs): 44 | inputs = [] 45 | 46 | for input_spec in self.input_specs: 47 | if input_spec[0] == "SEQ_PERTURB": 48 | inputs.append(tf.tile(self.perturbation, [num_seqs, 1, 1])) 49 | elif input_spec[0] == "INTOUT_SEQ": 50 | # pad the output if required 51 | to_pad = intout_output[self.intout_output_tensor_to_idx[input_spec[1]['node']]] 52 | padded = tf.keras.layers.ZeroPadding1D( 53 | input_spec[1]['padding'])(to_pad) 54 | inputs.append(padded) 55 | elif input_spec[0] == "INTOUT_ALT": 56 | # descendant of alternate input -- copy through 57 | inputs.append( 58 | intout_output[self.intout_output_tensor_to_idx[input_spec[1]['node']]]) 59 | elif input_spec[0] == "OFFSET": 60 | # nothing for now, add i specific offset later 61 | inputs.append(None) 62 | else: 63 | raise ValueError( 64 | "{}: what is this input spec?".format(input_spec[0])) 65 | 66 | return inputs 67 | 68 | def run_model(self, inputs): 69 | return self.fast_ism_model(inputs, training=False) 70 | 71 | def get_ith_output(self, inp_batch, i, idxs_to_mutate): 72 | fast_ism_inputs = self.prepare_ith_input( 73 | self.padded_inputs, i, idxs_to_mutate) 74 | 75 | return self.run_model(fast_ism_inputs) 76 | 77 | def prepare_ith_input(self, padded_inputs, i, idxs_to_mutate): 78 | num_to_mutate = idxs_to_mutate.shape[0] 79 | inputs = [] 80 | 81 | for input_idx, input_spec in enumerate(self.input_specs): 82 | if input_spec[0] == "SEQ_PERTURB": 83 | inputs.append(padded_inputs[input_idx][:num_to_mutate]) 84 | elif input_spec[0] == "INTOUT_SEQ": 85 | # slice 86 | inputs.append( 87 | tf.gather(padded_inputs[input_idx][:, 88 | input_spec[1]['slices'][i][0]: input_spec[1]['slices'][i][1]], 89 | idxs_to_mutate)) 90 | elif input_spec[0] == "INTOUT_ALT": 91 | inputs.append( 92 | tf.gather(padded_inputs[input_idx], idxs_to_mutate)) 93 | elif input_spec[0] == "OFFSET": 94 | inputs.append(tf.convert_to_tensor(input_spec[1]['offsets'][i], dtype=tf.int64)) 95 | else: 96 | raise ValueError( 97 | "{}: what is this input spec?".format(input_spec[0])) 98 | 99 | return inputs 100 | 101 | def cleanup(self): 102 | # padded inputs no longer required, take up GPU memory 103 | # if not deleted, have led to memory leaks 104 | del self.padded_inputs 105 | 106 | def test_correctness(self, batch_size=10, replace_with=0, atol=1e-6): 107 | """ 108 | Verify that outputs are correct by matching with Naive ISM. Running on small 109 | examples so as to not take too long. 110 | 111 | Hence not comparing runtime against Naive ISM implementation, which requires 112 | bigger inputs to offset overheads. 113 | 114 | TODO: ensure generated data is on GPU already before calling either method (for speedup) 115 | """ 116 | 117 | # TODO: better way to do this? 118 | naive_ism = NaiveISM(self.model, self.seq_input_idx, 119 | self.change_ranges) 120 | 121 | # test batch 122 | if self.num_inputs == 1: 123 | x = tf.constant(np.random.random( 124 | (batch_size,) + self.model.input_shape[1:]), 125 | dtype=self.model.inputs[self.seq_input_idx].dtype) 126 | else: 127 | x = [] 128 | for j in range(self.num_inputs): 129 | x.append( 130 | tf.constant(np.random.random( 131 | (batch_size,) + self.model.input_shape[j][1:]), 132 | dtype=self.model.inputs[j].dtype) 133 | ) 134 | 135 | naive_out = naive_ism(x, replace_with=replace_with) 136 | fast_out = self(x, replace_with=replace_with) 137 | 138 | if self.num_outputs == 1: 139 | return np.all(np.isclose(naive_out, fast_out, atol=atol)) 140 | else: 141 | return all([np.allclose(naive_out[j], fast_out[j], atol=atol) and 142 | np.allclose(fast_out[j], naive_out[j], atol=atol) for 143 | j in range(self.num_outputs)]) 144 | 145 | def time_batch(self, seq_batch): 146 | pass 147 | -------------------------------------------------------------------------------- /fastISM/flatten_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.util import nest 3 | from collections import defaultdict 4 | import pydot 5 | 6 | 7 | def is_input_layer(layer): 8 | """Checks if layer is an input layer 9 | 10 | :param layer: A Keras layer 11 | :type layer: tf.keras.layers 12 | :return: True if layer is input layer, else False 13 | :rtype: bool 14 | """ 15 | return isinstance(layer, tf.keras.layers.InputLayer) 16 | 17 | 18 | def strip_subgraph_names(name, subgraph_names): 19 | "subgraph_name1/subgraph_name2/layer/name -> layer/name" 20 | while name.split("/")[0] in subgraph_names: 21 | name = name[name.find("/")+1:] 22 | return name 23 | 24 | 25 | def node_is_layer(node_name): 26 | # if False then Tensor 27 | return node_name.startswith("LAYER") 28 | 29 | 30 | def list_replace(l, old, new): 31 | return [new if t == old else t for t in l] 32 | 33 | 34 | def is_bipartite(edges): 35 | # only layer->tensor/tensor->layer connections allowed 36 | for node in edges: 37 | is_layer = node_is_layer(node) 38 | for ngb in edges[node]: 39 | if node_is_layer(ngb) == is_layer: 40 | return False 41 | return True 42 | 43 | 44 | def is_consistent(edges, inbound_edges): 45 | inbound_edges_from_edges = defaultdict(set) 46 | for x in edges: 47 | for y in edges[x]: 48 | inbound_edges_from_edges[y].add(x) 49 | if set(inbound_edges_from_edges.keys()) != set(inbound_edges.keys()): 50 | return False 51 | 52 | for x in inbound_edges: 53 | if set(inbound_edges[x]) != inbound_edges_from_edges[x]: 54 | return False 55 | return True 56 | 57 | 58 | def get_flattened_graph(model, is_subgraph=False): 59 | """[summary] 60 | 61 | :param model: [description] 62 | :type model: [type] 63 | :param is_subgraph: [description], defaults to False 64 | :type is_subgraph: bool, optional 65 | :return: [description] 66 | :rtype: [type] 67 | """ 68 | # Inspired by: https://github.com/tensorflow/tensorflow/blob/b36436b/tensorflow/python/keras/utils/vis_utils.py#L70 69 | # Wrapper support like in model_to_dot?? 70 | # MORE comments 71 | # gets rid of intermediate inputlayers and makes graph bipartite 72 | layers = model.layers 73 | nodes = dict() 74 | edges = defaultdict(list) 75 | inbound_edges = defaultdict(list) 76 | subgraph_names = set() 77 | 78 | if isinstance(model, tf.keras.Sequential): 79 | if not model.built: 80 | model.build() 81 | # same as in model_to_dot, without this the layers don't contain 82 | # the input layer for some reason 83 | layers = super(tf.keras.Sequential, model).layers 84 | 85 | for _, layer in enumerate(layers): 86 | layer_name = "LAYER/{}".format(layer.name) 87 | 88 | if isinstance(layer, tf.keras.Sequential) or layer.__class__.__name__=="Functional": 89 | subgraph_nodes, subgraph_edges, subgraph_inbound_edges, \ 90 | subsubgraph_names, _ = get_flattened_graph( 91 | layer, is_subgraph=True) 92 | 93 | nodes.update(subgraph_nodes) 94 | edges.update(subgraph_edges) 95 | inbound_edges.update(subgraph_inbound_edges) 96 | subgraph_names.add(layer.name) 97 | subgraph_names.update(subsubgraph_names) 98 | 99 | else: 100 | for o in nest.flatten(layer.output): 101 | nodes["TENSOR/{}".format(o.name)] = None 102 | 103 | if not (is_subgraph and isinstance(layer, tf.keras.layers.InputLayer)): 104 | # TBD if necessary 105 | nodes[layer_name] = layer 106 | # layer -> tensor edge (trivial) 107 | edges[layer_name].append("TENSOR/{}".format(o.name)) 108 | inbound_edges["TENSOR/{}".format(o.name) 109 | ].append(layer_name) 110 | 111 | # tensor -> inputLayer tensor edges 112 | # tensor -> Layer edges 113 | for _, layer in enumerate(layers): 114 | # len(layer.inbound_nodes) is > 1 when models are nested 115 | # however, it seems like all different layer.inbound_nodes[i].input_tensors 116 | # point to the same tensors, through different scopes 117 | # using the 1st seems to work along with stripping subgraph names 118 | # assert(len(layer.inbound_nodes) == 1) 119 | 120 | layer_input_tensors = [x.name for x in nest.flatten( 121 | layer.inbound_nodes[0].call_args)] #input_tensors)] 122 | # if inbound node comes from a subgraph, it will start with "subgraph_name/" 123 | # if it comes from subgraph within subgraph, it will start with "subgraph_name1/subgraph_name2/" 124 | # but "nodes" do not have subgraph_name in them 125 | for i in range(len(layer_input_tensors)): 126 | layer_input_tensors[i] = strip_subgraph_names( 127 | layer_input_tensors[i], subgraph_names) 128 | 129 | layer_input_tensors = [ 130 | "TENSOR/{}".format(x) for x in layer_input_tensors] 131 | 132 | assert(all([x in nodes for x in layer_input_tensors])) 133 | 134 | if isinstance(layer, tf.keras.Sequential) or layer.__class__.__name__=="Functional": 135 | layer_inputlayer_names = [ 136 | "TENSOR/{}".format(x.name) for x in layer.inputs] 137 | 138 | assert(all([x in nodes for x in layer_inputlayer_names])) 139 | assert(len(layer_input_tensors) == len(layer_inputlayer_names)) 140 | 141 | # assuming order of inputs is preserved 142 | # inbound_edges should store inputs in correct order for multi 143 | # input layers 144 | for x, y in zip(layer_input_tensors, layer_inputlayer_names): 145 | # transfering edges of y to x and deleting y 146 | for e in edges[y]: 147 | edges[x].append(e) 148 | # replace y by x in inbound_edges 149 | inbound_edges[e] = list_replace(inbound_edges[e], y, x) 150 | 151 | del edges[y] 152 | del nodes[y] 153 | 154 | elif not isinstance(layer, tf.keras.layers.InputLayer): 155 | layer_name = "LAYER/{}".format(layer.name) 156 | 157 | for x in layer_input_tensors: 158 | edges[x].append(layer_name) 159 | # this preserves order of inputs 160 | inbound_edges[layer_name].append(x) 161 | 162 | assert(is_bipartite(edges)) 163 | 164 | # ensure edges and inbound_edges agree 165 | assert(is_consistent(edges, inbound_edges)) 166 | 167 | # strip model output names 168 | output_nodes = ["TENSOR/{}".format(strip_subgraph_names(x.name, subgraph_names)) 169 | for x in model.outputs] 170 | assert(all([o in nodes for o in output_nodes])) 171 | 172 | return nodes, edges, inbound_edges, subgraph_names, output_nodes 173 | 174 | 175 | def viz_graph(nodes, edges, outpath): 176 | dot = pydot.Dot() 177 | dot.set('rankdir', 'TB') 178 | # dot.set('concentrate', True) 179 | # dot.set_node_defaults(shape='record') 180 | dot.set('dpi', 96) 181 | for x in nodes: 182 | dot.add_node(pydot.Node(x.replace(":", "/"), 183 | label=x.replace(":", "/"))) 184 | for x in edges: 185 | for y in edges[x]: 186 | dot.add_edge(pydot.Edge(x.replace(":", "/"), y.replace(":", "/"))) 187 | dot.write(outpath, format='png') 188 | -------------------------------------------------------------------------------- /fastISM/ism_base.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | class ISMBase(): 6 | def __init__(self, model, seq_input_idx=0, change_ranges=None): 7 | # check if model is supported by current implementation 8 | self.model = model 9 | self.num_outputs = len(model.outputs) 10 | self.num_inputs = len(model.inputs) 11 | 12 | self.seq_input_idx = seq_input_idx 13 | seq_input = self.model.inputs[seq_input_idx] 14 | self.seq_dtype = seq_input.dtype 15 | self.seqlen = seq_input.shape[1] 16 | self.num_chars = seq_input.shape[2] 17 | 18 | if change_ranges is None: 19 | # default would be mutations at each position, 1 bp wide 20 | change_ranges = [(i, i+1) for i in range(self.seqlen)] 21 | # TODO: unify nomenclature "change_ranges", "affected_ranges", "perturbed_ranges" 22 | self.change_ranges = change_ranges 23 | 24 | # only one input width allowed (currently) 25 | assert(len(set([x[1]-x[0] for x in change_ranges])) == 1) 26 | self.perturb_width = change_ranges[0][1] - change_ranges[0][0] 27 | 28 | def set_perturbation(self, replace_with): 29 | self.replace_with = replace_with 30 | if self.replace_with != 0: 31 | self.replace_with = np.array(replace_with) 32 | if self.replace_with.ndim == 1: 33 | self.replace_with = np.expand_dims(self.replace_with, 0) 34 | assert(self.replace_with.ndim == 2) 35 | 36 | if replace_with == 0: 37 | self.perturbation = tf.constant( 38 | np.zeros((1, self.perturb_width, self.num_chars)), 39 | dtype=self.seq_dtype) 40 | else: 41 | assert(self.replace_with.shape[0] == self.perturb_width) 42 | self.perturbation = tf.constant(np.expand_dims(self.replace_with, 0), 43 | dtype=self.seq_dtype) 44 | 45 | def __call__(self, inp_batch, replace_with=0): 46 | self.set_perturbation(replace_with) 47 | 48 | if self.num_inputs == 1: 49 | num_seqs = inp_batch.shape[0] 50 | else: 51 | num_seqs = inp_batch[self.seq_input_idx].shape[0] 52 | 53 | # setup bookeeping and return output on unperturbed input 54 | unperturbed_output = self.pre_change_range_loop_prep( 55 | inp_batch, num_seqs) 56 | 57 | # set up ism output tensors by intialising to unperturbed_output 58 | if self.num_outputs == 1: 59 | # take off GPU 60 | unperturbed_output = unperturbed_output.numpy() 61 | 62 | # batch_size x num_perturb x output_dim 63 | ism_outputs = np.repeat(np.expand_dims(unperturbed_output, 1), 64 | len(self.change_ranges), 1) 65 | else: 66 | unperturbed_output = [x.numpy() for x in unperturbed_output] 67 | 68 | ism_outputs = [] 69 | for j in range(self.num_outputs): 70 | ism_outputs.append(np.repeat(np.expand_dims(unperturbed_output[j], 1), 71 | len(self.change_ranges), 1)) 72 | 73 | for i, change_range in enumerate(self.change_ranges): 74 | # only run models on seqs that are being perturbed 75 | if self.num_inputs == 1: 76 | idxs_to_mutate = tf.squeeze(tf.where(tf.logical_not(tf.reduce_all( 77 | inp_batch[:, change_range[0]:change_range[1]] == self.perturbation[0], axis=(1, 2)))), 78 | axis=1) 79 | else: 80 | idxs_to_mutate = tf.squeeze(tf.where(tf.logical_not(tf.reduce_all( 81 | inp_batch[self.seq_input_idx][:, change_range[0]:change_range[1]] == self.perturbation[0], axis=(1, 2)))), 82 | axis=1) 83 | 84 | num_to_mutate = idxs_to_mutate.shape[0] 85 | if num_to_mutate > 0: 86 | # output only on idxs_to_mutate 87 | ism_ith_output = self.get_ith_output(inp_batch, i, idxs_to_mutate) 88 | 89 | if self.num_outputs == 1: 90 | ism_outputs[idxs_to_mutate, i] = ism_ith_output 91 | else: 92 | for j in range(self.num_outputs): 93 | ism_outputs[j][idxs_to_mutate, 94 | i] = ism_ith_output[j].numpy() 95 | 96 | # cleanup tensors that have been used 97 | self.cleanup() 98 | 99 | return ism_outputs 100 | 101 | def pre_change_range_loop_prep(self, inp_batch, num_seqs): 102 | pass 103 | 104 | def get_ith_output(self, inp_batch, i, idxs_to_mutate): 105 | pass 106 | 107 | 108 | class NaiveISM(ISMBase): 109 | def __init__(self, model, seq_input_idx=0, change_ranges=None): 110 | super().__init__(model, seq_input_idx, change_ranges) 111 | 112 | def pre_change_range_loop_prep(self, inp_batch, num_seqs): 113 | self.cur_perturbation = tf.tile(self.perturbation, [num_seqs, 1, 1]) 114 | 115 | return self.model(inp_batch, training=False) 116 | 117 | def run_model(self, x): 118 | return self.model(x, training=False) 119 | 120 | def get_ith_output(self, inp_batch, i, idxs_to_mutate): 121 | num_to_mutate = idxs_to_mutate.shape[0] 122 | 123 | # prep input with ith change range mutation 124 | if self.num_inputs == 1: 125 | ism_input = tf.concat([ 126 | tf.gather(inp_batch[ 127 | :, :self.change_ranges[i][0]], idxs_to_mutate), 128 | self.cur_perturbation[:num_to_mutate], 129 | tf.gather(inp_batch[ 130 | :, self.change_ranges[i][1]:], idxs_to_mutate) 131 | ], axis=1) 132 | else: 133 | ism_input = [] 134 | for j in range(self.num_inputs): 135 | if j == self.seq_input_idx: 136 | ism_input.append(tf.concat([ 137 | tf.gather(inp_batch[self.seq_input_idx], idxs_to_mutate)[:, 138 | :self.change_ranges[i][0]], 139 | self.cur_perturbation[:num_to_mutate], 140 | tf.gather(inp_batch[self.seq_input_idx], idxs_to_mutate)[:, 141 | self.change_ranges[i][1]:], 142 | ], axis=1)) 143 | else: 144 | ism_input.append(tf.gather(inp_batch[j], idxs_to_mutate)) 145 | 146 | return self.run_model(ism_input) 147 | 148 | def cleanup(self): 149 | pass 150 | -------------------------------------------------------------------------------- /fastISM/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | -------------------------------------------------------------------------------- /fastISM/models/basset.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def basset_model(seqlen=1000, numchars=4, num_outputs=1, name='basset_model'): 5 | inp = tf.keras.Input(shape=(seqlen, numchars)) 6 | 7 | # conv mxp 1 8 | x = tf.keras.layers.Conv1D( 9 | 300, 19, strides=1, padding='same', activation='relu', name='conv1')(inp) 10 | x = tf.keras.layers.BatchNormalization()(x) 11 | x = tf.keras.layers.MaxPool1D(3)(x) 12 | 13 | # conv mxp 2 14 | x = tf.keras.layers.Conv1D( 15 | 200, 11, strides=1, padding='same', activation='relu', name='conv2')(x) 16 | x = tf.keras.layers.BatchNormalization()(x) 17 | x = tf.keras.layers.MaxPool1D(4)(x) 18 | 19 | # conv mxp 3 20 | x = tf.keras.layers.Conv1D( 21 | 200, 7, strides=1, padding='same', activation='relu', name='conv3')(x) 22 | x = tf.keras.layers.BatchNormalization()(x) 23 | x = tf.keras.layers.MaxPool1D(4)(x) 24 | 25 | # fc 26 | x = tf.keras.layers.Flatten()(x) 27 | x = tf.keras.layers.Dense(1000, activation='relu', name='fc1')(x) 28 | x = tf.keras.layers.BatchNormalization()(x) 29 | x = tf.keras.layers.Dense(1000, activation='relu', name='fc2')(x) 30 | x = tf.keras.layers.BatchNormalization()(x) 31 | x = tf.keras.layers.Dense(num_outputs, name='fc3')(x) 32 | 33 | model = tf.keras.Model(inputs=inp, outputs=x, name=name) 34 | 35 | return model 36 | -------------------------------------------------------------------------------- /fastISM/models/bpnet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def bpnet_model(seqlen=1000, numchars=4, num_dilated_convs=9, num_tasks=1, 5 | name='bpnet_model'): 6 | # original as per https://www.biorxiv.org/content/10.1101/737981v1.full.pdf 7 | inp = tf.keras.layers.Input(shape=(seqlen, 4)) 8 | x = tf.keras.layers.Conv1D( 9 | 64, kernel_size=25, padding='same', activation='relu')(inp) 10 | 11 | for i in range(num_dilated_convs): 12 | conv_x = tf.keras.layers.Conv1D( 13 | 64, kernel_size=3, padding='same', activation='relu', dilation_rate=2**i)(x) 14 | x = tf.keras.layers.Add()([conv_x, x]) 15 | bottleneck = x 16 | 17 | # heads 18 | outputs = [] 19 | for _ in range(num_tasks): 20 | # profile shape head 21 | px = tf.keras.layers.Reshape((-1, 1, 64))(bottleneck) 22 | px = tf.keras.layers.Conv2DTranspose( 23 | 1, kernel_size=(25, 1), padding='same')(px) 24 | outputs.append(tf.keras.layers.Flatten()(px)) 25 | 26 | # total counts head 27 | cx = tf.keras.layers.GlobalAvgPool1D()(bottleneck) 28 | outputs.append(tf.keras.layers.Dense(1)(cx)) 29 | 30 | model = tf.keras.Model(inputs=inp, outputs=outputs) 31 | 32 | return model 33 | -------------------------------------------------------------------------------- /fastISM/models/bpnet_dense.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.backend import int_shape 3 | from tensorflow.keras.models import Model 4 | from tensorflow.keras.layers import Dense, Activation, Dropout, Flatten, \ 5 | Reshape, Input, Concatenate, Cropping1D, Add, Dropout, Reshape, Dense, \ 6 | Activation, Flatten, Conv1D, GlobalMaxPooling1D, MaxPooling1D, \ 7 | GlobalAveragePooling1D, BatchNormalization 8 | import argparse 9 | 10 | 11 | def get_model_param_dict(param_file): 12 | ''' 13 | param_file has 2 columns -- param name in column 1, and param value in column 2 14 | ''' 15 | params = {} 16 | if param_file is None: 17 | return params 18 | for line in open(param_file, 'r').read().strip().split('\n'): 19 | tokens = line.split('\t') 20 | params[tokens[0]] = tokens[1] 21 | return params 22 | 23 | 24 | def bpnet_dense_model(inlen=1346, outlen=1000, filters=32, ndl=6): 25 | """ 26 | Anna's BPNet architecture based on 27 | https://github.com/kundajelab/kerasAC/blob/1b483c4/kerasAC/architectures/profile_bpnet_chipseq.py 28 | """ 29 | filters = filters 30 | n_dil_layers = ndl 31 | conv1_kernel_size = 21 32 | profile_kernel_size = 75 33 | counts_loss_weight = 1 34 | profile_loss_weight = 1 35 | 36 | parser = argparse.ArgumentParser(description="view model arch") 37 | parser.add_argument("--seed", type=int, default=1234) 38 | parser.add_argument("--init_weights", default=None) 39 | parser.add_argument("--tdb_input_flank", nargs="+", default=[str(inlen//2)]) 40 | parser.add_argument("--tdb_output_flank", nargs="+", default=[str(outlen//2)]) 41 | parser.add_argument("--num_tasks", type=int, default=2) 42 | parser.add_argument("--model_params", default=None) 43 | 44 | args = parser.parse_args("") 45 | 46 | model_params = get_model_param_dict(args.model_params) 47 | if 'filters' in model_params: 48 | filters = int(model_params['filters']) 49 | if 'n_dil_layers' in model_params: 50 | n_dil_layers = int(model_params['n_dil_layers']) 51 | if 'conv1_kernel_size' in model_params: 52 | conv1_kernel_size = int(model_params['conv1_kernel_size']) 53 | if 'profile_kernel_size' in model_params: 54 | profile_kernel_size = int(model_params['profile_kernel_size']) 55 | if 'counts_loss_weight' in model_params: 56 | counts_loss_weight = float(model_params['counts_loss_weight']) 57 | if 'profile_loss_weight' in model_params: 58 | profile_loss_weight = float(model_params['profile_loss_weight']) 59 | 60 | #read in arguments 61 | seed = args.seed 62 | init_weights = args.init_weights 63 | sequence_flank = int(args.tdb_input_flank[0].split(',')[0]) 64 | num_tasks = args.num_tasks 65 | 66 | seq_len = 2*sequence_flank 67 | out_flank = int(args.tdb_output_flank[0].split(',')[0]) 68 | out_pred_len = 2*out_flank 69 | 70 | # define inputs 71 | inp = Input(shape=(seq_len, 4), name='sequence') 72 | bias_counts_input = Input(shape=(num_tasks,), name='control_logcount') 73 | bias_profile_input = Input( 74 | shape=(out_pred_len, num_tasks), name='control_profile') 75 | # first convolution without dilation 76 | first_conv = Conv1D(filters, 77 | kernel_size=conv1_kernel_size, 78 | padding='valid', 79 | activation='relu', 80 | name='1st_conv')(inp) 81 | # 6 dilated convolutions with resnet-style additions 82 | # each layer receives the sum of feature maps 83 | # from all previous layers 84 | res_layers = [(first_conv, '1stconv')] # on a quest to have meaninful 85 | # layer names 86 | layer_names = [str(i)+"_dil" for i in range(n_dil_layers)] 87 | for i in range(1, n_dil_layers + 1): 88 | if i == 1: 89 | res_layers_sum = first_conv 90 | else: 91 | res_layers_sum = Add(name='add_{}'.format(i))( 92 | [l for l, _ in res_layers]) 93 | 94 | # dilated convolution 95 | conv_layer_name = '{}conv'.format(layer_names[i-1]) 96 | conv_output = Conv1D(filters, 97 | kernel_size=3, 98 | padding='valid', 99 | activation='relu', 100 | dilation_rate=2**i, 101 | name=conv_layer_name)(res_layers_sum) 102 | 103 | # get shape of latest layer and crop 104 | # all other previous layers in the list to that size 105 | conv_output_shape = int_shape(conv_output) 106 | cropped_layers = [] 107 | for lyr, name in res_layers: 108 | lyr_shape = int_shape(lyr) 109 | cropsize = int(lyr_shape[1]/2) - int(conv_output_shape[1]/2) 110 | lyr_name = '{}-crop_{}th_dconv'.format(name.split('-')[0], i) 111 | cropped_layers.append((Cropping1D(cropsize, 112 | name=lyr_name)(lyr), 113 | lyr_name)) 114 | 115 | # append to the list of previous layers 116 | cropped_layers.append((conv_output, conv_layer_name)) 117 | res_layers = cropped_layers 118 | 119 | # the final output from the 6 dilated convolutions 120 | # with resnet-style connections 121 | combined_conv = Add(name='combined_conv')([l for l, _ in res_layers]) 122 | 123 | # Branch 1. Profile prediction 124 | # Step 1.1 - 1D convolution with a very large kernel 125 | profile_out_prebias = Conv1D(filters=num_tasks, 126 | kernel_size=profile_kernel_size, 127 | padding='valid', 128 | name='profile_out_prebias')(combined_conv) 129 | # Step 1.2 - Crop to match size of the required output size, a minimum 130 | # difference of 346 is required between input seq len and ouput len 131 | profile_out_prebias_shape = int_shape(profile_out_prebias) 132 | cropsize = int(profile_out_prebias_shape[1]/2)-int(out_pred_len/2) 133 | profile_out_prebias = Cropping1D(cropsize, 134 | name='prof_out_crop2match_output')(profile_out_prebias) 135 | # Step 1.3 - concatenate with the control profile 136 | concat_pop_bpi = Concatenate(axis=-1, name='concat_with_bias_prof')([profile_out_prebias, 137 | bias_profile_input]) 138 | 139 | # Step 1.4 - Final 1x1 convolution 140 | profile_out = Conv1D(filters=num_tasks, 141 | kernel_size=1, 142 | name="profile_predictions")(concat_pop_bpi) 143 | # Branch 2. Counts prediction 144 | # Step 2.1 - Global average pooling along the "length", the result 145 | # size is same as "filters" parameter to the BPNet function 146 | gap_combined_conv = GlobalAveragePooling1D( 147 | name='gap')(combined_conv) # acronym - gapcc 148 | 149 | # Step 2.2 Concatenate the output of GAP with bias counts 150 | concat_gapcc_bci = Concatenate( 151 | name="concat_with_bias_cnts", axis=-1)([gap_combined_conv, bias_counts_input]) 152 | 153 | # Step 2.3 Dense layer to predict final counts 154 | count_out = Dense(num_tasks, name="logcount_predictions")(concat_gapcc_bci) 155 | 156 | # instantiate keras Model with inputs and outputs 157 | model = Model(inputs=[inp, bias_profile_input, bias_counts_input], 158 | outputs=[profile_out, count_out]) 159 | 160 | # if only want counts output and input without control tracks 161 | #ct = Dense(1)(gap_combined_conv) 162 | #model = Model(inputs=inp, outputs=ct) 163 | 164 | return model 165 | -------------------------------------------------------------------------------- /fastISM/models/factorized_basset.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def factorized_basset_model(seqlen=1000, numchars=4, num_outputs=1, name='factorized_basset_model'): 5 | inp = tf.keras.Input(shape=(seqlen, numchars)) 6 | 7 | # conv mxp 1 8 | x = tf.keras.layers.Conv1D(48, 3, padding='same', name='conv1a')(inp) 9 | x = tf.keras.layers.BatchNormalization()(x) 10 | x = tf.keras.layers.ReLU()(x) 11 | x = tf.keras.layers.Conv1D(64, 3, padding='same', name='conv1b')(x) 12 | x = tf.keras.layers.BatchNormalization()(x) 13 | x = tf.keras.layers.ReLU()(x) 14 | x = tf.keras.layers.Conv1D(100, 3, padding='same', name='conv1c')(x) 15 | x = tf.keras.layers.BatchNormalization()(x) 16 | x = tf.keras.layers.ReLU()(x) 17 | x = tf.keras.layers.Conv1D(150, 7, padding='same', name='conv1d')(x) 18 | x = tf.keras.layers.BatchNormalization()(x) 19 | x = tf.keras.layers.ReLU()(x) 20 | x = tf.keras.layers.Conv1D(300, 7, padding='same', name='conv1e')(x) 21 | x = tf.keras.layers.BatchNormalization()(x) 22 | x = tf.keras.layers.ReLU()(x) 23 | 24 | x = tf.keras.layers.MaxPool1D(3)(x) 25 | 26 | # conv mxp 2 27 | x = tf.keras.layers.Conv1D(200, 7, padding='same', name='conv2a')(x) 28 | x = tf.keras.layers.BatchNormalization()(x) 29 | x = tf.keras.layers.ReLU()(x) 30 | x = tf.keras.layers.Conv1D(200, 3, padding='same', name='conv2b')(x) 31 | x = tf.keras.layers.BatchNormalization()(x) 32 | x = tf.keras.layers.ReLU()(x) 33 | x = tf.keras.layers.Conv1D(200, 3, padding='same', name='conv2c')(x) 34 | x = tf.keras.layers.BatchNormalization()(x) 35 | x = tf.keras.layers.ReLU()(x) 36 | 37 | x = tf.keras.layers.MaxPool1D(4)(x) 38 | 39 | # conv mxp 3 40 | x = tf.keras.layers.Conv1D(200, 7, padding='same', name='conv3')(x) 41 | x = tf.keras.layers.BatchNormalization()(x) 42 | x = tf.keras.layers.ReLU()(x) 43 | 44 | x = tf.keras.layers.MaxPool1D(4)(x) 45 | 46 | # fc 47 | x = tf.keras.layers.Flatten()(x) 48 | x = tf.keras.layers.Dense(1000, activation='relu', name='fc1')(x) 49 | x = tf.keras.layers.BatchNormalization()(x) 50 | x = tf.keras.layers.Dense(1000, activation='relu', name='fc2')(x) 51 | x = tf.keras.layers.BatchNormalization()(x) 52 | x = tf.keras.layers.Dense(num_outputs, name='fc3')(x) 53 | 54 | model = tf.keras.Model(inputs=inp, outputs=x, name=name) 55 | 56 | return model 57 | -------------------------------------------------------------------------------- /images/annotated_basset.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/fastISM/4fc1f44b3107720f50f1c37f067df0ffe6974093/images/annotated_basset.pdf -------------------------------------------------------------------------------- /images/logo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/fastISM/4fc1f44b3107720f50f1c37f067df0ffe6974093/images/logo.jpeg -------------------------------------------------------------------------------- /images/logo_1280x640.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/fastISM/4fc1f44b3107720f50f1c37f067df0ffe6974093/images/logo_1280x640.jpeg -------------------------------------------------------------------------------- /notebooks/DeepSHAPBenchmark.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# DeepSHAP Benchmark" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import sys\n", 17 | "sys.path.append(\"../\")\n", 18 | "\n", 19 | "import fastISM\n", 20 | "from fastISM.models.basset import basset_model\n", 21 | "\n", 22 | "from fastISM.models.factorized_basset import factorized_basset_model\n", 23 | "from fastISM.models.bpnet import bpnet_model\n", 24 | "import tensorflow as tf\n", 25 | "import numpy as np\n", 26 | "from importlib import reload\n", 27 | "import time\n", 28 | "\n", 29 | "import shap" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "metadata": {}, 36 | "outputs": [ 37 | { 38 | "data": { 39 | "text/plain": [ 40 | "" 41 | ] 42 | }, 43 | "execution_count": 2, 44 | "metadata": {}, 45 | "output_type": "execute_result" 46 | } 47 | ], 48 | "source": [ 49 | "reload(fastISM.flatten_model)\n", 50 | "reload(fastISM.models)\n", 51 | "reload(fastISM.ism_base)\n", 52 | "reload(fastISM.change_range)\n", 53 | "reload(fastISM.fast_ism_utils)\n", 54 | "reload(fastISM)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "data": { 64 | "text/plain": [ 65 | "'2.3.0'" 66 | ] 67 | }, 68 | "execution_count": 3, 69 | "metadata": {}, 70 | "output_type": "execute_result" 71 | } 72 | ], 73 | "source": [ 74 | "tf.__version__" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 4, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "name": "stdout", 84 | "output_type": "stream", 85 | "text": [ 86 | "Num GPUs Available: 1\n" 87 | ] 88 | } 89 | ], 90 | "source": [ 91 | "print(\"Num GPUs Available: \", len(tf.config.experimental.list_physical_devices('GPU')))" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 5, 97 | "metadata": {}, 98 | "outputs": [ 99 | { 100 | "data": { 101 | "text/plain": [ 102 | "'GPU:0'" 103 | ] 104 | }, 105 | "execution_count": 5, 106 | "metadata": {}, 107 | "output_type": "execute_result" 108 | } 109 | ], 110 | "source": [ 111 | "device = 'GPU:0' if tf.config.experimental.list_physical_devices('GPU') else '/device:CPU:0'\n", 112 | "device" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 6, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "# https://github.com/kundajelab/tfmodisco_tf_models/blob/bd449328b/src/extract/dinuc_shuffle.py\n", 122 | "\n", 123 | "def string_to_char_array(seq):\n", 124 | " \"\"\"\n", 125 | " Converts an ASCII string to a NumPy array of byte-long ASCII codes.\n", 126 | " e.g. \"ACGT\" becomes [65, 67, 71, 84].\n", 127 | " \"\"\"\n", 128 | " return np.frombuffer(bytes(seq, \"utf8\"), dtype=np.int8)\n", 129 | "\n", 130 | "\n", 131 | "def char_array_to_string(arr):\n", 132 | " \"\"\"\n", 133 | " Converts a NumPy array of byte-long ASCII codes into an ASCII string.\n", 134 | " e.g. [65, 67, 71, 84] becomes \"ACGT\".\n", 135 | " \"\"\"\n", 136 | " return arr.tostring().decode(\"ascii\")\n", 137 | "\n", 138 | "\n", 139 | "def one_hot_to_tokens(one_hot):\n", 140 | " \"\"\"\n", 141 | " Converts an L x D one-hot encoding into an L-vector of integers in the range\n", 142 | " [0, D], where the token D is used when the one-hot encoding is all 0. This\n", 143 | " assumes that the one-hot encoding is well-formed, with at most one 1 in each\n", 144 | " column (and 0s elsewhere).\n", 145 | " \"\"\"\n", 146 | " tokens = np.tile(one_hot.shape[1], one_hot.shape[0]) # Vector of all D\n", 147 | " seq_inds, dim_inds = np.where(one_hot)\n", 148 | " tokens[seq_inds] = dim_inds\n", 149 | " return tokens\n", 150 | "\n", 151 | "\n", 152 | "def tokens_to_one_hot(tokens, one_hot_dim):\n", 153 | " \"\"\"\n", 154 | " Converts an L-vector of integers in the range [0, D] to an L x D one-hot\n", 155 | " encoding. The value `D` must be provided as `one_hot_dim`. A token of D\n", 156 | " means the one-hot encoding is all 0s.\n", 157 | " \"\"\"\n", 158 | " identity = np.identity(one_hot_dim + 1)[:, :-1] # Last row is all 0s\n", 159 | " return identity[tokens]\n", 160 | "\n", 161 | "\n", 162 | "def dinuc_shuffle(seq, num_shufs, rng=None):\n", 163 | " \"\"\"\n", 164 | " Creates shuffles of the given sequence, in which dinucleotide frequencies\n", 165 | " are preserved.\n", 166 | " Arguments:\n", 167 | " `seq`: either a string of length L, or an L x D NumPy array of one-hot\n", 168 | " encodings\n", 169 | " `num_shufs`: the number of shuffles to create, N\n", 170 | " `rng`: a NumPy RandomState object, to use for performing shuffles\n", 171 | " If `seq` is a string, returns a list of N strings of length L, each one\n", 172 | " being a shuffled version of `seq`. If `seq` is a 2D NumPy array, then the\n", 173 | " result is an N x L x D NumPy array of shuffled versions of `seq`, also\n", 174 | " one-hot encoded.\n", 175 | " \"\"\"\n", 176 | " if type(seq) is str:\n", 177 | " arr = string_to_char_array(seq)\n", 178 | " elif type(seq) is np.ndarray and len(seq.shape) == 2:\n", 179 | " seq_len, one_hot_dim = seq.shape\n", 180 | " arr = one_hot_to_tokens(seq)\n", 181 | " else:\n", 182 | " raise ValueError(\"Expected string or one-hot encoded array\")\n", 183 | "\n", 184 | " if not rng:\n", 185 | " rng = np.random.RandomState()\n", 186 | " \n", 187 | " # Get the set of all characters, and a mapping of which positions have which\n", 188 | " # characters; use `tokens`, which are integer representations of the\n", 189 | " # original characters\n", 190 | " chars, tokens = np.unique(arr, return_inverse=True)\n", 191 | "\n", 192 | " # For each token, get a list of indices of all the tokens that come after it\n", 193 | " shuf_next_inds = []\n", 194 | " for t in range(len(chars)):\n", 195 | " mask = tokens[:-1] == t # Excluding last char\n", 196 | " inds = np.where(mask)[0]\n", 197 | " shuf_next_inds.append(inds + 1) # Add 1 for next token\n", 198 | " \n", 199 | " if type(seq) is str:\n", 200 | " all_results = []\n", 201 | " else:\n", 202 | " all_results = np.empty(\n", 203 | " (num_shufs, seq_len, one_hot_dim), dtype=seq.dtype\n", 204 | " )\n", 205 | "\n", 206 | " for i in range(num_shufs):\n", 207 | " # Shuffle the next indices\n", 208 | " for t in range(len(chars)):\n", 209 | " inds = np.arange(len(shuf_next_inds[t]))\n", 210 | " inds[:-1] = rng.permutation(len(inds) - 1) # Keep last index same\n", 211 | " shuf_next_inds[t] = shuf_next_inds[t][inds]\n", 212 | "\n", 213 | " counters = [0] * len(chars)\n", 214 | " \n", 215 | " # Build the resulting array\n", 216 | " ind = 0\n", 217 | " result = np.empty_like(tokens)\n", 218 | " result[0] = tokens[ind]\n", 219 | " for j in range(1, len(tokens)):\n", 220 | " t = tokens[ind]\n", 221 | " ind = shuf_next_inds[t][counters[t]]\n", 222 | " counters[t] += 1\n", 223 | " result[j] = tokens[ind]\n", 224 | "\n", 225 | " if type(seq) is str:\n", 226 | " all_results.append(char_array_to_string(chars[result]))\n", 227 | " else:\n", 228 | " all_results[i] = tokens_to_one_hot(chars[result], one_hot_dim)\n", 229 | " return all_results" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 7, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "# based on https://github.com/kundajelab/tfmodisco_tf_models/blob/bd449328b22/src/extract/compute_profile_shap.py\n", 239 | "\n", 240 | "def create_background(model_inputs, bg_size=10, seed=20191206):\n", 241 | " \"\"\"\n", 242 | " From a pair of single inputs to the model, generates the set of background\n", 243 | " inputs to perform interpretation against.\n", 244 | " Arguments:\n", 245 | " `model_inputs`: a pair of two entries; the first is a single one-hot\n", 246 | " encoded input sequence of shape I x 4; the second is the set of\n", 247 | " control profiles for the model, shaped T x O x 2\n", 248 | " `bg_size`: the number of background examples to generate.\n", 249 | " Returns a pair of arrays as a list, where the first array is G x I x 4, and\n", 250 | " the second array is G x T x O x 2; these are the background inputs. The\n", 251 | " background for the input sequences is randomly dinuceotide-shuffles of the\n", 252 | " original sequence. The background for the control profiles is the same as\n", 253 | " the originals.\n", 254 | " \"\"\"\n", 255 | " input_seq = model_inputs[0]\n", 256 | " rng = np.random.RandomState(seed)\n", 257 | " input_seq_bg = dinuc_shuffle(input_seq, bg_size, rng=rng)\n", 258 | " return input_seq_bg" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 8, 264 | "metadata": {}, 265 | "outputs": [ 266 | { 267 | "data": { 268 | "text/plain": [ 269 | "'0.36.0'" 270 | ] 271 | }, 272 | "execution_count": 8, 273 | "metadata": {}, 274 | "output_type": "execute_result" 275 | } 276 | ], 277 | "source": [ 278 | "shap.__version__" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 9, 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [ 287 | "shap.explainers._deep.deep_tf.op_handlers[\"AddV2\"] = shap.explainers._deep.deep_tf.passthrough" 288 | ] 289 | }, 290 | { 291 | "cell_type": "markdown", 292 | "metadata": {}, 293 | "source": [ 294 | "## Benchmark" 295 | ] 296 | }, 297 | { 298 | "cell_type": "markdown", 299 | "metadata": {}, 300 | "source": [ 301 | "### Basset/Factorized Basset" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 15, 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [ 310 | "BATCH_SIZES = [1,32,64,128,256,512, 1024]" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 16, 316 | "metadata": {}, 317 | "outputs": [ 318 | { 319 | "name": "stdout", 320 | "output_type": "stream", 321 | "text": [ 322 | "\n", 323 | "------------------\n", 324 | "MODEL: \n", 325 | "SEQLEN: 1000\n", 326 | "BATCH SIZE: 1\tTIME: 0.03\tPER 100: 2.56\n", 327 | "BATCH SIZE: 32\tTIME: 0.67\tPER 100: 2.09\n", 328 | "BATCH SIZE: 64\tTIME: 1.19\tPER 100: 1.86\n", 329 | "BATCH SIZE: 128\tTIME: 2.24\tPER 100: 1.75\n", 330 | "BATCH SIZE: 256\tTIME: 4.48\tPER 100: 1.75\n", 331 | "BATCH SIZE: 512\tTIME: 8.95\tPER 100: 1.75\n", 332 | "BATCH SIZE: 1024\tTIME: 17.97\tPER 100: 1.75\n", 333 | "BEST PER 100: 1.75\n", 334 | "\n", 335 | "------------------\n", 336 | "MODEL: \n", 337 | "SEQLEN: 2000\n", 338 | "BATCH SIZE: 1\tTIME: 0.03\tPER 100: 3.14\n", 339 | "BATCH SIZE: 32\tTIME: 0.97\tPER 100: 3.04\n", 340 | "BATCH SIZE: 64\tTIME: 1.95\tPER 100: 3.05\n", 341 | "BATCH SIZE: 128\tTIME: 3.89\tPER 100: 3.04\n", 342 | "BATCH SIZE: 256\tTIME: 7.79\tPER 100: 3.04\n", 343 | "BATCH SIZE: 512\tTIME: 15.61\tPER 100: 3.05\n", 344 | "BATCH SIZE: 1024\tTIME: 31.29\tPER 100: 3.06\n", 345 | "BEST PER 100: 3.04\n", 346 | "\n", 347 | "------------------\n", 348 | "MODEL: \n", 349 | "SEQLEN: 1000\n", 350 | "BATCH SIZE: 1\tTIME: 0.03\tPER 100: 2.69\n", 351 | "BATCH SIZE: 32\tTIME: 0.85\tPER 100: 2.64\n", 352 | "BATCH SIZE: 64\tTIME: 1.69\tPER 100: 2.64\n", 353 | "BATCH SIZE: 128\tTIME: 3.38\tPER 100: 2.64\n", 354 | "BATCH SIZE: 256\tTIME: 6.77\tPER 100: 2.64\n", 355 | "BATCH SIZE: 512\tTIME: 13.81\tPER 100: 2.70\n", 356 | "BATCH SIZE: 1024\tTIME: 27.64\tPER 100: 2.70\n", 357 | "BEST PER 100: 2.64\n", 358 | "\n", 359 | "------------------\n", 360 | "MODEL: \n", 361 | "SEQLEN: 2000\n", 362 | "BATCH SIZE: 1\tTIME: 0.05\tPER 100: 4.76\n", 363 | "BATCH SIZE: 32\tTIME: 1.50\tPER 100: 4.68\n", 364 | "BATCH SIZE: 64\tTIME: 2.97\tPER 100: 4.63\n", 365 | "BATCH SIZE: 128\tTIME: 5.97\tPER 100: 4.66\n", 366 | "BATCH SIZE: 256\tTIME: 11.98\tPER 100: 4.68\n", 367 | "BATCH SIZE: 512\tTIME: 23.94\tPER 100: 4.68\n", 368 | "BATCH SIZE: 1024\tTIME: 47.92\tPER 100: 4.68\n", 369 | "BEST PER 100: 4.63\n" 370 | ] 371 | } 372 | ], 373 | "source": [ 374 | "# shap_values most likely internally creates a batch for each example\n", 375 | "# thus time per 100 examples stays near constant with batch size\n", 376 | "\n", 377 | "for model_type in [basset_model, factorized_basset_model]:\n", 378 | " for seqlen in [1000, 2000]:\n", 379 | " print(\"\\n------------------\")\n", 380 | " print(\"MODEL: {}\".format(model_type))\n", 381 | " print(\"SEQLEN: {}\".format(seqlen))\n", 382 | " model = model_type(seqlen=seqlen, num_outputs=1)\n", 383 | " \n", 384 | " # dry run \n", 385 | " e = shap.DeepExplainer(model, data=create_background)\n", 386 | " o = e.shap_values(np.random.random((10,seqlen,4)), check_additivity=False)\n", 387 | " \n", 388 | " times = []\n", 389 | " per_100 = []\n", 390 | " for b in BATCH_SIZES:\n", 391 | " x = np.random.random((b,seqlen,4))\n", 392 | " t = time.time()\n", 393 | " e.shap_values(x, check_additivity=False)\n", 394 | " times.append(time.time()-t)\n", 395 | " per_100.append((times[-1]/b)*100)\n", 396 | " print(\"BATCH SIZE: {}\\tTIME: {:.2f}\\tPER 100: {:.2f}\".format(b, times[-1], (times[-1]/b)*100))\n", 397 | " \n", 398 | " print(\"BEST PER 100: {:.2f}\".format(min(per_100)))" 399 | ] 400 | }, 401 | { 402 | "cell_type": "markdown", 403 | "metadata": {}, 404 | "source": [ 405 | "### BPNet" 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": 10, 411 | "metadata": {}, 412 | "outputs": [], 413 | "source": [ 414 | "BATCH_SIZES = [1,8]" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": 11, 420 | "metadata": {}, 421 | "outputs": [], 422 | "source": [ 423 | "# linear ops\n", 424 | "shap.explainers._deep.deep_tf.op_handlers[\"BatchToSpaceND\"] = shap.explainers._deep.deep_tf.passthrough\n", 425 | "shap.explainers._deep.deep_tf.op_handlers[\"SpaceToBatchND\"] = shap.explainers._deep.deep_tf.passthrough\n", 426 | "shap.explainers._deep.deep_tf.op_handlers[\"Conv2DBackpropInput\"] = shap.explainers._deep.deep_tf.passthrough" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": 12, 432 | "metadata": {}, 433 | "outputs": [], 434 | "source": [ 435 | "# tensorflow throws warnings that stem from creating lots of explainers\n", 436 | "# suppress them\n", 437 | "tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": 13, 443 | "metadata": {}, 444 | "outputs": [ 445 | { 446 | "name": "stdout", 447 | "output_type": "stream", 448 | "text": [ 449 | "\n", 450 | "------------------\n", 451 | "SEQLEN: 1000\n", 452 | "BATCH SIZE: 1\tTIME: 17.80\tPER 100: 1779.57\n", 453 | "BATCH SIZE: 8\tTIME: 139.47\tPER 100: 1743.35\n", 454 | "BEST PER 100: 1743.35\n", 455 | "\n", 456 | "------------------\n", 457 | "SEQLEN: 2000\n", 458 | "BATCH SIZE: 1\tTIME: 65.06\tPER 100: 6506.26\n", 459 | "BATCH SIZE: 8\tTIME: 514.22\tPER 100: 6427.77\n", 460 | "BEST PER 100: 6427.77\n" 461 | ] 462 | } 463 | ], 464 | "source": [ 465 | "for seqlen in [1000, 2000]:\n", 466 | " print(\"\\n------------------\")\n", 467 | " print(\"SEQLEN: {}\".format(seqlen))\n", 468 | " model = bpnet_model(seqlen=seqlen, num_dilated_convs=9)\n", 469 | "\n", 470 | " # run explainers for each position\n", 471 | " times = [0 for _ in BATCH_SIZES]\n", 472 | " for i in range(seqlen):\n", 473 | " e = shap.DeepExplainer((model.input, model.output[0][:,i]), data=create_background)\n", 474 | " # dry run\n", 475 | " o = e.shap_values(np.random.random((1,seqlen,4)), check_additivity=False)\n", 476 | " \n", 477 | " # batch sizes in inner loop to make explainers only once for diff batch sizes\n", 478 | " # making explainers is the bottleneck\n", 479 | " for b_idx, b in enumerate(BATCH_SIZES):\n", 480 | " x = np.random.random((b,seqlen,4))\n", 481 | "\n", 482 | " # time taken for this position (excluding time taken for setting up explainers) \n", 483 | " t = time.time()\n", 484 | " e.shap_values(x, check_additivity=False)\n", 485 | " times[b_idx] += time.time()-t\n", 486 | "\n", 487 | " # counts output\n", 488 | " e = shap.DeepExplainer((model.input, model.output[1]), data=create_background)\n", 489 | " # dry run\n", 490 | " o = e.shap_values(np.random.random((1,seqlen,4)), check_additivity=False)\n", 491 | "\n", 492 | " for b_idx, b in enumerate(BATCH_SIZES):\n", 493 | " x = np.random.random((b,seqlen,4))\n", 494 | "\n", 495 | " # time taken for this position (excluding time taken for setting up explainers) \n", 496 | " t = time.time()\n", 497 | " e.shap_values(x, check_additivity=False)\n", 498 | " times[b_idx] += time.time()-t\n", 499 | "\n", 500 | " per_100 = [(x/BATCH_SIZES[i])*100 for i,x in enumerate(times)]\n", 501 | " \n", 502 | " for i,x in enumerate(times): \n", 503 | " print(\"BATCH SIZE: {}\\tTIME: {:.2f}\\tPER 100: {:.2f}\".format(BATCH_SIZES[i], x, per_100[i]))\n", 504 | "\n", 505 | " print(\"BEST PER 100: {:.2f}\".format(min(per_100)))" 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": null, 511 | "metadata": {}, 512 | "outputs": [], 513 | "source": [] 514 | } 515 | ], 516 | "metadata": { 517 | "kernelspec": { 518 | "display_name": "Python 3", 519 | "language": "python", 520 | "name": "python3" 521 | }, 522 | "language_info": { 523 | "codemirror_mode": { 524 | "name": "ipython", 525 | "version": 3 526 | }, 527 | "file_extension": ".py", 528 | "mimetype": "text/x-python", 529 | "name": "python", 530 | "nbconvert_exporter": "python", 531 | "pygments_lexer": "ipython3", 532 | "version": "3.7.9" 533 | } 534 | }, 535 | "nbformat": 4, 536 | "nbformat_minor": 4 537 | } 538 | -------------------------------------------------------------------------------- /notebooks/GradxInputBenchmark.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Grad x Input Benchmark" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import sys\n", 17 | "sys.path.append(\"../\")\n", 18 | "\n", 19 | "import fastISM\n", 20 | "from fastISM.models.basset import basset_model\n", 21 | "\n", 22 | "from fastISM.models.factorized_basset import factorized_basset_model\n", 23 | "from fastISM.models.bpnet import bpnet_model\n", 24 | "import tensorflow as tf\n", 25 | "import numpy as np\n", 26 | "from importlib import reload\n", 27 | "import time" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "metadata": {}, 34 | "outputs": [ 35 | { 36 | "data": { 37 | "text/plain": [ 38 | "" 39 | ] 40 | }, 41 | "execution_count": 2, 42 | "metadata": {}, 43 | "output_type": "execute_result" 44 | } 45 | ], 46 | "source": [ 47 | "reload(fastISM.flatten_model)\n", 48 | "reload(fastISM.models)\n", 49 | "reload(fastISM.ism_base)\n", 50 | "reload(fastISM.change_range)\n", 51 | "reload(fastISM.fast_ism_utils)\n", 52 | "reload(fastISM)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 3, 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "data": { 62 | "text/plain": [ 63 | "'2.3.0'" 64 | ] 65 | }, 66 | "execution_count": 3, 67 | "metadata": {}, 68 | "output_type": "execute_result" 69 | } 70 | ], 71 | "source": [ 72 | "tf.__version__" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 4, 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "Tue Sep 8 09:09:35 2020 \n", 85 | "+-----------------------------------------------------------------------------+\n", 86 | "| NVIDIA-SMI 450.51.05 Driver Version: 450.51.05 CUDA Version: 11.0 |\n", 87 | "|-------------------------------+----------------------+----------------------+\n", 88 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", 89 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", 90 | "| | | MIG M. |\n", 91 | "|===============================+======================+======================|\n", 92 | "| 0 Tesla P100-PCIE... On | 00000000:82:00.0 Off | 0 |\n", 93 | "| N/A 30C P0 27W / 250W | 0MiB / 16280MiB | 0% E. Process |\n", 94 | "| | | N/A |\n", 95 | "+-------------------------------+----------------------+----------------------+\n", 96 | " \n", 97 | "+-----------------------------------------------------------------------------+\n", 98 | "| Processes: |\n", 99 | "| GPU GI CI PID Type Process name GPU Memory |\n", 100 | "| ID ID Usage |\n", 101 | "|=============================================================================|\n", 102 | "| No running processes found |\n", 103 | "+-----------------------------------------------------------------------------+\n" 104 | ] 105 | } 106 | ], 107 | "source": [ 108 | "!nvidia-smi" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 5, 114 | "metadata": {}, 115 | "outputs": [ 116 | { 117 | "name": "stdout", 118 | "output_type": "stream", 119 | "text": [ 120 | "GPU 0: Tesla P100-PCIE-16GB (UUID: GPU-0d9a859c-ce19-78f3-2f87-aade11d14bae)\n" 121 | ] 122 | } 123 | ], 124 | "source": [ 125 | "!nvidia-smi -L" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 6, 131 | "metadata": {}, 132 | "outputs": [ 133 | { 134 | "name": "stdout", 135 | "output_type": "stream", 136 | "text": [ 137 | "nvcc: NVIDIA (R) Cuda compiler driver\n", 138 | "Copyright (c) 2005-2019 NVIDIA Corporation\n", 139 | "Built on Wed_Apr_24_19:10:27_PDT_2019\n", 140 | "Cuda compilation tools, release 10.1, V10.1.168\n" 141 | ] 142 | } 143 | ], 144 | "source": [ 145 | "!nvcc --version" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 7, 151 | "metadata": {}, 152 | "outputs": [ 153 | { 154 | "name": "stdout", 155 | "output_type": "stream", 156 | "text": [ 157 | "Num GPUs Available: 1\n" 158 | ] 159 | } 160 | ], 161 | "source": [ 162 | "print(\"Num GPUs Available: \", len(tf.config.experimental.list_physical_devices('GPU')))" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 8, 168 | "metadata": {}, 169 | "outputs": [ 170 | { 171 | "data": { 172 | "text/plain": [ 173 | "'GPU:0'" 174 | ] 175 | }, 176 | "execution_count": 8, 177 | "metadata": {}, 178 | "output_type": "execute_result" 179 | } 180 | ], 181 | "source": [ 182 | "device = 'GPU:0' if tf.config.experimental.list_physical_devices('GPU') else '/device:CPU:0'\n", 183 | "device" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "metadata": {}, 189 | "source": [ 190 | "## Benchmark" 191 | ] 192 | }, 193 | { 194 | "cell_type": "markdown", 195 | "metadata": {}, 196 | "source": [ 197 | "### Basset/Factorized Basset" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 9, 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "BATCH_SIZES = [1,32,64,128,256]" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 10, 212 | "metadata": {}, 213 | "outputs": [ 214 | { 215 | "name": "stdout", 216 | "output_type": "stream", 217 | "text": [ 218 | "\n", 219 | "------------------\n", 220 | "MODEL: \n", 221 | "SEQLEN: 1000\n", 222 | "BATCH SIZE: 1\tTIME: 0.03\tPER 100: 2.54\n", 223 | "BATCH SIZE: 32\tTIME: 0.02\tPER 100: 0.06\n", 224 | "BATCH SIZE: 64\tTIME: 0.03\tPER 100: 0.05\n", 225 | "BATCH SIZE: 128\tTIME: 0.05\tPER 100: 0.04\n", 226 | "BATCH SIZE: 256\tTIME: 0.10\tPER 100: 0.04\n", 227 | "BEST PER 100: 0.04\n", 228 | "\n", 229 | "------------------\n", 230 | "MODEL: \n", 231 | "SEQLEN: 2000\n", 232 | "BATCH SIZE: 1\tTIME: 0.01\tPER 100: 1.26\n", 233 | "BATCH SIZE: 32\tTIME: 0.03\tPER 100: 0.09\n", 234 | "BATCH SIZE: 64\tTIME: 0.05\tPER 100: 0.08\n", 235 | "BATCH SIZE: 128\tTIME: 0.10\tPER 100: 0.08\n", 236 | "BATCH SIZE: 256\tTIME: 0.20\tPER 100: 0.08\n", 237 | "BEST PER 100: 0.08\n", 238 | "\n", 239 | "------------------\n", 240 | "MODEL: \n", 241 | "SEQLEN: 1000\n", 242 | "BATCH SIZE: 1\tTIME: 0.03\tPER 100: 2.78\n", 243 | "BATCH SIZE: 32\tTIME: 0.04\tPER 100: 0.13\n", 244 | "BATCH SIZE: 64\tTIME: 0.06\tPER 100: 0.10\n", 245 | "BATCH SIZE: 128\tTIME: 0.12\tPER 100: 0.09\n", 246 | "BATCH SIZE: 256\tTIME: 0.23\tPER 100: 0.09\n", 247 | "BEST PER 100: 0.09\n", 248 | "\n", 249 | "------------------\n", 250 | "MODEL: \n", 251 | "SEQLEN: 2000\n", 252 | "BATCH SIZE: 1\tTIME: 0.03\tPER 100: 2.74\n", 253 | "BATCH SIZE: 32\tTIME: 0.06\tPER 100: 0.20\n", 254 | "BATCH SIZE: 64\tTIME: 0.12\tPER 100: 0.18\n", 255 | "BATCH SIZE: 128\tTIME: 0.22\tPER 100: 0.18\n", 256 | "BATCH SIZE: 256\tTIME: 0.45\tPER 100: 0.17\n", 257 | "BEST PER 100: 0.17\n" 258 | ] 259 | } 260 | ], 261 | "source": [ 262 | "# shap_values most likely internally creates a batch for each example\n", 263 | "# thus time per 100 examples stays near constant with batch size\n", 264 | "\n", 265 | "NUM_TO_AVG = 100\n", 266 | "\n", 267 | "for model_type in [basset_model, factorized_basset_model]:\n", 268 | " for seqlen in [1000, 2000]:\n", 269 | " print(\"\\n------------------\")\n", 270 | " print(\"MODEL: {}\".format(model_type))\n", 271 | " print(\"SEQLEN: {}\".format(seqlen))\n", 272 | " model = model_type(seqlen=seqlen, num_outputs=1)\n", 273 | " \n", 274 | " # dry run \n", 275 | " p = model(np.random.random((10,seqlen,4)))\n", 276 | " \n", 277 | " times = []\n", 278 | " per_100 = []\n", 279 | " for b in BATCH_SIZES:\n", 280 | " tot = 0\n", 281 | " for i in range(NUM_TO_AVG):\n", 282 | " x = np.random.random((b,seqlen,4))\n", 283 | " \n", 284 | " t = time.time()\n", 285 | " x = tf.constant(x)\n", 286 | " with tf.GradientTape() as tape:\n", 287 | " tape.watch(x)\n", 288 | " pred = model(x)\n", 289 | " g = (x*tape.gradient(pred, x)).numpy()\n", 290 | " \n", 291 | " tot+= time.time() - t\n", 292 | " \n", 293 | " times.append(tot/NUM_TO_AVG)\n", 294 | " per_100.append((times[-1]/b)*100)\n", 295 | " print(\"BATCH SIZE: {}\\tTIME: {:.2f}\\tPER 100: {:.2f}\".format(b, times[-1], (times[-1]/b)*100))\n", 296 | " \n", 297 | " print(\"BEST PER 100: {:.2f}\".format(min(per_100)))" 298 | ] 299 | }, 300 | { 301 | "cell_type": "markdown", 302 | "metadata": {}, 303 | "source": [ 304 | "### BPNet" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 17, 310 | "metadata": {}, 311 | "outputs": [], 312 | "source": [ 313 | "BATCH_SIZES = [64, 128]" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": 18, 319 | "metadata": {}, 320 | "outputs": [ 321 | { 322 | "name": "stdout", 323 | "output_type": "stream", 324 | "text": [ 325 | "\n", 326 | "------------------\n", 327 | "SEQLEN: 1000\n", 328 | "BATCH SIZE: 64\tTIME: 27.25\tPER 100: 42.58\n", 329 | "BATCH SIZE: 128\tTIME: 54.40\tPER 100: 42.50\n", 330 | "BEST PER 100: 42.50\n", 331 | "\n", 332 | "------------------\n", 333 | "SEQLEN: 2000\n", 334 | "BATCH SIZE: 64\tTIME: 80.90\tPER 100: 126.41\n", 335 | "BATCH SIZE: 128\tTIME: 164.60\tPER 100: 128.60\n", 336 | "BEST PER 100: 126.41\n" 337 | ] 338 | } 339 | ], 340 | "source": [ 341 | "for seqlen in [1000, 2000]:\n", 342 | " print(\"\\n------------------\")\n", 343 | " print(\"SEQLEN: {}\".format(seqlen))\n", 344 | " model = bpnet_model(seqlen=seqlen, num_dilated_convs=9)\n", 345 | "\n", 346 | " # run explainers for each position\n", 347 | " times = []\n", 348 | " per_100 = []\n", 349 | "\n", 350 | " # dry run \n", 351 | " p = model(np.random.random((10,seqlen,4)))\n", 352 | "\n", 353 | " for b_idx, b in enumerate(BATCH_SIZES):\n", 354 | " x = np.random.random((b,seqlen,4))\n", 355 | "\n", 356 | " t = time.time()\n", 357 | " x = tf.constant(x)\n", 358 | " g=[]\n", 359 | " \n", 360 | " with tf.GradientTape(persistent=True) as tape:\n", 361 | " tape.watch(x)\n", 362 | " prof, ct = model(x)\n", 363 | " prof = [prof[:,i:i+1] for i in range(seqlen)]\n", 364 | "\n", 365 | " for i in range(seqlen):\n", 366 | " g.append((x*tape.gradient(prof[i], x)).numpy())\n", 367 | " g.append((x*tape.gradient(ct, x)).numpy())\n", 368 | " times.append(time.time()-t)\n", 369 | "\n", 370 | " per_100.append((times[-1]/b)*100)\n", 371 | " print(\"BATCH SIZE: {}\\tTIME: {:.2f}\\tPER 100: {:.2f}\".format(b, times[-1], per_100[-1]))\n", 372 | "\n", 373 | " print(\"BEST PER 100: {:.2f}\".format(min(per_100)))" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": null, 379 | "metadata": {}, 380 | "outputs": [], 381 | "source": [] 382 | } 383 | ], 384 | "metadata": { 385 | "kernelspec": { 386 | "display_name": "Python 3", 387 | "language": "python", 388 | "name": "python3" 389 | }, 390 | "language_info": { 391 | "codemirror_mode": { 392 | "name": "ipython", 393 | "version": 3 394 | }, 395 | "file_extension": ".py", 396 | "mimetype": "text/x-python", 397 | "name": "python", 398 | "nbconvert_exporter": "python", 399 | "pygments_lexer": "ipython3", 400 | "version": "3.7.9" 401 | } 402 | }, 403 | "nbformat": 4, 404 | "nbformat_minor": 4 405 | } 406 | -------------------------------------------------------------------------------- /notebooks/IntegratedGradientsBenchmark.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Integrated Gradients Benchmark" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import sys\n", 17 | "sys.path.append(\"../\")\n", 18 | "\n", 19 | "import fastISM\n", 20 | "from fastISM.models.basset import basset_model\n", 21 | "\n", 22 | "from fastISM.models.factorized_basset import factorized_basset_model\n", 23 | "from fastISM.models.bpnet import bpnet_model\n", 24 | "import tensorflow as tf\n", 25 | "import numpy as np\n", 26 | "from importlib import reload\n", 27 | "import time" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "metadata": {}, 34 | "outputs": [ 35 | { 36 | "data": { 37 | "text/plain": [ 38 | "" 39 | ] 40 | }, 41 | "execution_count": 2, 42 | "metadata": {}, 43 | "output_type": "execute_result" 44 | } 45 | ], 46 | "source": [ 47 | "reload(fastISM.flatten_model)\n", 48 | "reload(fastISM.models)\n", 49 | "reload(fastISM.ism_base)\n", 50 | "reload(fastISM.change_range)\n", 51 | "reload(fastISM.fast_ism_utils)\n", 52 | "reload(fastISM)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 3, 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "data": { 62 | "text/plain": [ 63 | "'2.3.0'" 64 | ] 65 | }, 66 | "execution_count": 3, 67 | "metadata": {}, 68 | "output_type": "execute_result" 69 | } 70 | ], 71 | "source": [ 72 | "tf.__version__" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 4, 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "Tue Sep 8 09:53:00 2020 \n", 85 | "+-----------------------------------------------------------------------------+\n", 86 | "| NVIDIA-SMI 450.51.05 Driver Version: 450.51.05 CUDA Version: 11.0 |\n", 87 | "|-------------------------------+----------------------+----------------------+\n", 88 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", 89 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", 90 | "| | | MIG M. |\n", 91 | "|===============================+======================+======================|\n", 92 | "| 0 Tesla P100-PCIE... On | 00000000:82:00.0 Off | 0 |\n", 93 | "| N/A 32C P0 27W / 250W | 0MiB / 16280MiB | 0% E. Process |\n", 94 | "| | | N/A |\n", 95 | "+-------------------------------+----------------------+----------------------+\n", 96 | " \n", 97 | "+-----------------------------------------------------------------------------+\n", 98 | "| Processes: |\n", 99 | "| GPU GI CI PID Type Process name GPU Memory |\n", 100 | "| ID ID Usage |\n", 101 | "|=============================================================================|\n", 102 | "| No running processes found |\n", 103 | "+-----------------------------------------------------------------------------+\n" 104 | ] 105 | } 106 | ], 107 | "source": [ 108 | "!nvidia-smi" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 5, 114 | "metadata": {}, 115 | "outputs": [ 116 | { 117 | "name": "stdout", 118 | "output_type": "stream", 119 | "text": [ 120 | "GPU 0: Tesla P100-PCIE-16GB (UUID: GPU-0d9a859c-ce19-78f3-2f87-aade11d14bae)\n" 121 | ] 122 | } 123 | ], 124 | "source": [ 125 | "!nvidia-smi -L" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 6, 131 | "metadata": {}, 132 | "outputs": [ 133 | { 134 | "name": "stdout", 135 | "output_type": "stream", 136 | "text": [ 137 | "nvcc: NVIDIA (R) Cuda compiler driver\n", 138 | "Copyright (c) 2005-2019 NVIDIA Corporation\n", 139 | "Built on Wed_Apr_24_19:10:27_PDT_2019\n", 140 | "Cuda compilation tools, release 10.1, V10.1.168\n" 141 | ] 142 | } 143 | ], 144 | "source": [ 145 | "!nvcc --version" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 7, 151 | "metadata": {}, 152 | "outputs": [ 153 | { 154 | "name": "stdout", 155 | "output_type": "stream", 156 | "text": [ 157 | "Num GPUs Available: 1\n" 158 | ] 159 | } 160 | ], 161 | "source": [ 162 | "print(\"Num GPUs Available: \", len(tf.config.experimental.list_physical_devices('GPU')))" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 8, 168 | "metadata": {}, 169 | "outputs": [ 170 | { 171 | "data": { 172 | "text/plain": [ 173 | "'GPU:0'" 174 | ] 175 | }, 176 | "execution_count": 8, 177 | "metadata": {}, 178 | "output_type": "execute_result" 179 | } 180 | ], 181 | "source": [ 182 | "device = 'GPU:0' if tf.config.experimental.list_physical_devices('GPU') else '/device:CPU:0'\n", 183 | "device" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 9, 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "import alibi\n", 193 | "from alibi.explainers import IntegratedGradients" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 10, 199 | "metadata": {}, 200 | "outputs": [ 201 | { 202 | "data": { 203 | "text/plain": [ 204 | "'0.5.4'" 205 | ] 206 | }, 207 | "execution_count": 10, 208 | "metadata": {}, 209 | "output_type": "execute_result" 210 | } 211 | ], 212 | "source": [ 213 | "alibi.__version__" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": {}, 219 | "source": [ 220 | "## Benchmark" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 11, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "def time_ig(model, batch_sizes, seqlen, num_examples=500, n_steps=50, targets = [None]):\n", 230 | " x = np.random.random((num_examples,seqlen,4))\n", 231 | " times = []\n", 232 | " per_100 = []\n", 233 | " for b in batch_sizes:\n", 234 | " ig = IntegratedGradients(model,\n", 235 | " layer=None,\n", 236 | " method=\"gausslegendre\",\n", 237 | " n_steps=n_steps,\n", 238 | " internal_batch_size=b)\n", 239 | " # dry run\n", 240 | " ig.explain(x[:10], baselines=None,\n", 241 | " target=targets[0])\n", 242 | " \n", 243 | " t = time.time()\n", 244 | " for tgt in targets:\n", 245 | " ig.explain(x, baselines=None,\n", 246 | " target=tgt)\n", 247 | " times.append(time.time()-t)\n", 248 | " per_100.append((times[-1]/num_examples)*100)\n", 249 | " print(\"BATCH: {}\\tTIME: {:.2f}\\tPER 100: {:.2f}\".format(b, times[-1], (times[-1]/num_examples)*100))\n", 250 | " \n", 251 | " print(\"BEST PER 100: {:.2f}\".format(min(per_100)))" 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "metadata": {}, 257 | "source": [ 258 | "### Basset (1000)" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 12, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "model = basset_model(seqlen=1000, num_outputs=1)" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 13, 273 | "metadata": {}, 274 | "outputs": [ 275 | { 276 | "name": "stdout", 277 | "output_type": "stream", 278 | "text": [ 279 | "BATCH: 100\tTIME: 0.26\tPER 100: 2.62\n", 280 | "BATCH: 200\tTIME: 0.24\tPER 100: 2.37\n", 281 | "BATCH: 500\tTIME: 0.25\tPER 100: 2.50\n", 282 | "BEST PER 100: 2.37\n" 283 | ] 284 | } 285 | ], 286 | "source": [ 287 | "%%capture --no-stdout \n", 288 | "# hide warning about scalar output\n", 289 | "\n", 290 | "time_ig(model, [100, 200, 500], 1000, num_examples=10, targets=[None]) # targets None since only one scalar output" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 14, 296 | "metadata": {}, 297 | "outputs": [ 298 | { 299 | "name": "stdout", 300 | "output_type": "stream", 301 | "text": [ 302 | "BATCH: 100\tTIME: 2.77\tPER 100: 2.77\n", 303 | "BATCH: 200\tTIME: 2.39\tPER 100: 2.39\n", 304 | "BATCH: 500\tTIME: 2.34\tPER 100: 2.34\n", 305 | "BATCH: 1000\tTIME: 3.28\tPER 100: 3.28\n", 306 | "BEST PER 100: 2.34\n" 307 | ] 308 | } 309 | ], 310 | "source": [ 311 | "%%capture --no-stdout\n", 312 | "\n", 313 | "time_ig(model, [100, 200, 500, 1000], 1000, num_examples=100, targets=[None]) " 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "metadata": {}, 319 | "source": [ 320 | "### Basset (2000)" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": 12, 326 | "metadata": {}, 327 | "outputs": [], 328 | "source": [ 329 | "model = basset_model(seqlen=2000, num_outputs=1)" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": 13, 335 | "metadata": {}, 336 | "outputs": [ 337 | { 338 | "name": "stdout", 339 | "output_type": "stream", 340 | "text": [ 341 | "BATCH: 100\tTIME: 0.48\tPER 100: 4.80\n", 342 | "BATCH: 200\tTIME: 0.48\tPER 100: 4.76\n", 343 | "BATCH: 500\tTIME: 0.48\tPER 100: 4.82\n", 344 | "BEST PER 100: 4.76\n" 345 | ] 346 | } 347 | ], 348 | "source": [ 349 | "%%capture --no-stdout\n", 350 | "\n", 351 | "time_ig(model, [100, 200, 500], 2000, num_examples=10, targets=[None]) " 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 14, 357 | "metadata": {}, 358 | "outputs": [ 359 | { 360 | "name": "stdout", 361 | "output_type": "stream", 362 | "text": [ 363 | "BATCH: 100\tTIME: 4.77\tPER 100: 4.77\n", 364 | "BATCH: 200\tTIME: 4.63\tPER 100: 4.63\n", 365 | "BATCH: 500\tTIME: 4.61\tPER 100: 4.61\n", 366 | "BEST PER 100: 4.61\n" 367 | ] 368 | } 369 | ], 370 | "source": [ 371 | "%%capture --no-stdout\n", 372 | "\n", 373 | "time_ig(model, [100, 200, 500], 2000, num_examples=100, targets=[None]) " 374 | ] 375 | }, 376 | { 377 | "cell_type": "markdown", 378 | "metadata": {}, 379 | "source": [ 380 | "### Factorized Basset (1000)" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": 12, 386 | "metadata": {}, 387 | "outputs": [], 388 | "source": [ 389 | "model = factorized_basset_model(seqlen=1000, num_outputs=1)" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": 13, 395 | "metadata": {}, 396 | "outputs": [ 397 | { 398 | "name": "stdout", 399 | "output_type": "stream", 400 | "text": [ 401 | "BATCH: 100\tTIME: 0.54\tPER 100: 5.40\n", 402 | "BATCH: 200\tTIME: 0.51\tPER 100: 5.14\n", 403 | "BATCH: 500\tTIME: 0.50\tPER 100: 5.00\n", 404 | "BEST PER 100: 5.00\n" 405 | ] 406 | } 407 | ], 408 | "source": [ 409 | "%%capture --no-stdout\n", 410 | "time_ig(model, [100, 200, 500], 1000, num_examples=10, targets=[None])" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": 14, 416 | "metadata": {}, 417 | "outputs": [ 418 | { 419 | "name": "stdout", 420 | "output_type": "stream", 421 | "text": [ 422 | "BATCH: 100\tTIME: 5.18\tPER 100: 5.18\n", 423 | "BATCH: 200\tTIME: 4.92\tPER 100: 4.92\n", 424 | "BATCH: 500\tTIME: 4.82\tPER 100: 4.82\n", 425 | "BEST PER 100: 4.82\n" 426 | ] 427 | } 428 | ], 429 | "source": [ 430 | "%%capture --no-stdout\n", 431 | "time_ig(model, [100, 200, 500], 1000, num_examples=100, targets=[None])" 432 | ] 433 | }, 434 | { 435 | "cell_type": "markdown", 436 | "metadata": {}, 437 | "source": [ 438 | "### Factorized Basset (2000)" 439 | ] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "execution_count": 12, 444 | "metadata": {}, 445 | "outputs": [], 446 | "source": [ 447 | "model = factorized_basset_model(seqlen=2000, num_outputs=1)" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": 13, 453 | "metadata": {}, 454 | "outputs": [ 455 | { 456 | "name": "stdout", 457 | "output_type": "stream", 458 | "text": [ 459 | "BATCH: 100\tTIME: 0.99\tPER 100: 9.89\n", 460 | "BATCH: 200\tTIME: 0.97\tPER 100: 9.73\n", 461 | "BATCH: 300\tTIME: 0.95\tPER 100: 9.47\n", 462 | "BEST PER 100: 9.47\n" 463 | ] 464 | } 465 | ], 466 | "source": [ 467 | "%%capture --no-stdout\n", 468 | "time_ig(model, [100, 200, 300], 2000, num_examples=10, targets=[None])" 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": 14, 474 | "metadata": {}, 475 | "outputs": [ 476 | { 477 | "name": "stdout", 478 | "output_type": "stream", 479 | "text": [ 480 | "BATCH: 100\tTIME: 9.90\tPER 100: 9.90\n", 481 | "BATCH: 200\tTIME: 9.56\tPER 100: 9.56\n", 482 | "BATCH: 300\tTIME: 9.53\tPER 100: 9.53\n", 483 | "BEST PER 100: 9.53\n" 484 | ] 485 | } 486 | ], 487 | "source": [ 488 | "%%capture --no-stdout\n", 489 | "time_ig(model, [100, 200, 300], 2000, num_examples=100, targets=[None])" 490 | ] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "execution_count": 15, 495 | "metadata": {}, 496 | "outputs": [ 497 | { 498 | "name": "stdout", 499 | "output_type": "stream", 500 | "text": [ 501 | "BATCH: 100\tTIME: 19.84\tPER 100: 9.92\n", 502 | "BATCH: 200\tTIME: 19.14\tPER 100: 9.57\n", 503 | "BATCH: 300\tTIME: 19.07\tPER 100: 9.54\n", 504 | "BEST PER 100: 9.54\n" 505 | ] 506 | } 507 | ], 508 | "source": [ 509 | "%%capture --no-stdout\n", 510 | "time_ig(model, [100, 200, 300], 2000, num_examples=200, targets=[None])" 511 | ] 512 | }, 513 | { 514 | "cell_type": "markdown", 515 | "metadata": {}, 516 | "source": [ 517 | "### BPNet (1000)" 518 | ] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "execution_count": 12, 523 | "metadata": {}, 524 | "outputs": [], 525 | "source": [ 526 | "model = bpnet_model(seqlen=1000, num_dilated_convs=9)\n", 527 | "\n", 528 | "# flatten and concat outputs\n", 529 | "inp = tf.keras.Input(shape=model.input_shape[1:])\n", 530 | "prof, cts = model(inp)\n", 531 | "prof = tf.keras.layers.Flatten()(prof)\n", 532 | "cts = tf.keras.layers.Flatten()(cts)\n", 533 | "out = tf.keras.layers.Concatenate()([prof, cts])\n", 534 | "model_ig = tf.keras.Model(inputs=inp, outputs=out)\n", 535 | "\n", 536 | "# flattened outputs\n", 537 | "model = model_ig" 538 | ] 539 | }, 540 | { 541 | "cell_type": "code", 542 | "execution_count": 13, 543 | "metadata": {}, 544 | "outputs": [ 545 | { 546 | "data": { 547 | "text/plain": [ 548 | "" 549 | ] 550 | }, 551 | "execution_count": 13, 552 | "metadata": {}, 553 | "output_type": "execute_result" 554 | } 555 | ], 556 | "source": [ 557 | "model.output" 558 | ] 559 | }, 560 | { 561 | "cell_type": "code", 562 | "execution_count": 14, 563 | "metadata": {}, 564 | "outputs": [ 565 | { 566 | "name": "stdout", 567 | "output_type": "stream", 568 | "text": [ 569 | "BATCH: 500\tTIME: 439.95\tPER 100: 4399.53\n", 570 | "BEST PER 100: 4399.53\n" 571 | ] 572 | } 573 | ], 574 | "source": [ 575 | "time_ig(model, [500], 1000, num_examples=10, targets=range(1001)) # all 1000 profile outs + 1 count out" 576 | ] 577 | }, 578 | { 579 | "cell_type": "markdown", 580 | "metadata": {}, 581 | "source": [ 582 | "### BPNet (2000)" 583 | ] 584 | }, 585 | { 586 | "cell_type": "code", 587 | "execution_count": 12, 588 | "metadata": {}, 589 | "outputs": [], 590 | "source": [ 591 | "model = bpnet_model(seqlen=2000, num_dilated_convs=9)\n", 592 | "\n", 593 | "# flatten and concat outputs\n", 594 | "inp = tf.keras.Input(shape=model.input_shape[1:])\n", 595 | "prof, cts = model(inp)\n", 596 | "prof = tf.keras.layers.Flatten()(prof)\n", 597 | "cts = tf.keras.layers.Flatten()(cts)\n", 598 | "out = tf.keras.layers.Concatenate()([prof, cts])\n", 599 | "model_ig = tf.keras.Model(inputs=inp, outputs=out)\n", 600 | "\n", 601 | "# flattened outputs\n", 602 | "model = model_ig" 603 | ] 604 | }, 605 | { 606 | "cell_type": "code", 607 | "execution_count": 13, 608 | "metadata": {}, 609 | "outputs": [ 610 | { 611 | "data": { 612 | "text/plain": [ 613 | "" 614 | ] 615 | }, 616 | "execution_count": 13, 617 | "metadata": {}, 618 | "output_type": "execute_result" 619 | } 620 | ], 621 | "source": [ 622 | "model.output" 623 | ] 624 | }, 625 | { 626 | "cell_type": "code", 627 | "execution_count": 14, 628 | "metadata": {}, 629 | "outputs": [ 630 | { 631 | "name": "stdout", 632 | "output_type": "stream", 633 | "text": [ 634 | "BATCH: 500\tTIME: 622.02\tPER 100: 12440.47\n", 635 | "BEST PER 100: 12440.47\n" 636 | ] 637 | } 638 | ], 639 | "source": [ 640 | "time_ig(model, [500], 2000, num_examples=5, targets=range(2001)) # all 2000 profile outs + 1 count out" 641 | ] 642 | }, 643 | { 644 | "cell_type": "code", 645 | "execution_count": null, 646 | "metadata": {}, 647 | "outputs": [], 648 | "source": [] 649 | } 650 | ], 651 | "metadata": { 652 | "kernelspec": { 653 | "display_name": "Python 3", 654 | "language": "python", 655 | "name": "python3" 656 | }, 657 | "language_info": { 658 | "codemirror_mode": { 659 | "name": "ipython", 660 | "version": 3 661 | }, 662 | "file_extension": ".py", 663 | "mimetype": "text/x-python", 664 | "name": "python", 665 | "nbconvert_exporter": "python", 666 | "pygments_lexer": "ipython3", 667 | "version": "3.7.9" 668 | } 669 | }, 670 | "nbformat": 4, 671 | "nbformat_minor": 4 672 | } 673 | -------------------------------------------------------------------------------- /notebooks/TimeBassetParts.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Time Basset Parts\n", 8 | "\n", 9 | "Estimate how much time is spent in convolutional vs fc layers for Basset." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import tensorflow as tf\n", 19 | "import numpy as np" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "data": { 29 | "text/plain": [ 30 | "'2.3.0'" 31 | ] 32 | }, 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "output_type": "execute_result" 36 | } 37 | ], 38 | "source": [ 39 | "tf.__version__" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "metadata": {}, 46 | "outputs": [ 47 | { 48 | "name": "stdout", 49 | "output_type": "stream", 50 | "text": [ 51 | "Num GPUs Available: 1\n" 52 | ] 53 | } 54 | ], 55 | "source": [ 56 | "print(\"Num GPUs Available: \", len(tf.config.experimental.list_physical_devices('GPU')))" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 4, 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "data": { 66 | "text/plain": [ 67 | "'GPU:0'" 68 | ] 69 | }, 70 | "execution_count": 4, 71 | "metadata": {}, 72 | "output_type": "execute_result" 73 | } 74 | ], 75 | "source": [ 76 | "device = 'GPU:0' if tf.config.experimental.list_physical_devices('GPU') else '/device:CPU:0'\n", 77 | "device" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "## Conv Layers" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 5, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "inp = tf.keras.Input(shape=(1000, 4))\n", 94 | "\n", 95 | "# conv mxp 1\n", 96 | "x = tf.keras.layers.Conv1D(\n", 97 | " 300, 19, strides=1, padding='same', activation='relu', name='conv1')(inp)\n", 98 | "x = tf.keras.layers.BatchNormalization()(x)\n", 99 | "x = tf.keras.layers.MaxPool1D(3)(x)\n", 100 | "\n", 101 | "# conv mxp 2\n", 102 | "x = tf.keras.layers.Conv1D(\n", 103 | " 200, 11, strides=1, padding='same', activation='relu', name='conv2')(x)\n", 104 | "x = tf.keras.layers.BatchNormalization()(x)\n", 105 | "x = tf.keras.layers.MaxPool1D(4)(x)\n", 106 | "\n", 107 | "# conv mxp 3\n", 108 | "x = tf.keras.layers.Conv1D(\n", 109 | " 200, 7, strides=1, padding='same', activation='relu', name='conv3')(x)\n", 110 | "x = tf.keras.layers.BatchNormalization()(x)\n", 111 | "x = tf.keras.layers.MaxPool1D(4)(x)\n", 112 | "\n", 113 | "# fc\n", 114 | "x = tf.keras.layers.Flatten()(x)\n", 115 | "\n", 116 | "# sum it up\n", 117 | "x = tf.keras.backend.sum(x, axis= -1, keepdims=True)\n", 118 | "\n", 119 | "conv_model = tf.keras.Model(inp, x)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 6, 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "data": { 129 | "text/plain": [ 130 | "'/job:localhost/replica:0/task:0/device:GPU:0'" 131 | ] 132 | }, 133 | "execution_count": 6, 134 | "metadata": {}, 135 | "output_type": "execute_result" 136 | } 137 | ], 138 | "source": [ 139 | "x = tf.constant(np.random.random((512,1000,4)))*1\n", 140 | "x.device " 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 7, 146 | "metadata": {}, 147 | "outputs": [ 148 | { 149 | "data": { 150 | "text/plain": [ 151 | "" 162 | ] 163 | }, 164 | "execution_count": 7, 165 | "metadata": {}, 166 | "output_type": "execute_result" 167 | } 168 | ], 169 | "source": [ 170 | "# dry run\n", 171 | "conv_model(x[:10], training=False)" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 9, 177 | "metadata": {}, 178 | "outputs": [ 179 | { 180 | "name": "stdout", 181 | "output_type": "stream", 182 | "text": [ 183 | "75.9 ms ± 198 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" 184 | ] 185 | } 186 | ], 187 | "source": [ 188 | "%timeit conv_model(x, training=False)" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "## FC Layers" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 10, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "inp = tf.keras.Input(shape=(4000))\n", 205 | "\n", 206 | "x = tf.keras.layers.Dense(1000, activation='relu', name='fc1')(inp)\n", 207 | "x = tf.keras.layers.BatchNormalization()(x)\n", 208 | "x = tf.keras.layers.Dense(1000, activation='relu', name='fc2')(x)\n", 209 | "x = tf.keras.layers.BatchNormalization()(x)\n", 210 | "x = tf.keras.layers.Dense(1, name='fc3')(x)\n", 211 | "\n", 212 | "fc_model = tf.keras.Model(inputs=inp, outputs=x)" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 11, 218 | "metadata": {}, 219 | "outputs": [ 220 | { 221 | "data": { 222 | "text/plain": [ 223 | "'/job:localhost/replica:0/task:0/device:GPU:0'" 224 | ] 225 | }, 226 | "execution_count": 11, 227 | "metadata": {}, 228 | "output_type": "execute_result" 229 | } 230 | ], 231 | "source": [ 232 | "x = tf.constant(np.random.random((512, 4000)))*1\n", 233 | "x.device " 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": 12, 239 | "metadata": {}, 240 | "outputs": [ 241 | { 242 | "data": { 243 | "text/plain": [ 244 | "" 255 | ] 256 | }, 257 | "execution_count": 12, 258 | "metadata": {}, 259 | "output_type": "execute_result" 260 | } 261 | ], 262 | "source": [ 263 | "# dry run\n", 264 | "fc_model(x[:10], training=False)" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 13, 270 | "metadata": {}, 271 | "outputs": [ 272 | { 273 | "name": "stdout", 274 | "output_type": "stream", 275 | "text": [ 276 | "1.91 ms ± 120 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" 277 | ] 278 | } 279 | ], 280 | "source": [ 281 | "%timeit fc_model(x, training=False)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": {}, 287 | "source": [ 288 | "## Full Model" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": 15, 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [ 297 | "inp = tf.keras.Input(shape=(1000, 4))\n", 298 | "\n", 299 | "# conv mxp 1\n", 300 | "x = tf.keras.layers.Conv1D(\n", 301 | " 300, 19, strides=1, padding='same', activation='relu', name='conv1')(inp)\n", 302 | "x = tf.keras.layers.BatchNormalization()(x)\n", 303 | "x = tf.keras.layers.MaxPool1D(3)(x)\n", 304 | "\n", 305 | "# conv mxp 2\n", 306 | "x = tf.keras.layers.Conv1D(\n", 307 | " 200, 11, strides=1, padding='same', activation='relu', name='conv2')(x)\n", 308 | "x = tf.keras.layers.BatchNormalization()(x)\n", 309 | "x = tf.keras.layers.MaxPool1D(4)(x)\n", 310 | "\n", 311 | "# conv mxp 3\n", 312 | "x = tf.keras.layers.Conv1D(\n", 313 | " 200, 7, strides=1, padding='same', activation='relu', name='conv3')(x)\n", 314 | "x = tf.keras.layers.BatchNormalization()(x)\n", 315 | "x = tf.keras.layers.MaxPool1D(4)(x)\n", 316 | "\n", 317 | "# fc\n", 318 | "x = tf.keras.layers.Flatten()(x)\n", 319 | "x = tf.keras.layers.Dense(1000, activation='relu', name='fc1')(x)\n", 320 | "x = tf.keras.layers.BatchNormalization()(x)\n", 321 | "x = tf.keras.layers.Dense(1000, activation='relu', name='fc2')(x)\n", 322 | "x = tf.keras.layers.BatchNormalization()(x)\n", 323 | "x = tf.keras.layers.Dense(1, name='fc3')(x)\n", 324 | "\n", 325 | "\n", 326 | "full_model = tf.keras.Model(inp, x)" 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": 17, 332 | "metadata": {}, 333 | "outputs": [ 334 | { 335 | "data": { 336 | "text/plain": [ 337 | "'/job:localhost/replica:0/task:0/device:GPU:0'" 338 | ] 339 | }, 340 | "execution_count": 17, 341 | "metadata": {}, 342 | "output_type": "execute_result" 343 | } 344 | ], 345 | "source": [ 346 | "x = tf.constant(np.random.random((512,1000,4)))*1\n", 347 | "x.device " 348 | ] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": 18, 353 | "metadata": {}, 354 | "outputs": [ 355 | { 356 | "data": { 357 | "text/plain": [ 358 | "" 369 | ] 370 | }, 371 | "execution_count": 18, 372 | "metadata": {}, 373 | "output_type": "execute_result" 374 | } 375 | ], 376 | "source": [ 377 | "# dry run\n", 378 | "full_model(x[:10], training=False)" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": 19, 384 | "metadata": {}, 385 | "outputs": [ 386 | { 387 | "name": "stdout", 388 | "output_type": "stream", 389 | "text": [ 390 | "77.2 ms ± 66.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" 391 | ] 392 | } 393 | ], 394 | "source": [ 395 | "%timeit full_model(x, training=False)" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": null, 401 | "metadata": {}, 402 | "outputs": [], 403 | "source": [] 404 | } 405 | ], 406 | "metadata": { 407 | "kernelspec": { 408 | "display_name": "Python 3", 409 | "language": "python", 410 | "name": "python3" 411 | }, 412 | "language_info": { 413 | "codemirror_mode": { 414 | "name": "ipython", 415 | "version": 3 416 | }, 417 | "file_extension": ".py", 418 | "mimetype": "text/x-python", 419 | "name": "python", 420 | "nbconvert_exporter": "python", 421 | "pygments_lexer": "ipython3", 422 | "version": "3.7.8" 423 | } 424 | }, 425 | "nbformat": 4, 426 | "nbformat_minor": 4 427 | } 428 | -------------------------------------------------------------------------------- /notebooks/seq_to_np.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Seq to Numpy" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 5, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "f = open('test.seq.txt')\n", 26 | "d = [x.strip() for x in f]\n", 27 | "f.close()" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 5, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "def dna_to_one_hot(seqs):\n", 37 | " \"\"\"\n", 38 | " Converts a list of DNA (\"ACGT\") sequences to one-hot encodings, where the\n", 39 | " position of 1s is ordered alphabetically by \"ACGT\". `seqs` must be a list\n", 40 | " of N strings, where every string is the same length L. Returns an N x L x 4\n", 41 | " NumPy array of one-hot encodings, in the same order as the input sequences.\n", 42 | " All bases will be converted to upper-case prior to performing the encoding.\n", 43 | " Any bases that are not \"ACGT\" will be given an encoding of all 0s.\n", 44 | " \"\"\"\n", 45 | " seq_len = len(seqs[0])\n", 46 | " assert np.all(np.array([len(s) for s in seqs]) == seq_len)\n", 47 | "\n", 48 | " # Join all sequences together into one long string, all uppercase\n", 49 | " seq_concat = \"\".join(seqs).upper()\n", 50 | "\n", 51 | " one_hot_map = np.identity(5)[:, :-1]\n", 52 | "\n", 53 | " # Convert string into array of ASCII character codes;\n", 54 | " base_vals = np.frombuffer(bytearray(seq_concat, \"utf8\"), dtype=np.int8)\n", 55 | "\n", 56 | " # Anything that's not an A, C, G, or T gets assigned a higher code\n", 57 | " base_vals[~np.isin(base_vals, np.array([65, 67, 71, 84]))] = 85\n", 58 | "\n", 59 | " # Convert the codes into indices in [0, 4], in ascending order by code\n", 60 | " _, base_inds = np.unique(base_vals, return_inverse=True)\n", 61 | "\n", 62 | " # Get the one-hot encoding for those indices, and reshape back to separate\n", 63 | " return one_hot_map[base_inds].reshape((len(seqs), seq_len, 4))" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 6, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "seqs = dna_to_one_hot(d)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 7, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "np.save(\"test.seq.npy\", seqs)" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [] 90 | } 91 | ], 92 | "metadata": { 93 | "kernelspec": { 94 | "display_name": "Python 3", 95 | "language": "python", 96 | "name": "python3" 97 | }, 98 | "language_info": { 99 | "codemirror_mode": { 100 | "name": "ipython", 101 | "version": 3 102 | }, 103 | "file_extension": ".py", 104 | "mimetype": "text/x-python", 105 | "name": "python", 106 | "nbconvert_exporter": "python", 107 | "pygments_lexer": "ipython3", 108 | "version": "3.7.9" 109 | } 110 | }, 111 | "nbformat": 4, 112 | "nbformat_minor": 4 113 | } 114 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "fastism" 3 | version = "0.5.1" 4 | description = "Fast In-silico Mutagenesis for Convolution-based Neural Networks" 5 | authors = ["Surag Nair "] 6 | license = "MIT" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.6" 10 | tensorflow = "^2.3.0" 11 | pydot = "^1.4.1" 12 | 13 | [tool.poetry.dev-dependencies] 14 | 15 | [build-system] 16 | requires = ["poetry>=0.12"] 17 | build-backend = "poetry.masonry.api" 18 | -------------------------------------------------------------------------------- /test/context.py: -------------------------------------------------------------------------------- 1 | # based on https://docs.python-guide.org/writing/structure/#test-suite 2 | import os 3 | import sys 4 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 5 | 6 | import fastISM -------------------------------------------------------------------------------- /test/test_cropping.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import unittest 3 | 4 | from context import fastISM 5 | 6 | 7 | class TestCropping(unittest.TestCase): 8 | # many tests modified from test_simple_skip_conn_arcitectures.py 9 | def test_conv_crop_even_fc(self): 10 | # inp -> C -> Crop -> D -> y 11 | inp = tf.keras.Input((100, 4)) 12 | x = tf.keras.layers.Conv1D(20, 3)(inp) 13 | x = tf.keras.layers.Cropping1D((4,4))(x) 14 | x = tf.keras.layers.Flatten()(x) 15 | x = tf.keras.layers.Dense(1)(x) 16 | model = tf.keras.Model(inputs=inp, outputs=x) 17 | 18 | fast_ism_model = fastISM.FastISM( 19 | model, test_correctness=False) 20 | 21 | self.assertTrue(fast_ism_model.test_correctness()) 22 | 23 | def test_conv_crop_odd_fc(self): 24 | # inp -> C -> Crop -> D -> y 25 | inp = tf.keras.Input((100, 4)) 26 | x = tf.keras.layers.Conv1D(20, 3)(inp) 27 | x = tf.keras.layers.Cropping1D((5,3))(x) 28 | x = tf.keras.layers.Flatten()(x) 29 | x = tf.keras.layers.Dense(1)(x) 30 | model = tf.keras.Model(inputs=inp, outputs=x) 31 | 32 | fast_ism_model = fastISM.FastISM( 33 | model, test_correctness=False) 34 | 35 | self.assertTrue(fast_ism_model.test_correctness()) 36 | 37 | def test_conv_crop_even_odd_fc(self): 38 | # inp -> C -> Crop -> D -> y 39 | inp = tf.keras.Input((100, 4)) 40 | x = tf.keras.layers.Conv1D(20, 3)(inp) 41 | x = tf.keras.layers.Cropping1D((4,5))(x) 42 | x = tf.keras.layers.Flatten()(x) 43 | x = tf.keras.layers.Dense(1)(x) 44 | model = tf.keras.Model(inputs=inp, outputs=x) 45 | 46 | fast_ism_model = fastISM.FastISM( 47 | model, test_correctness=False) 48 | 49 | self.assertTrue(fast_ism_model.test_correctness()) 50 | 51 | def test_conv_double_crop_fc(self): 52 | # inp -> C -> Crop -> Crop -> D -> y 53 | inp = tf.keras.Input((100, 4)) 54 | x = tf.keras.layers.Conv1D(20, 3)(inp) 55 | x = tf.keras.layers.Cropping1D((2,0))(x) 56 | x = tf.keras.layers.Cropping1D((1,3))(x) 57 | x = tf.keras.layers.Flatten()(x) 58 | x = tf.keras.layers.Dense(1)(x) 59 | model = tf.keras.Model(inputs=inp, outputs=x) 60 | 61 | fast_ism_model = fastISM.FastISM( 62 | model, test_correctness=False) 63 | 64 | self.assertTrue(fast_ism_model.test_correctness()) 65 | 66 | def test_conv_crop_add_two_fc(self): 67 | # inp -> C -> C-> Add -> D -> y 68 | # |_Crop__^ 69 | inp = tf.keras.Input((100, 4)) 70 | x = tf.keras.layers.Conv1D(20, 3)(inp) 71 | x1 = tf.keras.layers.Conv1D(20, 3, padding='valid')(x) 72 | x = tf.keras.layers.Cropping1D((1,1))(x) 73 | x = tf.keras.layers.Add()([x, x1]) 74 | x = tf.keras.layers.Flatten()(x) 75 | y = tf.keras.layers.Dense(1)(x) 76 | model = tf.keras.Model(inputs=inp, outputs=y) 77 | 78 | fast_ism_model = fastISM.FastISM( 79 | model, test_correctness=False) 80 | 81 | self.assertTrue(fast_ism_model.test_correctness()) 82 | 83 | def test_conv_add_three_fc(self): 84 | # ^-C-Crop-| 85 | # inp -> C -> C-> Add -> D -> y 86 | # |_Crop__^ 87 | inp = tf.keras.Input((100, 4)) 88 | x = tf.keras.layers.Conv1D(20, 3)(inp) 89 | x1 = tf.keras.layers.Conv1D(20, 3, padding='valid')(x) 90 | x2 = tf.keras.layers.Conv1D(20, 5, padding='valid')(x) 91 | 92 | x = tf.keras.layers.Cropping1D((1,3))(x) 93 | x1 = tf.keras.layers.Cropping1D((2,0))(x1) 94 | 95 | x = tf.keras.layers.Add()([x, x1, x2]) 96 | x = tf.keras.layers.Flatten()(x) 97 | y = tf.keras.layers.Dense(1)(x) 98 | model = tf.keras.Model(inputs=inp, outputs=y) 99 | 100 | fast_ism_model = fastISM.FastISM( 101 | model, test_correctness=False) 102 | 103 | self.assertTrue(fast_ism_model.test_correctness()) 104 | 105 | def test_skip_crop_then_crop_mxp(self): 106 | # __Crop___ 107 | # ^ | 108 | # inp -> C -> C-> Add -> Crop -> MXP -> [without Flatten!] D -> y 109 | # y has output dim [_,5] per example 110 | inp = tf.keras.Input((100, 4)) 111 | x = tf.keras.layers.Conv1D(20, 3)(inp) 112 | x1 = tf.keras.layers.Conv1D(20, 3, padding='valid')(x) 113 | x = tf.keras.layers.Cropping1D((0,2))(x) 114 | x1 = tf.keras.layers.Add()([x, x1]) 115 | x1 = tf.keras.layers.Cropping1D((1,1))(x1) 116 | x2 = tf.keras.layers.MaxPooling1D(3)(x1) 117 | 118 | y = tf.keras.layers.Dense(5)(x2) 119 | model = tf.keras.Model(inputs=inp, outputs=y) 120 | 121 | fast_ism_model = fastISM.FastISM( 122 | model, test_correctness=False) 123 | 124 | self.assertTrue(fast_ism_model.test_correctness()) 125 | 126 | def test_mini_dense_net_1(self): 127 | # __Crop___ ___Crop____ 128 | # ^ | ^ | 129 | # inp -> C -> C-> Add -> C -> Add -> D -> y 130 | # |________Crop________^ 131 | inp = tf.keras.Input((100, 4)) 132 | x = tf.keras.layers.Conv1D(20, 3)(inp) 133 | x1 = tf.keras.layers.Conv1D(20, 3, padding='valid')(x) 134 | x_crop1 = tf.keras.layers.Cropping1D((1,1))(x) 135 | x1 = tf.keras.layers.Add()([x_crop1, x1]) 136 | x2 = tf.keras.layers.Conv1D(20, 5, padding='valid')(x1) 137 | 138 | x1 = tf.keras.layers.Cropping1D((2,2))(x1) 139 | x_crop2 = tf.keras.layers.Cropping1D((2,4))(x) 140 | 141 | x2 = tf.keras.layers.Add()([x_crop2, x1, x2]) 142 | x2 = tf.keras.layers.Flatten()(x2) 143 | y = tf.keras.layers.Dense(1)(x2) 144 | model = tf.keras.Model(inputs=inp, outputs=y) 145 | 146 | fast_ism_model = fastISM.FastISM( 147 | model, test_correctness=False) 148 | 149 | self.assertTrue(fast_ism_model.test_correctness()) 150 | 151 | def test_mini_dense_net_2(self): 152 | # __Crop___ ___Crop____ _________ ___________ 153 | # ^ | ^ | ^ | ^ | 154 | # inp -> C -> C-> Add -> C -> Add -> MXP -> C -> C-> Add -> C -> Add -> D -> y 155 | # |______Crop__________^ |____________________^ 156 | # |_________________________Crop_________________________^ 157 | inp = tf.keras.Input((100, 4)) 158 | x = tf.keras.layers.Conv1D(10, 2)(inp) 159 | x1 = tf.keras.layers.Conv1D(10, 3, padding='valid')(x) 160 | x_crop1 = tf.keras.layers.Cropping1D((1,1))(x) 161 | x1 = tf.keras.layers.Add()([x_crop1, x1]) 162 | x2 = tf.keras.layers.Conv1D(10, 5, padding='valid')(x1) 163 | 164 | x1 = tf.keras.layers.Cropping1D((2,2))(x1) 165 | x_crop2 = tf.keras.layers.Cropping1D((2,4))(x) 166 | 167 | x2 = tf.keras.layers.Add()([x_crop2, x1, x2]) 168 | 169 | x2 = tf.keras.layers.MaxPooling1D(3)(x2) 170 | x2 = tf.keras.layers.Conv1D(10, 2)(x2) 171 | 172 | x3 = tf.keras.layers.Conv1D(10, 7, padding='same')(x2) 173 | x3 = tf.keras.layers.Maximum()([x2, x3]) 174 | x4 = tf.keras.layers.Conv1D(10, 4, padding='same')(x3) 175 | 176 | # 99 -> 30 177 | x_crop3 = tf.keras.layers.Cropping1D((20,49))(x) 178 | x4 = tf.keras.layers.Add()([x_crop3, x2, x3, x4]) 179 | 180 | x4 = tf.keras.layers.Flatten()(x4) 181 | y = tf.keras.layers.Dense(1)(x4) 182 | model = tf.keras.Model(inputs=inp, outputs=y) 183 | 184 | fast_ism_model = fastISM.FastISM( 185 | model, test_correctness=False) 186 | 187 | self.assertTrue(fast_ism_model.test_correctness()) 188 | 189 | 190 | if __name__ == '__main__': 191 | unittest.main() 192 | -------------------------------------------------------------------------------- /test/test_custom_stop_layer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import unittest 3 | 4 | from context import fastISM 5 | from fastISM.models.bpnet import bpnet_model 6 | 7 | 8 | class TestCustomStopLayer(unittest.TestCase): 9 | # testing introducing stop layers at intermediate nodes (early_stop_layers) 10 | # that are not necessarily in STOP_LAYERS (i.e. not a dense/flatten etc 11 | # layer) could be conv/add layers as well. Would be useful if perturbed 12 | # width increases quickly at spans large fraction of intermediate conv 13 | # layers 14 | 15 | def test_two_conv_addstop_fc(self): 16 | # inp --> C -> Add (stop) -> D -> y 17 | # |-> C ----^ 18 | inp = tf.keras.Input((100, 4)) 19 | x1 = tf.keras.layers.Conv1D(20, 3)(inp) 20 | x2 = tf.keras.layers.Conv1D(20, 3)(inp) 21 | 22 | stop_add = tf.keras.layers.Add() 23 | x = stop_add([x1, x2]) 24 | 25 | x = tf.keras.layers.Flatten()(x) 26 | x = tf.keras.layers.Dense(1)(x) 27 | model = tf.keras.Model(inputs=inp, outputs=x) 28 | 29 | fast_ism_model = fastISM.FastISM( 30 | model, 31 | early_stop_layers=stop_add.name, 32 | test_correctness=False) 33 | 34 | self.assertTrue(fast_ism_model.test_correctness()) 35 | 36 | def test_two_conv_addstop_skip_fc(self): 37 | # inp --> C -> Add (stop) -> D --> Add -> y 38 | # |-> C -> ^ ----------> D -----^ 39 | inp = tf.keras.Input((100, 4)) 40 | x1 = tf.keras.layers.Conv1D(20, 3)(inp) 41 | x2 = tf.keras.layers.Conv1D(20, 3)(inp) 42 | 43 | stop_add = tf.keras.layers.Add() 44 | x = stop_add([x1, x2]) 45 | 46 | x = tf.keras.layers.Flatten()(x) 47 | x = tf.keras.layers.Dense(10)(x) 48 | 49 | x2 = tf.keras.layers.Flatten()(x2) 50 | x2 = tf.keras.layers.Dense(10)(x2) 51 | 52 | x = tf.keras.layers.Add()([x, x2]) 53 | 54 | model = tf.keras.Model(inputs=inp, outputs=x) 55 | 56 | fast_ism_model = fastISM.FastISM( 57 | model, 58 | early_stop_layers=stop_add.name, 59 | test_correctness=False) 60 | 61 | self.assertTrue(fast_ism_model.test_correctness()) 62 | 63 | def test_conv_into_stop_segment(self): 64 | # inp --> C -> C (stop) -> Add -> D --> y 65 | # |--> C -----------^ 66 | inp = tf.keras.Input((100, 4)) 67 | x = tf.keras.layers.Conv1D(20, 3)(inp) 68 | 69 | x1 = tf.keras.layers.Conv1D(20, 3)(x) 70 | 71 | stop_conv = tf.keras.layers.Conv1D(20, 3) 72 | x = stop_conv(x) 73 | 74 | x = tf.keras.layers.Add()([x, x1]) 75 | 76 | x = tf.keras.layers.Flatten()(x) 77 | x = tf.keras.layers.Dense(1)(x) 78 | 79 | model = tf.keras.Model(inputs=inp, outputs=x) 80 | 81 | fast_ism_model = fastISM.FastISM( 82 | model, 83 | early_stop_layers=stop_conv.name, 84 | test_correctness=False) 85 | 86 | self.assertTrue(fast_ism_model.test_correctness()) 87 | 88 | def test_two_conv_maxpool_fc(self): 89 | # inp -> C -> MXP -> C -> MXP -> D -> y 90 | inp = tf.keras.Input((100, 4)) 91 | x = tf.keras.layers.Conv1D(10, 7, padding='same')(inp) 92 | x = tf.keras.layers.MaxPooling1D(3)(x) 93 | x = tf.keras.layers.Conv1D(10, 3)(x) 94 | x = tf.keras.layers.MaxPooling1D(2)(x) 95 | x = tf.keras.layers.Flatten()(x) 96 | x = tf.keras.layers.Dense(2)(x) 97 | model = tf.keras.Model(inputs=inp, outputs=x) 98 | 99 | for layer in model.layers[1:5]: 100 | fast_ism_model = fastISM.FastISM( 101 | model, 102 | early_stop_layers=layer.name, 103 | test_correctness=False) 104 | 105 | self.assertTrue(fast_ism_model.test_correctness()) 106 | 107 | def test_mini_dense_net(self): 108 | # early stops added at layers with "x" 109 | # _________ _____________ __________________ _____________ 110 | # ^ | ^ | ^ | ^ | 111 | # inp -> C -> C-> Add1 (x)-> C -> Add2(x) -> MXP (x) -> C1 (x) -> C2 -> Max1 (x) -> C -> Add -> D -> y 112 | # |_______________________^ |___________________________________^ 113 | inp = tf.keras.Input((100, 4)) 114 | x = tf.keras.layers.Conv1D(20, 3)(inp) 115 | x1 = tf.keras.layers.Conv1D(20, 3, padding='same')(x) 116 | 117 | add1 = tf.keras.layers.Add() 118 | x1 = add1([x, x1]) 119 | 120 | x2 = tf.keras.layers.Conv1D(20, 5, padding='same')(x1) 121 | 122 | add2 = tf.keras.layers.Add() 123 | x2 = add2([x, x1, x2]) 124 | 125 | mxp = tf.keras.layers.MaxPooling1D(3) 126 | x2 = mxp(x2) 127 | 128 | c1 = tf.keras.layers.Conv1D(10, 2) 129 | x2 = c1(x2) 130 | 131 | x3 = tf.keras.layers.Conv1D(10, 7, padding='same')(x2) 132 | 133 | max1 = tf.keras.layers.Maximum() 134 | x3 = max1([x2, x3]) 135 | 136 | x4 = tf.keras.layers.Conv1D(10, 4, padding='same')(x3) 137 | x4 = tf.keras.layers.Add()([x2, x3, x4]) 138 | 139 | x4 = tf.keras.layers.Flatten()(x4) 140 | y = tf.keras.layers.Dense(1)(x4) 141 | model = tf.keras.Model(inputs=inp, outputs=y) 142 | 143 | for layer in [add1, add2, mxp, c1, max1]: 144 | fast_ism_model = fastISM.FastISM( 145 | model, 146 | early_stop_layers=layer.name, 147 | test_correctness=False) 148 | 149 | self.assertTrue(fast_ism_model.test_correctness()) 150 | 151 | def test_bpnet_5_dilated_100(self): 152 | model = bpnet_model(seqlen=100, num_dilated_convs=5) 153 | 154 | conv_layers = [x.name for x in model.layers if 'conv1d' in x.name] 155 | 156 | for conv_layer in conv_layers: 157 | # try with an early stop at each of conv layers 158 | fast_ism_model = fastISM.FastISM( 159 | model, 160 | early_stop_layers=conv_layer, 161 | test_correctness=False) 162 | 163 | # seems to need lower numerical to always pass 164 | self.assertTrue(fast_ism_model.test_correctness(atol=1e-5)) 165 | 166 | def test_bpnet_9_dilated_100(self): 167 | model = bpnet_model(seqlen=100, num_dilated_convs=9) 168 | 169 | conv_layers = [x.name for x in model.layers if 'conv1d' in x.name] 170 | 171 | for conv_layer in conv_layers[-4:]: 172 | # try with an early stop at each of the last 4 conv layers 173 | fast_ism_model = fastISM.FastISM( 174 | model, 175 | early_stop_layers=conv_layer, 176 | test_correctness=False) 177 | 178 | # seems to need lower numerical to always pass 179 | self.assertTrue(fast_ism_model.test_correctness(atol=1e-5)) 180 | 181 | 182 | if __name__ == '__main__': 183 | unittest.main() 184 | -------------------------------------------------------------------------------- /test/test_example_architectures.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import tensorflow as tf 3 | 4 | from context import fastISM 5 | from fastISM.models.basset import basset_model 6 | from fastISM.models.factorized_basset import factorized_basset_model 7 | from fastISM.models.bpnet import bpnet_model 8 | from fastISM.models.bpnet_dense import bpnet_dense_model 9 | from random import sample 10 | 11 | # Takes a few mins! 12 | 13 | 14 | class TestExampleArchitectures(unittest.TestCase): 15 | def test_basset_200(self): 16 | model = basset_model(seqlen=200) 17 | 18 | fast_ism_model = fastISM.FastISM( 19 | model, test_correctness=False) 20 | 21 | self.assertTrue(fast_ism_model.test_correctness()) 22 | 23 | def test_basset_500(self): 24 | model = basset_model(seqlen=500) 25 | 26 | fast_ism_model = fastISM.FastISM( 27 | model, test_correctness=False) 28 | 29 | self.assertTrue(fast_ism_model.test_correctness()) 30 | 31 | def test_factorized_basset_200(self): 32 | model = factorized_basset_model(seqlen=200) 33 | 34 | fast_ism_model = fastISM.FastISM( 35 | model, test_correctness=False) 36 | 37 | self.assertTrue(fast_ism_model.test_correctness()) 38 | 39 | def test_factorized_basset_500(self): 40 | model = factorized_basset_model(seqlen=500) 41 | 42 | fast_ism_model = fastISM.FastISM( 43 | model, test_correctness=False) 44 | 45 | self.assertTrue(fast_ism_model.test_correctness()) 46 | 47 | def test_bpnet_5_dilated_500(self): 48 | model = bpnet_model(seqlen=500, num_dilated_convs=5) 49 | 50 | fast_ism_model = fastISM.FastISM( 51 | model, test_correctness=False) 52 | 53 | self.assertTrue(fast_ism_model.test_correctness()) 54 | 55 | def test_bpnet_9_dilated_100(self): 56 | model = bpnet_model(seqlen=100, num_dilated_convs=9) 57 | 58 | fast_ism_model = fastISM.FastISM( 59 | model, test_correctness=False) 60 | 61 | # seems to need lower numerical to always pass 62 | self.assertTrue(fast_ism_model.test_correctness(atol=1e-5)) 63 | 64 | def test_bpnet_9_dilated_500(self): 65 | model = bpnet_model(seqlen=500, num_dilated_convs=9) 66 | 67 | fast_ism_model = fastISM.FastISM( 68 | model, test_correctness=False) 69 | 70 | # seems to need lower numerical to always pass 71 | self.assertTrue(fast_ism_model.test_correctness(atol=1e-5)) 72 | 73 | def test_bpnet_dense_2_672(self): 74 | model = bpnet_dense_model(inlen=672, outlen=500, ndl=2) 75 | 76 | # too slow for all, so select randomly (preference for edges) 77 | change_ranges = [(x, x+1) for x in range(0, 10)] + \ 78 | [(x, x+1) for x in sorted(sample(range(10, 662), 50))] + \ 79 | [(x, x+1) for x in range(662, 672)] 80 | 81 | fast_ism_model = fastISM.FastISM( 82 | model, 83 | change_ranges=change_ranges, 84 | early_stop_layers='profile_out_prebias', 85 | test_correctness=False) 86 | 87 | # seems to need lower numerical to always pass 88 | self.assertTrue(fast_ism_model.test_correctness(atol=1e-5)) 89 | 90 | def test_bpnet_dense_6_1346(self): 91 | model = bpnet_dense_model(inlen=1346, outlen=1000, filters=8, ndl=6) 92 | 93 | # too slow for all, so select randomly (preference for edges) 94 | change_ranges = [(x, x+1) for x in range(0, 10)] + \ 95 | [(x, x+1) for x in sorted(sample(range(10, 1336), 50))] + \ 96 | [(x, x+1) for x in range(1336, 1346)] 97 | 98 | fast_ism_model = fastISM.FastISM( 99 | model, 100 | change_ranges=change_ranges, 101 | early_stop_layers='profile_out_prebias', 102 | test_correctness=False) 103 | 104 | # seems to need lower numerical to always pass 105 | self.assertTrue(fast_ism_model.test_correctness(atol=1e-5)) 106 | 107 | 108 | if __name__ == '__main__': 109 | unittest.main() 110 | -------------------------------------------------------------------------------- /test/test_simple_multi_in_architectures.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import unittest 3 | 4 | from context import fastISM 5 | 6 | 7 | class TestSimpleMultiInArchitectures(unittest.TestCase): 8 | def test_one_alt_inp_conv_add_error(self): 9 | # inp_seq -> C -> Add -> D -> y 10 | # inp_alt -> C -----^ 11 | # 12 | # Currently not supported to mix alternate with seq 13 | # before a STOP_LAYER. Should raise NotImplementedError 14 | inp_seq = tf.keras.Input((100, 4)) 15 | inp_alt = tf.keras.Input((100, 4)) 16 | x1 = tf.keras.layers.Conv1D(20, 3)(inp_seq) 17 | x2 = tf.keras.layers.Conv1D(20, 3)(inp_alt) 18 | x = tf.keras.layers.Add()([x1, x2]) 19 | x = tf.keras.layers.Dense(1)(x) 20 | 21 | # both order of inputs 22 | model1 = tf.keras.Model(inputs=[inp_seq, inp_alt], outputs=x) 23 | model2 = tf.keras.Model(inputs=[inp_alt, inp_seq], outputs=x) 24 | 25 | with self.assertRaises(NotImplementedError): 26 | fastISM.FastISM(model1, seq_input_idx=0, test_correctness=False) 27 | 28 | with self.assertRaises(NotImplementedError): 29 | fastISM.FastISM(model2, seq_input_idx=0, test_correctness=False) 30 | 31 | def test_one_alt_inp_conv_cat_fc(self): 32 | # inp_seq -> C -> Concat -> D -> y 33 | # inp_alt --^ 34 | inp_seq = tf.keras.Input((100, 4)) 35 | inp_alt = tf.keras.Input((10,)) 36 | x = tf.keras.layers.Conv1D(20, 3)(inp_seq) 37 | x = tf.keras.layers.Flatten()(x) 38 | x = tf.keras.layers.Concatenate()([x, inp_alt]) 39 | x = tf.keras.layers.Dense(1)(x) 40 | 41 | # both order of inputs 42 | model1 = tf.keras.Model(inputs=[inp_seq, inp_alt], outputs=x) 43 | model2 = tf.keras.Model(inputs=[inp_alt, inp_seq], outputs=x) 44 | 45 | fast_ism_model1 = fastISM.FastISM( 46 | model1, seq_input_idx=0, test_correctness=False) 47 | fast_ism_model2 = fastISM.FastISM( 48 | model2, seq_input_idx=1, test_correctness=False) 49 | 50 | self.assertTrue(fast_ism_model1.test_correctness()) 51 | self.assertTrue(fast_ism_model2.test_correctness()) 52 | 53 | def test_one_alt_inp_process_conv_cat_fc(self): 54 | # inp_seq -> C -> Concat -> D -> y 55 | # inp_alt -> D -^ 56 | inp_seq = tf.keras.Input((100, 4)) 57 | inp_alt = tf.keras.Input((10,)) 58 | x = tf.keras.layers.Conv1D(20, 3)(inp_seq) 59 | x = tf.keras.layers.Flatten()(x) 60 | x_alt = tf.keras.layers.Dense(10)(inp_alt) 61 | x = tf.keras.layers.Concatenate()([x, x_alt]) 62 | x = tf.keras.layers.Dense(1)(x) 63 | 64 | # both order of inputs 65 | model1 = tf.keras.Model(inputs=[inp_seq, inp_alt], outputs=x) 66 | model2 = tf.keras.Model(inputs=[inp_alt, inp_seq], outputs=x) 67 | 68 | fast_ism_model1 = fastISM.FastISM( 69 | model1, seq_input_idx=0, test_correctness=False) 70 | fast_ism_model2 = fastISM.FastISM( 71 | model2, seq_input_idx=1, test_correctness=False) 72 | 73 | self.assertTrue(fast_ism_model1.test_correctness()) 74 | self.assertTrue(fast_ism_model2.test_correctness()) 75 | 76 | def test_two_alt_inp_conv_cat_fc(self): 77 | # inp_alt1 -| 78 | # inp_seq -> C -> Concat -> D -> y 79 | # inp_alt2 --^ 80 | inp_seq = tf.keras.Input((100, 4)) 81 | inp_alt1 = tf.keras.Input((10,)) 82 | inp_alt2 = tf.keras.Input((10,)) 83 | x = tf.keras.layers.Conv1D(20, 3)(inp_seq) 84 | x = tf.keras.layers.Flatten()(x) 85 | x = tf.keras.layers.Concatenate()([x, inp_alt1, inp_alt2]) 86 | x = tf.keras.layers.Dense(1)(x) 87 | 88 | # different order of inputs 89 | model1 = tf.keras.Model( 90 | inputs=[inp_seq, inp_alt1, inp_alt2], outputs=x) 91 | model2 = tf.keras.Model( 92 | inputs=[inp_alt2, inp_seq, inp_alt1], outputs=x) 93 | model3 = tf.keras.Model( 94 | inputs=[inp_alt2, inp_alt1, inp_seq], outputs=x) 95 | 96 | fast_ism_model1 = fastISM.FastISM( 97 | model1, seq_input_idx=0, test_correctness=False) 98 | fast_ism_model2 = fastISM.FastISM( 99 | model2, seq_input_idx=1, test_correctness=False) 100 | fast_ism_model3 = fastISM.FastISM( 101 | model3, seq_input_idx=2, test_correctness=False) 102 | 103 | self.assertTrue(fast_ism_model1.test_correctness()) 104 | self.assertTrue(fast_ism_model2.test_correctness()) 105 | self.assertTrue(fast_ism_model3.test_correctness()) 106 | 107 | def test_two_alt_inp_conv_stagger(self): 108 | # inp_alt1 -| 109 | # inp_seq -> C -> Concat -> D -> Concat -> D -> y 110 | # inp_alt2 --^ 111 | inp_seq = tf.keras.Input((100, 4)) 112 | inp_alt1 = tf.keras.Input((10,)) 113 | inp_alt2 = tf.keras.Input((10,)) 114 | x = tf.keras.layers.Conv1D(20, 3)(inp_seq) 115 | x = tf.keras.layers.Flatten()(x) 116 | x = tf.keras.layers.Concatenate()([x, inp_alt1]) 117 | x = tf.keras.layers.Dense(10)(x) 118 | x = tf.keras.layers.Concatenate()([x, inp_alt2]) 119 | x = tf.keras.layers.Dense(1)(x) 120 | 121 | # different order of inputs 122 | model1 = tf.keras.Model( 123 | inputs=[inp_seq, inp_alt1, inp_alt2], outputs=x) 124 | model2 = tf.keras.Model( 125 | inputs=[inp_alt2, inp_seq, inp_alt1], outputs=x) 126 | model3 = tf.keras.Model( 127 | inputs=[inp_alt2, inp_alt1, inp_seq], outputs=x) 128 | 129 | fast_ism_model1 = fastISM.FastISM( 130 | model1, seq_input_idx=0, test_correctness=False) 131 | fast_ism_model2 = fastISM.FastISM( 132 | model2, seq_input_idx=1, test_correctness=False) 133 | fast_ism_model3 = fastISM.FastISM( 134 | model3, seq_input_idx=2, test_correctness=False) 135 | 136 | self.assertTrue(fast_ism_model1.test_correctness()) 137 | self.assertTrue(fast_ism_model2.test_correctness()) 138 | self.assertTrue(fast_ism_model3.test_correctness()) 139 | 140 | def test_two_alt_interact(self): 141 | # inp_seq -> C -> Concat -> D -> y 142 | # inp_alt1 -> ADD ---^ 143 | # inp_alt2 ----^ 144 | inp_seq = tf.keras.Input((100, 4)) 145 | inp_alt1 = tf.keras.Input((10,)) 146 | inp_alt2 = tf.keras.Input((10,)) 147 | inp_sum = tf.keras.layers.Add()([inp_alt1, inp_alt2]) 148 | x = tf.keras.layers.Conv1D(20, 3)(inp_seq) 149 | x = tf.keras.layers.Flatten()(x) 150 | x = tf.keras.layers.Concatenate()([x, inp_sum]) 151 | x = tf.keras.layers.Dense(1)(x) 152 | 153 | # different order of inputs 154 | model1 = tf.keras.Model( 155 | inputs=[inp_seq, inp_alt1, inp_alt2], outputs=x) 156 | model2 = tf.keras.Model( 157 | inputs=[inp_alt2, inp_seq, inp_alt1], outputs=x) 158 | model3 = tf.keras.Model( 159 | inputs=[inp_alt2, inp_alt1, inp_seq], outputs=x) 160 | 161 | fast_ism_model1 = fastISM.FastISM( 162 | model1, seq_input_idx=0, test_correctness=False) 163 | fast_ism_model2 = fastISM.FastISM( 164 | model2, seq_input_idx=1, test_correctness=False) 165 | fast_ism_model3 = fastISM.FastISM( 166 | model3, seq_input_idx=2, test_correctness=False) 167 | 168 | self.assertTrue(fast_ism_model1.test_correctness()) 169 | self.assertTrue(fast_ism_model2.test_correctness()) 170 | self.assertTrue(fast_ism_model3.test_correctness()) 171 | 172 | def test_one_alt_conv_cat_twice_fc(self): 173 | # inp_seq -> C -> Concat -> D -> Concat -> D -> y 174 | # inp_alt --^-----------^ 175 | inp_seq = tf.keras.Input((100, 4)) 176 | inp_alt = tf.keras.Input((10,)) 177 | x = tf.keras.layers.Conv1D(20, 3)(inp_seq) 178 | x = tf.keras.layers.Flatten()(x) 179 | x = tf.keras.layers.Concatenate()([x, inp_alt]) 180 | x = tf.keras.layers.Dense(10)(x) 181 | x = tf.keras.layers.Concatenate()([inp_alt, x]) 182 | x = tf.keras.layers.Dense(1)(x) 183 | 184 | # both order of inputs 185 | model1 = tf.keras.Model(inputs=[inp_seq, inp_alt], outputs=x) 186 | model2 = tf.keras.Model(inputs=[inp_alt, inp_seq], outputs=x) 187 | 188 | fast_ism_model1 = fastISM.FastISM( 189 | model1, seq_input_idx=0, test_correctness=False) 190 | fast_ism_model2 = fastISM.FastISM( 191 | model2, seq_input_idx=1, test_correctness=False) 192 | 193 | self.assertTrue(fast_ism_model1.test_correctness()) 194 | self.assertTrue(fast_ism_model2.test_correctness()) 195 | 196 | def test_two_alt_interact_complex(self): 197 | # inp_seq -> C -> Concat -> D ---> Concat -> D -> y 198 | # inp_alt1 -> ADD ---^--| ^ 199 | # inp_alt2 ----^----> Concat --> D --> 200 | inp_seq = tf.keras.Input((100, 4)) 201 | inp_alt1 = tf.keras.Input((10,)) 202 | inp_alt2 = tf.keras.Input((10,)) 203 | inp_sum = tf.keras.layers.Add()([inp_alt1, inp_alt2]) 204 | x = tf.keras.layers.Conv1D(20, 3)(inp_seq) 205 | x = tf.keras.layers.Flatten()(x) 206 | x = tf.keras.layers.Concatenate()([x, inp_sum]) 207 | x = tf.keras.layers.Dense(20)(x) 208 | inp_alt2_sum = tf.keras.layers.Concatenate()([inp_sum, inp_alt2]) 209 | inp_alt2_sum = tf.keras.layers.Dense(10)(inp_alt2_sum) 210 | x = tf.keras.layers.Concatenate()([inp_alt2_sum, x]) 211 | x = tf.keras.layers.Dense(1)(x) 212 | 213 | # different order of inputs 214 | model1 = tf.keras.Model( 215 | inputs=[inp_seq, inp_alt1, inp_alt2], outputs=x) 216 | model2 = tf.keras.Model( 217 | inputs=[inp_alt2, inp_seq, inp_alt1], outputs=x) 218 | model3 = tf.keras.Model( 219 | inputs=[inp_alt2, inp_alt1, inp_seq], outputs=x) 220 | 221 | fast_ism_model1 = fastISM.FastISM( 222 | model1, seq_input_idx=0, test_correctness=False) 223 | fast_ism_model2 = fastISM.FastISM( 224 | model2, seq_input_idx=1, test_correctness=False) 225 | fast_ism_model3 = fastISM.FastISM( 226 | model3, seq_input_idx=2, test_correctness=False) 227 | 228 | self.assertTrue(fast_ism_model1.test_correctness()) 229 | self.assertTrue(fast_ism_model2.test_correctness()) 230 | self.assertTrue(fast_ism_model3.test_correctness()) 231 | 232 | def test_one_alt_double_cat_three_out(self): 233 | # test multiple outputs 234 | # |----> D -> y1 235 | # inp_seq -> C -> Concat ----> D -> D -> y2 236 | # ^ | 237 | # inp_alt -> D -|-> D ->Concat -> D -> y3 238 | inp_seq = tf.keras.Input((100, 4)) 239 | inp_alt = tf.keras.Input((10,)) 240 | x = tf.keras.layers.Conv1D(20, 3)(inp_seq) 241 | x = tf.keras.layers.Flatten()(x) 242 | y1 = tf.keras.layers.Dense(1)(x) 243 | x_alt = tf.keras.layers.Dense(10)(inp_alt) 244 | x = tf.keras.layers.Concatenate()([x, x_alt]) 245 | x_alt = tf.keras.layers.Dense(10)(x_alt) 246 | x = tf.keras.layers.Dense(10)(x) 247 | x_alt = tf.keras.layers.Concatenate()([x, x_alt]) 248 | y2 = tf.keras.layers.Dense(1)(x) 249 | y3 = tf.keras.layers.Dense(1)(x_alt) 250 | 251 | # both order of inputs 252 | model1 = tf.keras.Model( 253 | inputs=[inp_seq, inp_alt], outputs=[y1, y2, y3]) 254 | model2 = tf.keras.Model( 255 | inputs=[inp_alt, inp_seq], outputs=[y1, y2, y3]) 256 | 257 | fast_ism_model1 = fastISM.FastISM( 258 | model1, seq_input_idx=0, test_correctness=False) 259 | fast_ism_model2 = fastISM.FastISM( 260 | model2, seq_input_idx=1, test_correctness=False) 261 | 262 | self.assertTrue(fast_ism_model1.test_correctness()) 263 | self.assertTrue(fast_ism_model2.test_correctness()) 264 | 265 | def test_one_alt_double_cat_three_out_10bp_change_range(self): 266 | # test multiple outputs 267 | # |----> D -> y1 268 | # inp_seq -> C -> Concat ----> D -> D -> y2 269 | # ^ | 270 | # inp_alt -> D -|-> D ->Concat -> D -> y3 271 | inp_seq = tf.keras.Input((100, 4)) 272 | inp_alt = tf.keras.Input((10,)) 273 | x = tf.keras.layers.Conv1D(20, 3)(inp_seq) 274 | x = tf.keras.layers.Flatten()(x) 275 | y1 = tf.keras.layers.Dense(1)(x) 276 | x_alt = tf.keras.layers.Dense(10)(inp_alt) 277 | x = tf.keras.layers.Concatenate()([x, x_alt]) 278 | x_alt = tf.keras.layers.Dense(10)(x_alt) 279 | x = tf.keras.layers.Dense(10)(x) 280 | x_alt = tf.keras.layers.Concatenate()([x, x_alt]) 281 | y2 = tf.keras.layers.Dense(1)(x) 282 | y3 = tf.keras.layers.Dense(1)(x_alt) 283 | 284 | # both order of inputs 285 | model1 = tf.keras.Model( 286 | inputs=[inp_seq, inp_alt], outputs=[y1, y2, y3]) 287 | model2 = tf.keras.Model( 288 | inputs=[inp_alt, inp_seq], outputs=[y1, y2, y3]) 289 | 290 | fast_ism_model1 = fastISM.FastISM( 291 | model1, seq_input_idx=0, 292 | change_ranges=[(i, i+10) for i in range(0, 100, 10)], 293 | test_correctness=False) 294 | fast_ism_model2 = fastISM.FastISM( 295 | model2, seq_input_idx=1, 296 | change_ranges=[(i, i+10) for i in range(0, 100, 10)], 297 | test_correctness=False) 298 | 299 | self.assertTrue(fast_ism_model1.test_correctness()) 300 | self.assertTrue(fast_ism_model2.test_correctness()) 301 | 302 | 303 | if __name__ == '__main__': 304 | unittest.main() 305 | -------------------------------------------------------------------------------- /test/test_simple_nested_architectures.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import unittest 3 | 4 | from context import fastISM 5 | 6 | def conv_block(input_shape=(108,4)): 7 | inp = tf.keras.Input(shape=input_shape) 8 | x = tf.keras.layers.Conv1D(20, 3, padding='same')(inp) 9 | x = tf.keras.layers.MaxPooling1D(3)(x) 10 | x = tf.keras.layers.Conv1D(20, 5, padding='same')(x) 11 | x = tf.keras.layers.MaxPooling1D(3)(x) 12 | x = tf.keras.layers.Conv1D(20, 9, padding='same')(x) 13 | x = tf.keras.layers.MaxPooling1D(3)(x) 14 | model = tf.keras.Model(inputs=inp, outputs=x) 15 | return model 16 | 17 | def res_block(input_shape=(108,20)): 18 | inp = tf.keras.Input(shape=input_shape) 19 | x = tf.keras.layers.Conv1D(20, 3, padding='same')(inp) 20 | x = tf.keras.layers.Add()([inp, x]) 21 | model = tf.keras.Model(inputs=inp, outputs=x) 22 | return model 23 | 24 | def doub_res_block(input_shape=(108,20)): 25 | inp = tf.keras.Input(shape=input_shape) 26 | x = res_block()(inp) 27 | x = res_block()(x) 28 | model = tf.keras.Model(inputs=inp, outputs=x) 29 | return model 30 | 31 | def fc_block(input_shape=(80,)): 32 | inp = tf.keras.Input(shape=input_shape) 33 | x = tf.keras.layers.Dense(10)(inp) 34 | x = tf.keras.layers.Dense(1)(x) 35 | 36 | model = tf.keras.Model(inputs=inp, outputs=x) 37 | return model 38 | 39 | def my_add_block(input_shape=(108,20)): 40 | x1 = tf.keras.Input(shape=input_shape) 41 | x2 = tf.keras.Input(shape=input_shape) 42 | y = tf.keras.layers.Add()([x1,x2]) 43 | 44 | model = tf.keras.Model(inputs=[x1,x2], outputs=y) 45 | return model 46 | 47 | def my_add_max_block(input_shape=(108,20)): 48 | x1 = tf.keras.Input(shape=input_shape) 49 | x2 = tf.keras.Input(shape=input_shape) 50 | y1 = tf.keras.layers.Add()([x1,x2]) 51 | y2 = tf.keras.layers.Maximum()([x1,x2]) 52 | 53 | model = tf.keras.Model(inputs=[x1,x2], outputs=[y1, y2]) 54 | return model 55 | 56 | 57 | def my_sub_block(input_shape=(108,20)): 58 | x1 = tf.keras.Input(shape=input_shape) 59 | x2 = tf.keras.Input(shape=input_shape) 60 | y = tf.keras.layers.Subtract()([x2,x1]) 61 | 62 | model = tf.keras.Model(inputs=[x1,x2], outputs=y) 63 | return model 64 | 65 | 66 | class TestSimpleSingleNestedArchitectures(unittest.TestCase): 67 | def test_three_conv_two_fc(self): 68 | # inp -> [ C -> M -> C -> M -> C -> M ] -> [ D -> D -> y ] 69 | convs = conv_block() 70 | fcs = fc_block() 71 | 72 | inp = tf.keras.Input((108, 4)) 73 | x = convs(inp) 74 | x = tf.keras.layers.Flatten()(x) 75 | x = fcs(x) 76 | 77 | model = tf.keras.Model(inputs=inp, outputs=x) 78 | 79 | fast_ism_model = fastISM.FastISM( 80 | model, test_correctness=False) 81 | 82 | self.assertTrue(fast_ism_model.test_correctness()) 83 | 84 | def test_conv_res_mxp_two_fc(self): 85 | # _________ 86 | # ^ | 87 | # inp -> C [ -> C -> Add ] -> M -> [ D -> D -> y ] 88 | res = res_block() 89 | fcs = fc_block(input_shape=(36*20,)) 90 | 91 | inp = tf.keras.Input((108, 4)) 92 | x = tf.keras.layers.Conv1D(20, 3, padding='same')(inp) 93 | x = res(x) 94 | x = tf.keras.layers.MaxPooling1D(3)(x) 95 | x = tf.keras.layers.Flatten()(x) 96 | x = fcs(x) 97 | 98 | model = tf.keras.Model(inputs=inp, outputs=x) 99 | 100 | fast_ism_model = fastISM.FastISM( 101 | model, test_correctness=False) 102 | 103 | self.assertTrue(fast_ism_model.test_correctness()) 104 | 105 | def test_conv_my_add_mxp_two_fc(self): 106 | # _________ 107 | # ^ | 108 | # inp -> C -> C ->[ Add ] -> M -> [ D -> D -> y ] 109 | # testing a nested block that takes in multiple inputs 110 | my_add = my_add_block() 111 | fcs = fc_block(input_shape=(36*20,)) 112 | 113 | inp = tf.keras.Input((108, 4)) 114 | x1 = tf.keras.layers.Conv1D(20, 3, padding='same')(inp) 115 | x2 = tf.keras.layers.Conv1D(20, 3, padding='same')(x1) 116 | y = my_add([x1,x2]) 117 | y = tf.keras.layers.MaxPooling1D(3)(y) 118 | y = tf.keras.layers.Flatten()(y) 119 | y = fcs(y) 120 | 121 | model = tf.keras.Model(inputs=inp, outputs=y) 122 | 123 | fast_ism_model = fastISM.FastISM( 124 | model, test_correctness=False) 125 | 126 | self.assertTrue(fast_ism_model.test_correctness()) 127 | 128 | def test_conv_my_add_max_mxp_two_fc(self): 129 | # _________ __________ 130 | # ^ | ^ | 131 | # inp -> C -> C ->[ Add/Max ] -> Add -> M -> [ D -> D -> y ] 132 | # testing a nested block that takes in multiple inputs 133 | # and returns multiple outputs 134 | my_add_max = my_add_max_block() 135 | fcs = fc_block(input_shape=(36*20,)) 136 | 137 | inp = tf.keras.Input((108, 4)) 138 | x1 = tf.keras.layers.Conv1D(20, 3, padding='same')(inp) 139 | x2 = tf.keras.layers.Conv1D(20, 3, padding='same')(x1) 140 | y1, y2 = my_add_max([x1,x2]) 141 | y = tf.keras.layers.Add()([y1,y2]) 142 | y = tf.keras.layers.MaxPooling1D(3)(y) 143 | y = tf.keras.layers.Flatten()(y) 144 | y = fcs(y) 145 | 146 | model = tf.keras.Model(inputs=inp, outputs=y) 147 | 148 | fast_ism_model = fastISM.FastISM( 149 | model, test_correctness=False) 150 | 151 | self.assertTrue(fast_ism_model.test_correctness()) 152 | 153 | def test_conv_my_sub_mxp_two_fc(self): 154 | # TODO: fails as of now since inbound_edges does not contain 155 | # the correct node order 156 | # _________ 157 | # ^ | 158 | # inp -> C -> C ->[ Sub ] -> M -> [ D -> D -> y ] 159 | # testing a nested block that takes in multiple inputs 160 | my_sub = my_sub_block() 161 | fcs = fc_block(input_shape=(36*20,)) 162 | 163 | inp = tf.keras.Input((108, 4)) 164 | x1 = tf.keras.layers.Conv1D(20, 3, padding='same')(inp) 165 | x2 = tf.keras.layers.Conv1D(20, 3, padding='same')(x1) 166 | y = my_sub([x1,x2]) 167 | y = tf.keras.layers.MaxPooling1D(3)(y) 168 | y = tf.keras.layers.Flatten()(y) 169 | y = fcs(y) 170 | 171 | model = tf.keras.Model(inputs=inp, outputs=y) 172 | 173 | fast_ism_model = fastISM.FastISM( 174 | model, test_correctness=False) 175 | 176 | self.assertTrue(fast_ism_model.test_correctness()) 177 | 178 | def test_conv_doub_res_mxp_two_fc(self): 179 | # _________ _________ 180 | # ^ | ^ | 181 | # inp -> C [ [ -> C -> Add ] -> [ -> C -> Add ] ] -> M -> [ D -> D -> y ] 182 | # doub_res_block contains 2 res_blocks within it -> double nesting 183 | doub_res = doub_res_block() 184 | fcs = fc_block(input_shape=(36*20,)) 185 | 186 | inp = tf.keras.Input((108, 4)) 187 | x = tf.keras.layers.Conv1D(20, 3, padding='same')(inp) 188 | x = doub_res(x) 189 | x = tf.keras.layers.MaxPooling1D(3)(x) 190 | x = tf.keras.layers.Flatten()(x) 191 | x = fcs(x) 192 | 193 | model = tf.keras.Model(inputs=inp, outputs=x) 194 | 195 | fast_ism_model = fastISM.FastISM( 196 | model, test_correctness=False) 197 | 198 | self.assertTrue(fast_ism_model.test_correctness()) 199 | 200 | if __name__ == '__main__': 201 | unittest.main() 202 | -------------------------------------------------------------------------------- /test/test_simple_single_in_multi_out_architectures.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import unittest 3 | 4 | from context import fastISM 5 | 6 | 7 | class TestSimpleSingleInMultiOutArchitectures(unittest.TestCase): 8 | def test_conv_two_fc(self): 9 | # /- D -> y1 10 | # inp -> C 11 | # \_ D -> y2 12 | inp = tf.keras.Input((100, 4)) 13 | x = tf.keras.layers.Conv1D(20, 3)(inp) 14 | x = tf.keras.layers.Flatten()(x) 15 | y1 = tf.keras.layers.Dense(1)(x) 16 | y2 = tf.keras.layers.Dense(1)(x) 17 | model = tf.keras.Model(inputs=inp, outputs=[y1, y2]) 18 | 19 | fast_ism_model = fastISM.FastISM( 20 | model, test_correctness=False) 21 | 22 | self.assertTrue(fast_ism_model.test_correctness()) 23 | 24 | def test_conv_three_fc(self): 25 | # /- D -> y1 26 | # inp -> C - D -> y2 27 | # \_ D -> y3 28 | inp = tf.keras.Input((100, 4)) 29 | x = tf.keras.layers.Conv1D(20, 3)(inp) 30 | x = tf.keras.layers.Flatten()(x) 31 | y1 = tf.keras.layers.Dense(1)(x) 32 | y2 = tf.keras.layers.Dense(1)(x) 33 | y3 = tf.keras.layers.Dense(1)(x) 34 | model = tf.keras.Model(inputs=inp, outputs=[y1, y2, y3]) 35 | 36 | fast_ism_model = fastISM.FastISM( 37 | model, test_correctness=False) 38 | 39 | self.assertTrue(fast_ism_model.test_correctness()) 40 | 41 | def test_conv_fc_two_head(self): 42 | # inp -> C -> D -> D -> y1 43 | # \_ D -> y2 44 | inp = tf.keras.Input((100, 4)) 45 | x = tf.keras.layers.Conv1D(20, 3)(inp) 46 | x = tf.keras.layers.Flatten()(x) 47 | x = tf.keras.layers.Dense(10)(x) 48 | y1 = tf.keras.layers.Dense(1)(x) 49 | y2 = tf.keras.layers.Dense(1)(x) 50 | model = tf.keras.Model(inputs=inp, outputs=[y1, y2]) 51 | 52 | fast_ism_model = fastISM.FastISM( 53 | model, test_correctness=False) 54 | 55 | self.assertTrue(fast_ism_model.test_correctness()) 56 | 57 | def test_two_conv_fc_per_conv(self): 58 | # /- D -> y1 59 | # inp -> C 60 | # \_ C -> D -> y2 61 | inp = tf.keras.Input((100, 4)) 62 | x1 = tf.keras.layers.Conv1D(20, 3)(inp) 63 | x2 = tf.keras.layers.Conv1D(20, 3)(x1) 64 | x1f = tf.keras.layers.Flatten()(x1) 65 | x2f = tf.keras.layers.Flatten()(x2) 66 | y1 = tf.keras.layers.Dense(1)(x1f) 67 | y2 = tf.keras.layers.Dense(1)(x2f) 68 | model = tf.keras.Model(inputs=inp, outputs=[y1, y2]) 69 | 70 | fast_ism_model = fastISM.FastISM( 71 | model, test_correctness=False) 72 | 73 | self.assertTrue(fast_ism_model.test_correctness()) 74 | 75 | def test_three_conv_maxpool_fc_per_conv(self): 76 | # /- D -> y1 77 | # inp -> C -> MX -> C -> MX -> C -> MX -> D -> y2 78 | # \_ C -> D -> y3 79 | inp = tf.keras.Input((100, 4)) 80 | x1 = tf.keras.layers.Conv1D(20, 3, padding='same')(inp) 81 | x1 = tf.keras.layers.MaxPool1D(2)(x1) 82 | x2 = tf.keras.layers.Conv1D(10, 4, padding='same')(x1) 83 | x2 = tf.keras.layers.MaxPool1D(2)(x2) 84 | x3 = tf.keras.layers.Conv1D(10, 3)(x2) 85 | x3 = tf.keras.layers.MaxPool1D(3)(x3) 86 | x1f = tf.keras.layers.Flatten()(x1) 87 | x2f = tf.keras.layers.Flatten()(x2) 88 | x3f = tf.keras.layers.Flatten()(x3) 89 | y1 = tf.keras.layers.Dense(1)(x1f) 90 | y2 = tf.keras.layers.Dense(1)(x2f) 91 | y3 = tf.keras.layers.Dense(1)(x3f) 92 | model = tf.keras.Model(inputs=inp, outputs=[y1, y2, y3]) 93 | 94 | fast_ism_model = fastISM.FastISM( 95 | model, test_correctness=False) 96 | 97 | self.assertTrue(fast_ism_model.test_correctness()) 98 | 99 | def test_input_split_conv_fc(self): 100 | # /- C -> D -> y1 101 | # inp 102 | # \_ C -> D -> y2 103 | inp = tf.keras.Input((100, 4)) 104 | x1 = tf.keras.layers.Conv1D(20, 3)(inp) 105 | x2 = tf.keras.layers.Conv1D(10, 4)(inp) 106 | x1f = tf.keras.layers.Flatten()(x1) 107 | x2f = tf.keras.layers.Flatten()(x2) 108 | y1 = tf.keras.layers.Dense(1)(x1f) 109 | y2 = tf.keras.layers.Dense(1)(x2f) 110 | model = tf.keras.Model(inputs=inp, outputs=[y1, y2]) 111 | 112 | fast_ism_model = fastISM.FastISM( 113 | model, test_correctness=False) 114 | 115 | self.assertTrue(fast_ism_model.test_correctness()) 116 | 117 | def test_input_split_complex(self): 118 | # /- C -> MXP -> C -> MXP -> D -> y1 119 | # inp \_ C -> MXP -> D -> D -> y2 120 | # \_ C -> MXP -> D -> y3 121 | inp = tf.keras.Input((100, 4)) 122 | 123 | # first row 124 | x1 = tf.keras.layers.Conv1D(20, 3, dilation_rate=2)(inp) 125 | x1 = tf.keras.layers.MaxPooling1D(2)(x1) 126 | x11 = tf.keras.layers.Conv1D(20, 3, dilation_rate=3)(x1) 127 | x11 = tf.keras.layers.MaxPooling1D(2)(x11) 128 | x11f = tf.keras.layers.Flatten()(x11) 129 | y1 = tf.keras.layers.Dense(5)(x11f) 130 | 131 | # second row 132 | x12 = tf.keras.layers.Conv1D( 133 | 15, 2, padding='same', activation='relu')(x1) 134 | x12 = tf.keras.layers.MaxPooling1D(2)(x12) 135 | x12f = tf.keras.layers.Flatten()(x12) 136 | y2 = tf.keras.layers.Dense(5)(x12f) 137 | y2 = tf.keras.layers.Dense(2, activation='tanh')(y2) 138 | 139 | # third row 140 | x2 = tf.keras.layers.Conv1D(10, 4, padding='same')(inp) 141 | x2 = tf.keras.layers.MaxPool1D(3)(x2) 142 | x2f = tf.keras.layers.Flatten()(x2) 143 | y3 = tf.keras.layers.Dense(1)(x2f) 144 | 145 | model = tf.keras.Model(inputs=inp, outputs=[y1, y2, y3]) 146 | 147 | fast_ism_model = fastISM.FastISM( 148 | model, test_correctness=False) 149 | 150 | self.assertTrue(fast_ism_model.test_correctness()) 151 | 152 | def test_input_split_complex_10bp_change_range(self): 153 | # /- C -> MXP -> C -> MXP -> D -> y1 154 | # inp \_ C -> MXP -> D -> D -> y2 155 | # \_ C -> MXP -> D -> y3 156 | inp = tf.keras.Input((100, 4)) 157 | 158 | # first row 159 | x1 = tf.keras.layers.Conv1D(20, 3, dilation_rate=2)(inp) 160 | x1 = tf.keras.layers.MaxPooling1D(2)(x1) 161 | x11 = tf.keras.layers.Conv1D(20, 3, dilation_rate=3)(x1) 162 | x11 = tf.keras.layers.MaxPooling1D(2)(x11) 163 | x11f = tf.keras.layers.Flatten()(x11) 164 | y1 = tf.keras.layers.Dense(5)(x11f) 165 | 166 | # second row 167 | x12 = tf.keras.layers.Conv1D( 168 | 15, 2, padding='same', activation='relu')(x1) 169 | x12 = tf.keras.layers.MaxPooling1D(2)(x12) 170 | x12f = tf.keras.layers.Flatten()(x12) 171 | y2 = tf.keras.layers.Dense(5)(x12f) 172 | y2 = tf.keras.layers.Dense(2, activation='tanh')(y2) 173 | 174 | # third row 175 | x2 = tf.keras.layers.Conv1D(10, 4, padding='same')(inp) 176 | x2 = tf.keras.layers.MaxPool1D(3)(x2) 177 | x2f = tf.keras.layers.Flatten()(x2) 178 | y3 = tf.keras.layers.Dense(1)(x2f) 179 | 180 | model = tf.keras.Model(inputs=inp, outputs=[y1, y2, y3]) 181 | 182 | fast_ism_model = fastISM.FastISM( 183 | model, change_ranges=[(i, i+10) for i in range(0, 100, 10)], 184 | test_correctness=False) 185 | 186 | self.assertTrue(fast_ism_model.test_correctness()) 187 | 188 | 189 | if __name__ == '__main__': 190 | unittest.main() 191 | -------------------------------------------------------------------------------- /test/test_simple_single_in_single_out_architectures.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import unittest 3 | 4 | from context import fastISM 5 | 6 | 7 | class TestSimpleSingleInSingleOutArchitectures(unittest.TestCase): 8 | def test_conv_fc(self): 9 | # inp -> C -> D -> y 10 | inp = tf.keras.Input((100, 4)) 11 | x = tf.keras.layers.Conv1D(20, 3)(inp) 12 | x = tf.keras.layers.Flatten()(x) 13 | x = tf.keras.layers.Dense(1)(x) 14 | model = tf.keras.Model(inputs=inp, outputs=x) 15 | 16 | fast_ism_model = fastISM.FastISM( 17 | model, test_correctness=False) 18 | 19 | self.assertTrue(fast_ism_model.test_correctness()) 20 | 21 | def test_conv_fc_sequential(self): 22 | # inp -> C -> D -> y 23 | # same as above but with Sequential 24 | model = tf.keras.Sequential() 25 | model.add(tf.keras.Input((100, 4))) 26 | model.add(tf.keras.layers.Conv1D(20, 3)) 27 | model.add(tf.keras.layers.Flatten()) 28 | model.add(tf.keras.layers.Dense(1)) 29 | 30 | fast_ism_model = fastISM.FastISM( 31 | model, test_correctness=False) 32 | 33 | self.assertTrue(fast_ism_model.test_correctness()) 34 | 35 | def test_conv_same_padding_fc(self): 36 | # inp -> C -> D -> y 37 | inp = tf.keras.Input((100, 4)) 38 | x = tf.keras.layers.Conv1D(20, 3, padding='same')(inp) 39 | x = tf.keras.layers.Flatten()(x) 40 | x = tf.keras.layers.Dense(1)(x) 41 | model = tf.keras.Model(inputs=inp, outputs=x) 42 | 43 | fast_ism_model = fastISM.FastISM( 44 | model, test_correctness=False) 45 | 46 | self.assertTrue(fast_ism_model.test_correctness()) 47 | 48 | def test_conv_even_kernel_fc(self): 49 | # inp -> C -> D -> y 50 | inp = tf.keras.Input((100, 4)) 51 | x = tf.keras.layers.Conv1D(20, 4)(inp) 52 | x = tf.keras.layers.Flatten()(x) 53 | x = tf.keras.layers.Dense(1)(x) 54 | model = tf.keras.Model(inputs=inp, outputs=x) 55 | 56 | fast_ism_model = fastISM.FastISM( 57 | model, test_correctness=False) 58 | 59 | self.assertTrue(fast_ism_model.test_correctness()) 60 | 61 | def test_conv_even_kernel_same_padding_fc(self): 62 | # inp -> C -> D -> y 63 | inp = tf.keras.Input((100, 4)) 64 | x = tf.keras.layers.Conv1D(20, 4, padding='same')(inp) 65 | x = tf.keras.layers.Flatten()(x) 66 | x = tf.keras.layers.Dense(1)(x) 67 | model = tf.keras.Model(inputs=inp, outputs=x) 68 | 69 | fast_ism_model = fastISM.FastISM( 70 | model, test_correctness=False) 71 | 72 | self.assertTrue(fast_ism_model.test_correctness()) 73 | 74 | def test_conv_dilated_fc(self): 75 | # inp -> C -> D -> y 76 | inp = tf.keras.Input((100, 4)) 77 | x = tf.keras.layers.Conv1D(20, 3, dilation_rate=3)(inp) 78 | x = tf.keras.layers.Flatten()(x) 79 | x = tf.keras.layers.Dense(1)(x) 80 | model = tf.keras.Model(inputs=inp, outputs=x) 81 | 82 | fast_ism_model = fastISM.FastISM( 83 | model, test_correctness=False) 84 | 85 | self.assertTrue(fast_ism_model.test_correctness()) 86 | 87 | def test_conv_maxpool_fc(self): 88 | # inp -> C -> MXP -> D -> y 89 | inp = tf.keras.Input((100, 4)) 90 | x = tf.keras.layers.Conv1D(10, 7)(inp) 91 | x = tf.keras.layers.MaxPooling1D(3)(x) 92 | x = tf.keras.layers.Flatten()(x) 93 | x = tf.keras.layers.Dense(2)(x) 94 | model = tf.keras.Model(inputs=inp, outputs=x) 95 | 96 | fast_ism_model = fastISM.FastISM( 97 | model, test_correctness=False) 98 | 99 | self.assertTrue(fast_ism_model.test_correctness()) 100 | 101 | def test_conv_two_maxpool_fc(self): 102 | # inp -> C -> MXP -> MXP -> D -> y 103 | inp = tf.keras.Input((100, 4)) 104 | x = tf.keras.layers.Conv1D(10, 7)(inp) 105 | x = tf.keras.layers.MaxPooling1D(3)(x) 106 | x = tf.keras.layers.MaxPooling1D(2)(x) 107 | x = tf.keras.layers.Flatten()(x) 108 | x = tf.keras.layers.Dense(2)(x) 109 | model = tf.keras.Model(inputs=inp, outputs=x) 110 | 111 | fast_ism_model = fastISM.FastISM( 112 | model, test_correctness=False) 113 | 114 | self.assertTrue(fast_ism_model.test_correctness()) 115 | 116 | def test_two_conv_maxpool_fc(self): 117 | # inp -> C -> MXP -> C -> MXP -> D -> y 118 | inp = tf.keras.Input((100, 4)) 119 | x = tf.keras.layers.Conv1D(10, 7, padding='same')(inp) 120 | x = tf.keras.layers.MaxPooling1D(3)(x) 121 | x = tf.keras.layers.Conv1D(10, 3)(x) 122 | x = tf.keras.layers.MaxPooling1D(2)(x) 123 | x = tf.keras.layers.Flatten()(x) 124 | x = tf.keras.layers.Dense(2)(x) 125 | model = tf.keras.Model(inputs=inp, outputs=x) 126 | 127 | fast_ism_model = fastISM.FastISM( 128 | model, test_correctness=False) 129 | 130 | self.assertTrue(fast_ism_model.test_correctness()) 131 | 132 | def test_four_conv_maxpool_two_fc_1(self): 133 | # inp -> C -> MXP -> C -> MXP -> C -> MXP -> C -> MXP -> D -> D -> y 134 | inp = tf.keras.Input((200, 4)) 135 | x = tf.keras.layers.Conv1D(10, 7, padding='same')(inp) 136 | x = tf.keras.layers.MaxPooling1D(2)(x) 137 | x = tf.keras.layers.Conv1D(20, 4, padding='same')(inp) 138 | x = tf.keras.layers.MaxPooling1D(2)(x) 139 | x = tf.keras.layers.Conv1D(30, 2, padding='valid')(x) 140 | x = tf.keras.layers.MaxPooling1D(2)(x) 141 | x = tf.keras.layers.Conv1D(10, 6, padding='same')(x) 142 | x = tf.keras.layers.MaxPooling1D(2)(x) 143 | x = tf.keras.layers.Flatten()(x) 144 | x = tf.keras.layers.Dense(20)(x) 145 | x = tf.keras.layers.Dense(1)(x) 146 | model = tf.keras.Model(inputs=inp, outputs=x) 147 | 148 | fast_ism_model = fastISM.FastISM( 149 | model, test_correctness=False) 150 | 151 | self.assertTrue(fast_ism_model.test_correctness()) 152 | 153 | def test_four_conv_maxpool_two_fc_2(self): 154 | # inp -> C -> MXP -> C -> MXP -> C -> MXP -> C -> MXP -> D -> D -> y 155 | inp = tf.keras.Input((200, 4)) 156 | x = tf.keras.layers.Conv1D(10, 3, dilation_rate=3, padding='same')(inp) 157 | x = tf.keras.layers.MaxPooling1D(2)(x) 158 | x = tf.keras.layers.Conv1D( 159 | 25, 4, padding='same', activation='relu')(inp) 160 | x = tf.keras.layers.MaxPooling1D(2)(x) 161 | x = tf.keras.layers.Conv1D( 162 | 30, 2, dilation_rate=2, padding='valid', activation='tanh')(x) 163 | x = tf.keras.layers.MaxPooling1D(2)(x) 164 | x = tf.keras.layers.Conv1D(10, 6, padding='same')(x) 165 | x = tf.keras.layers.MaxPooling1D(2)(x) 166 | x = tf.keras.layers.Flatten()(x) 167 | x = tf.keras.layers.Dense(20)(x) 168 | x = tf.keras.layers.Dense(1)(x) 169 | model = tf.keras.Model(inputs=inp, outputs=x) 170 | 171 | fast_ism_model = fastISM.FastISM( 172 | model, test_correctness=False) 173 | 174 | self.assertTrue(fast_ism_model.test_correctness()) 175 | 176 | def test_four_conv_maxpool_two_fc_3(self): 177 | # inp -> C -> MXP -> C -> MXP -> C -> MXP -> C -> MXP -> D -> D -> y 178 | inp = tf.keras.Input((200, 4)) 179 | x = tf.keras.layers.Conv1D(10, 5, use_bias=False, padding='same')(inp) 180 | x = tf.keras.layers.MaxPooling1D(2)(x) 181 | x = tf.keras.layers.Conv1D( 182 | 25, 4, padding='same', activation='relu')(inp) 183 | x = tf.keras.layers.MaxPooling1D(2)(x) 184 | x = tf.keras.layers.Conv1D(30, 2, dilation_rate=2, use_bias=False, 185 | padding='valid', activation='tanh')(x) 186 | x = tf.keras.layers.MaxPooling1D(2)(x) 187 | x = tf.keras.layers.Conv1D(10, 3, padding='same')(x) 188 | x = tf.keras.layers.MaxPooling1D(2)(x) 189 | x = tf.keras.layers.Flatten()(x) 190 | x = tf.keras.layers.Dense(10)(x) 191 | x = tf.keras.layers.Dense(1)(x) 192 | model = tf.keras.Model(inputs=inp, outputs=x) 193 | 194 | fast_ism_model = fastISM.FastISM( 195 | model, test_correctness=False) 196 | 197 | self.assertTrue(fast_ism_model.test_correctness()) 198 | 199 | def test_four_conv_maxpool_two_fc_4(self): 200 | # inp -> C -> MXP -> C -> MXP -> C -> MXP -> C -> MXP -> D -> D -> y 201 | # with Dropout and GlobalAveragePoolng1D 202 | inp = tf.keras.Input((200, 4)) 203 | x = tf.keras.layers.Conv1D(10, 5, use_bias=False, padding='same')(inp) 204 | x = tf.keras.layers.MaxPooling1D(2)(x) 205 | x = tf.keras.layers.Conv1D( 206 | 25, 4, padding='same', activation='relu')(inp) 207 | x = tf.keras.layers.Dropout(0.5)(x) 208 | x = tf.keras.layers.MaxPooling1D(2)(x) 209 | x = tf.keras.layers.Conv1D(30, 2, dilation_rate=2, use_bias=False, 210 | padding='valid', activation='tanh')(x) 211 | x = tf.keras.layers.MaxPooling1D(2)(x) 212 | x = tf.keras.layers.Dropout(0.8)(x) 213 | x = tf.keras.layers.Conv1D(10, 3, padding='same')(x) 214 | x = tf.keras.layers.MaxPooling1D(2)(x) 215 | x = tf.keras.layers.GlobalAveragePooling1D()(x) 216 | x = tf.keras.layers.Dense(10)(x) 217 | x = tf.keras.layers.Dropout(0.3)(x) 218 | x = tf.keras.layers.Dense(1)(x) 219 | model = tf.keras.Model(inputs=inp, outputs=x) 220 | 221 | fast_ism_model = fastISM.FastISM( 222 | model, test_correctness=False) 223 | 224 | self.assertTrue(fast_ism_model.test_correctness()) 225 | 226 | def test_pre_act_four_conv_maxpool_two_fc_4_10bp_change_range(self): 227 | # inp -> tanh -> C -> MXP -> C -> MXP -> C -> MXP -> C -> MXP -> D -> D -> y 228 | # with Dropout and GlobalAveragePoolng1D 229 | # activation before first conv! 230 | inp = tf.keras.Input((200, 4)) 231 | x = tf.keras.layers.Activation("tanh")(inp) 232 | x = tf.keras.layers.Conv1D(10, 5, use_bias=False, padding='same')(x) 233 | x = tf.keras.layers.MaxPooling1D(2)(x) 234 | x = tf.keras.layers.Conv1D( 235 | 25, 4, padding='same', activation='relu')(inp) 236 | x = tf.keras.layers.Dropout(0.5)(x) 237 | x = tf.keras.layers.MaxPooling1D(2)(x) 238 | x = tf.keras.layers.Conv1D(30, 2, dilation_rate=2, use_bias=False, 239 | padding='valid', activation='tanh')(x) 240 | x = tf.keras.layers.MaxPooling1D(2)(x) 241 | x = tf.keras.layers.Dropout(0.8)(x) 242 | x = tf.keras.layers.Conv1D(10, 3, padding='same')(x) 243 | x = tf.keras.layers.MaxPooling1D(2)(x) 244 | x = tf.keras.layers.GlobalAveragePooling1D()(x) 245 | x = tf.keras.layers.Dense(10)(x) 246 | x = tf.keras.layers.Dropout(0.3)(x) 247 | x = tf.keras.layers.Dense(1)(x) 248 | model = tf.keras.Model(inputs=inp, outputs=x) 249 | 250 | fast_ism_model = fastISM.FastISM( 251 | model, change_ranges=[(i, i+10) for i in range(0, 200, 10)], 252 | test_correctness=False) 253 | 254 | self.assertTrue(fast_ism_model.test_correctness()) 255 | 256 | def test_pre_act_four_conv_maxpool_two_fc_4_sequential(self): 257 | # inp -> tanh -> C -> MXP -> C -> MXP -> C -> MXP -> C -> MXP -> D -> D -> y 258 | # with Dropout and GlobalAveragePoolng1D 259 | # activation before first conv! 260 | # same as above but with Sequential 261 | model = tf.keras.Sequential() 262 | model.add(tf.keras.Input((200, 4))) 263 | model.add(tf.keras.layers.Activation("tanh")) 264 | model.add(tf.keras.layers.Conv1D( 265 | 10, 5, use_bias=False, padding='same')) 266 | model.add(tf.keras.layers.MaxPooling1D(2)) 267 | model.add(tf.keras.layers.Conv1D( 268 | 25, 4, padding='same', activation='relu')) 269 | model.add(tf.keras.layers.Dropout(0.5)) 270 | model.add(tf.keras.layers.MaxPooling1D(2)) 271 | model.add(tf.keras.layers.Conv1D(30, 2, dilation_rate=2, use_bias=False, 272 | padding='valid', activation='tanh')) 273 | model.add(tf.keras.layers.MaxPooling1D(2)) 274 | model.add(tf.keras.layers.Dropout(0.8)) 275 | model.add(tf.keras.layers.Conv1D(10, 3, padding='same')) 276 | model.add(tf.keras.layers.MaxPooling1D(2)) 277 | model.add(tf.keras.layers.GlobalAveragePooling1D()) 278 | model.add(tf.keras.layers.Dense(10)) 279 | model.add(tf.keras.layers.Dropout(0.3)) 280 | model.add(tf.keras.layers.Dense(1)) 281 | 282 | fast_ism_model = fastISM.FastISM( 283 | model, test_correctness=False) 284 | 285 | self.assertTrue(fast_ism_model.test_correctness()) 286 | 287 | 288 | if __name__ == '__main__': 289 | unittest.main() 290 | -------------------------------------------------------------------------------- /test/test_simple_skip_conn_architectures.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import unittest 3 | 4 | from context import fastISM 5 | 6 | 7 | class TestSimpleSkipConnArchitectures(unittest.TestCase): 8 | def test_conv_add_two_fc(self): 9 | # inp -> C -> C-> Add -> D -> y 10 | # |_______^ 11 | inp = tf.keras.Input((100, 4)) 12 | x = tf.keras.layers.Conv1D(20, 3)(inp) 13 | x1 = tf.keras.layers.Conv1D(20, 3, padding='same')(x) 14 | x = tf.keras.layers.Add()([x, x1]) 15 | x = tf.keras.layers.Flatten()(x) 16 | y = tf.keras.layers.Dense(1)(x) 17 | model = tf.keras.Model(inputs=inp, outputs=y) 18 | 19 | fast_ism_model = fastISM.FastISM( 20 | model, test_correctness=False) 21 | 22 | self.assertTrue(fast_ism_model.test_correctness()) 23 | 24 | def test_conv_self_add_two_fc(self): 25 | # inp -> C -> C-> Add -> D -> y 26 | # |____^ 27 | inp = tf.keras.Input((100, 4)) 28 | x = tf.keras.layers.Conv1D(20, 3)(inp) 29 | x = tf.keras.layers.Conv1D(20, 3, padding='same')(x) 30 | x = tf.keras.layers.Add()([x, x]) 31 | x = tf.keras.layers.Flatten()(x) 32 | y = tf.keras.layers.Dense(1)(x) 33 | model = tf.keras.Model(inputs=inp, outputs=y) 34 | 35 | fast_ism_model = fastISM.FastISM( 36 | model, test_correctness=False) 37 | 38 | self.assertTrue(fast_ism_model.test_correctness()) 39 | 40 | def test_conv_add_three_fc(self): 41 | # ^-- C---| 42 | # inp -> C -> C-> Add -> D -> y 43 | # |_______^ 44 | inp = tf.keras.Input((100, 4)) 45 | x = tf.keras.layers.Conv1D(20, 3)(inp) 46 | x1 = tf.keras.layers.Conv1D(20, 3, padding='same')(x) 47 | x2 = tf.keras.layers.Conv1D(20, 5, padding='same')(x) 48 | x = tf.keras.layers.Add()([x, x1, x2]) 49 | x = tf.keras.layers.Flatten()(x) 50 | y = tf.keras.layers.Dense(1)(x) 51 | model = tf.keras.Model(inputs=inp, outputs=y) 52 | 53 | fast_ism_model = fastISM.FastISM( 54 | model, test_correctness=False) 55 | 56 | self.assertTrue(fast_ism_model.test_correctness()) 57 | 58 | def test_skip_then_mxp(self): 59 | # _________ 60 | # ^ | 61 | # inp -> C -> C-> Add -> MXP -> [without Flatten!] D -> y 62 | # y has output dim [32,5] per example 63 | inp = tf.keras.Input((100, 4)) 64 | x = tf.keras.layers.Conv1D(20, 3)(inp) 65 | x1 = tf.keras.layers.Conv1D(20, 3, padding='same')(x) 66 | x1 = tf.keras.layers.Add()([x, x1]) 67 | x2 = tf.keras.layers.MaxPooling1D(3)(x1) 68 | 69 | y = tf.keras.layers.Dense(5)(x2) 70 | model = tf.keras.Model(inputs=inp, outputs=y) 71 | 72 | fast_ism_model = fastISM.FastISM( 73 | model, test_correctness=False) 74 | 75 | self.assertTrue(fast_ism_model.test_correctness()) 76 | 77 | def test_mini_dense_net_1(self): 78 | # _________ ___________ 79 | # ^ | ^ | 80 | # inp -> C -> C-> Add -> C -> Add -> D -> y 81 | # |____________________^ 82 | inp = tf.keras.Input((100, 4)) 83 | x = tf.keras.layers.Conv1D(20, 3)(inp) 84 | x1 = tf.keras.layers.Conv1D(20, 3, padding='same')(x) 85 | x1 = tf.keras.layers.Add()([x, x1]) 86 | x2 = tf.keras.layers.Conv1D(20, 5, padding='same')(x1) 87 | x2 = tf.keras.layers.Add()([x, x1, x2]) 88 | x2 = tf.keras.layers.Flatten()(x2) 89 | y = tf.keras.layers.Dense(1)(x2) 90 | model = tf.keras.Model(inputs=inp, outputs=y) 91 | 92 | fast_ism_model = fastISM.FastISM( 93 | model, test_correctness=False) 94 | 95 | self.assertTrue(fast_ism_model.test_correctness()) 96 | 97 | def test_mini_dense_net_2(self): 98 | # _________ ___________ _________ ___________ 99 | # ^ | ^ | ^ | ^ | 100 | # inp -> C -> C-> Add -> C -> Add -> MXP -> C -> C-> Add -> C -> Add -> D -> y 101 | # |____________________^ |____________________^ 102 | inp = tf.keras.Input((100, 4)) 103 | x = tf.keras.layers.Conv1D(20, 3)(inp) 104 | x1 = tf.keras.layers.Conv1D(20, 3, padding='same')(x) 105 | x1 = tf.keras.layers.Add()([x, x1]) 106 | x2 = tf.keras.layers.Conv1D(20, 5, padding='same')(x1) 107 | x2 = tf.keras.layers.Add()([x, x1, x2]) 108 | x2 = tf.keras.layers.MaxPooling1D(3)(x2) 109 | x2 = tf.keras.layers.Conv1D(10, 2)(x2) 110 | 111 | x3 = tf.keras.layers.Conv1D(10, 7, padding='same')(x2) 112 | x3 = tf.keras.layers.Maximum()([x2, x3]) 113 | x4 = tf.keras.layers.Conv1D(10, 4, padding='same')(x3) 114 | x4 = tf.keras.layers.Add()([x2, x3, x4]) 115 | 116 | x4 = tf.keras.layers.Flatten()(x4) 117 | y = tf.keras.layers.Dense(1)(x4) 118 | model = tf.keras.Model(inputs=inp, outputs=y) 119 | 120 | fast_ism_model = fastISM.FastISM( 121 | model, test_correctness=False) 122 | 123 | self.assertTrue(fast_ism_model.test_correctness()) 124 | 125 | def test_two_conv_addstop_stop_skip_fc(self): 126 | # inp --> C -> Add -> D --> Add -> y 127 | # |-> C -> ^ ---> D -----^ 128 | # skip connection between stop segments 129 | inp = tf.keras.Input((100, 4)) 130 | x1 = tf.keras.layers.Conv1D(20, 3)(inp) 131 | x2 = tf.keras.layers.Conv1D(20, 3)(inp) 132 | 133 | x = tf.keras.layers.Add()([x1, x2]) 134 | x = tf.keras.layers.Flatten()(x) 135 | x = tf.keras.layers.Dense(10)(x) 136 | 137 | x2 = tf.keras.layers.Flatten()(x2) 138 | x2 = tf.keras.layers.Dense(10)(x2) 139 | 140 | x = tf.keras.layers.Add()([x, x2]) 141 | 142 | model = tf.keras.Model(inputs=inp, outputs=x) 143 | 144 | fast_ism_model = fastISM.FastISM( 145 | model, test_correctness=False) 146 | 147 | self.assertTrue(fast_ism_model.test_correctness()) 148 | 149 | if __name__ == '__main__': 150 | unittest.main() 151 | -------------------------------------------------------------------------------- /test/test_unresolved.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import unittest 3 | 4 | from context import fastISM 5 | 6 | class TestUnresolved(unittest.TestCase): 7 | """ 8 | These are outstanding issues that need to be solved 9 | with short descriptions of the problem and possible 10 | solutions. 11 | """ 12 | 13 | def test_stop_multi_input(self): 14 | """ 15 | In this case the stop segment has different inputs 16 | at different nodes, and this confuses the change range 17 | computation step and breaks assertions. 18 | 19 | One way to fix it could be to modify the segmenting code 20 | such that nodes that are encountered again and are already 21 | in stop segment should be updated with a new idx, and 22 | in the label_stop_descendants step nodes that are already 23 | labeled with (potentially non-stop) index are given a new 24 | index. This would likely fail if a stop layer gets non-stop 25 | inputs of different widths. 26 | """ 27 | inp = tf.keras.Input((100, 4)) 28 | x = tf.keras.layers.Conv1D(10, 3, padding='same')(inp) 29 | y = tf.keras.layers.Dense(10)(x) 30 | x = tf.keras.layers.Add()([x,y]) 31 | x = tf.keras.layers.Flatten()(x) 32 | x = tf.keras.layers.Dense(1)(x) 33 | model = tf.keras.Model(inputs=inp, outputs=x) 34 | 35 | fast_ism_model = fastISM.FastISM( 36 | model, test_correctness=False) 37 | 38 | self.assertTrue(fast_ism_model.test_correctness()) 39 | 40 | --------------------------------------------------------------------------------