├── .gitignore ├── .travis.yml ├── .travis ├── deploy.sh └── docker-compose.yml ├── CONTRIBUTING.md ├── LICENSE ├── MAINTAINERS.md ├── README.md ├── bin └── depiction-models-download ├── blogs └── ibm_developer │ └── 01_depiction.ipynb ├── data ├── README.md ├── paccmann │ ├── README.md │ ├── gdsc.csv.gz │ ├── gdsc.smi │ └── gdsc_sensitivity.csv.gz └── single-cell │ ├── README.md │ ├── data.csv │ └── metadata.csv ├── depiction ├── __init__.py ├── core.py ├── interpreters │ ├── __init__.py │ ├── aix360 │ │ ├── __init__.py │ │ ├── rule_based_model.py │ │ └── tests │ │ │ ├── __init__.py │ │ │ └── rule_based_model_test.py │ ├── alibi │ │ ├── __init__.py │ │ ├── contrastive │ │ │ ├── __init__.py │ │ │ ├── cem.py │ │ │ └── tests │ │ │ │ ├── __init__.py │ │ │ │ └── cem_test.py │ │ └── counterfactual │ │ │ ├── __init__.py │ │ │ ├── counterfactual.py │ │ │ └── tests │ │ │ ├── __init__.py │ │ │ └── counterfactual_test.py │ ├── backprop │ │ ├── __init__.py │ │ ├── backpropeter.py │ │ └── tests │ │ │ ├── __init__.py │ │ │ └── backpropeter_test.py │ ├── base │ │ ├── __init__.py │ │ ├── base_interpreter.py │ │ └── tests │ │ │ ├── __init__.py │ │ │ └── base_interpreter_test.py │ └── u_wash │ │ ├── __init__.py │ │ ├── tests │ │ ├── __init__.py │ │ └── u_washer_test.py │ │ └── u_washer.py ├── models │ ├── __init__.py │ ├── base │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── binarized_model.py │ │ ├── tests │ │ │ ├── __init__.py │ │ │ ├── base_model_test.py │ │ │ └── utils_test.py │ │ └── utils.py │ ├── examples │ │ ├── __init__.py │ │ ├── celltype │ │ │ ├── __init__.py │ │ │ ├── celltype.py │ │ │ └── tests │ │ │ │ ├── __init__.py │ │ │ │ └── celltype_test.py │ │ ├── deepbind │ │ │ ├── __init__.py │ │ │ ├── deepbind.py │ │ │ └── deepbind_cli.py │ │ └── paccmann │ │ │ ├── __init__.py │ │ │ ├── core.py │ │ │ └── smiles.py │ ├── keras │ │ ├── __init__.py │ │ ├── application.py │ │ ├── core.py │ │ └── tests │ │ │ ├── __init__.py │ │ │ ├── application_test.py │ │ │ └── core_test.py │ ├── kipoi │ │ ├── __init__.py │ │ ├── core.py │ │ └── tests │ │ │ ├── __init__.py │ │ │ └── kipoi_test.py │ ├── max │ │ ├── __init__.py │ │ ├── breast_cancer_mitosis_detector.py │ │ ├── tests │ │ │ ├── __init__.py │ │ │ ├── breast_cancer_mitosis_detector_test.py │ │ │ └── toxic_comment_classifier_test.py │ │ └── toxic_comment_classifier.py │ ├── torch │ │ ├── __init__.py │ │ ├── core.py │ │ ├── tests │ │ │ ├── __init__.py │ │ │ ├── core_test.py │ │ │ └── torchvision_test.py │ │ └── torchvision.py │ └── uri │ │ ├── __init__.py │ │ ├── cache │ │ ├── __init__.py │ │ ├── cache_model.py │ │ ├── cos_model.py │ │ ├── file_system_model.py │ │ ├── http_model.py │ │ └── tests │ │ │ └── __init__.py │ │ ├── rest_api │ │ ├── __init__.py │ │ ├── max_model.py │ │ ├── rest_api_model.py │ │ └── tests │ │ │ ├── __init__.py │ │ │ ├── max_model_test.py │ │ │ └── rest_api_model_test.py │ │ ├── tests │ │ └── __init__.py │ │ └── uri_model.py └── tests │ ├── __init__.py │ └── core_test.py ├── docker ├── Dockerfile ├── docker-compose.yml └── docker-entrypoint.sh ├── environment.yml ├── examples ├── cem_mnist.py └── uwashers_imagenet.py ├── notebooks ├── Breast cancer IDC classification.ipynb ├── celltype_interpretability.ipynb ├── celltype_training.ipynb ├── deepbind.ipynb ├── kaggle_create_a_new_api_token.png ├── kaggle_go_to_your_account.png ├── model-zoo.ipynb └── paccmann.ipynb ├── requirements.txt ├── setup.py └── workshops ├── 20190909_BC2 └── README.md ├── 20191120_ODSC2019 ├── README.md ├── blog │ ├── README.md │ └── lime.png └── notebooks │ ├── celltype.ipynb │ ├── imagenet.ipynb │ └── mnist.ipynb ├── 20191121_PyPharma19 ├── README.md └── tutorial_colab.ipynb └── 20200125_AMLD2020 ├── README.md └── notebooks ├── breast_cancer_idc_classification.ipynb ├── celltype_interpretability.ipynb ├── deepbind.ipynb ├── kaggle_create_a_new_api_token.png ├── kaggle_go_to_your_account.png └── paccmann.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | depiction/cache/ 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | services: 2 | - docker 3 | 4 | env: 5 | - DOCKER_COMPOSE_VERSION=1.25.0 6 | 7 | before_install: 8 | # docker 9 | - curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add - 10 | - sudo add-apt-repository "deb [arch=amd64] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable" 11 | - sudo apt-get update 12 | - sudo apt-get -y -o Dpkg::Options::="--force-confnew" install docker-ce 13 | # docker-compose 14 | - sudo rm /usr/local/bin/docker-compose 15 | - curl -L https://github.com/docker/compose/releases/download/${DOCKER_COMPOSE_VERSION}/docker-compose-`uname -s`-`uname -m` > docker-compose 16 | - chmod +x docker-compose 17 | - sudo mv docker-compose /usr/local/bin 18 | script: 19 | - docker-compose -f docker/docker-compose.yml build 20 | - docker-compose -f .travis/docker-compose.yml up -d 21 | - docker exec -it depiction-test python3 -m unittest discover -v -t /build -p "*_test.py" /build/depiction/tests 22 | - docker exec -it depiction-test python3 -m unittest discover -v -t /build -p "*_test.py" /build/depiction/interpreters/ 23 | - docker exec -it depiction-test python3 -m unittest discover -v -t /build -p "*_test.py" /build/depiction/models/base/ 24 | - docker exec -it depiction-test python3 -m unittest discover -v -t /build -p "*_test.py" /build/depiction/models/keras 25 | - docker exec -it depiction-test python3 -m unittest discover -v -t /build -p "*_test.py" /build/depiction/models/kipoi/ 26 | - docker exec -it depiction-test python3 -m unittest discover -v -t /build -p "*_test.py" /build/depiction/models/max/ 27 | - docker exec -it depiction-test python3 -m unittest discover -v -t /build -p "*_test.py" /build/depiction/models/torch 28 | - docker exec -it depiction-test python3 -m unittest discover -v -t /build -p "*_test.py" /build/depiction/models/uri/ 29 | - docker exec -it depiction-test python3 -m unittest discover -v -t /build -p "*_test.py" /build/depiction/models/examples/celltype/ 30 | 31 | deploy: 32 | provider: script 33 | skip_cleanup: true 34 | script: sh .travis/deploy.sh 35 | on: 36 | branches: 37 | only: 38 | - master 39 | -------------------------------------------------------------------------------- /.travis/deploy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | echo ${DOCKER_PASSWORD} | docker login -u ${DOCKER_USERNAME} --password-stdin 4 | docker tag drugilsberg/depiction:latest drugilsberg/depiction:${TRAVIS_COMMIT} 5 | docker push drugilsberg/depiction:${TRAVIS_COMMIT} 6 | docker push drugilsberg/depiction:latest -------------------------------------------------------------------------------- /.travis/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.2" 2 | services: 3 | depiction: 4 | container_name: depiction-test 5 | image: drugilsberg/depiction 6 | environment: 7 | - TEST_REST_API=max-toxic-comment-classifier 8 | - TEST_MAX_BASE=max-toxic-comment-classifier 9 | - TEST_MAX_TOXIC_COMMENT_CLASSIFIER=max-toxic-comment-classifier 10 | - TEST_MAX_BREAST_CANCER_MITOSIS_DETECTOR=max-breast-cancer-mitosis-detector 11 | max-toxic-comment-classifier: 12 | image: codait/max-toxic-comment-classifier 13 | container_name: max-toxic-comment-classifier-test 14 | max-breast-cancer-mitosis-detector: 15 | image: codait/max-breast-cancer-mitosis-detector 16 | container_name: max-breast-cancer-mitosis-detector-test 17 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Contributing In General 2 | 3 | Our project welcomes external contributions. If you have an itch, please feel 4 | free to scratch it. 5 | 6 | To contribute code or documentation, please submit a [pull request](https://github.com/ibm/dl-interpretability-compbio/pulls). 7 | 8 | A good way to familiarize yourself with the codebase and contribution process is 9 | to look for and tackle low-hanging fruit in the [issue tracker](https://github.com/ibm/dl-interpretability-compbio/issues). 10 | Before embarking on a more ambitious contribution, please quickly [get in touch](#communication) with us. 11 | 12 | **Note: We appreciate your effort, and want to avoid a situation where a contribution 13 | requires extensive rework (by you or by us), sits in backlog for a long time, or 14 | cannot be accepted at all!** 15 | 16 | ### Proposing new features 17 | 18 | If you would like to implement a new feature, please [raise an issue](https://github.com/ibm/dl-interpretability-compbio/issues) 19 | before sending a pull request so the feature can be discussed. This is to avoid 20 | you wasting your valuable time working on a feature that the project developers 21 | are not interested in accepting into the code base. 22 | 23 | ### Fixing bugs 24 | 25 | If you would like to fix a bug, please [raise an issue](https://github.com/ibm/dl-interpretability-compbio/issues) before sending a 26 | pull request so it can be tracked. 27 | 28 | ### Merge approval 29 | 30 | The project maintainers use LGTM (Looks Good To Me) in comments on the code 31 | review to indicate acceptance. A change requires LGTMs from two of the 32 | maintainers of each component affected. 33 | 34 | For a list of the maintainers, see the [MAINTAINERS.md](MAINTAINERS.md) page. 35 | 36 | ## Legal 37 | 38 | We have tried to make it as easy as possible to make contributions. This 39 | applies to how we handle the legal aspects of contribution. We use the 40 | same approach - the [Developer's Certificate of Origin 1.1 (DCO)](https://github.com/hyperledger/fabric/blob/master/docs/source/DCO1.1.txt) - that the Linux® Kernel [community](https://elinux.org/Developer_Certificate_Of_Origin) 41 | uses to manage code contributions. 42 | 43 | We simply ask that when submitting a patch for review, the developer 44 | must include a sign-off statement in the commit message. 45 | 46 | Here is an example Signed-off-by line, which indicates that the 47 | submitter accepts the DCO: 48 | 49 | ```console 50 | Signed-off-by: Jane Doe 51 | ``` 52 | 53 | You can include this automatically when you commit a change to your 54 | local git repository using the following command: 55 | 56 | ```sh 57 | git commit -s 58 | ``` 59 | 60 | ## Communication 61 | 62 | Please feel free to connect with us via email, see the [MAINTAINERS.md](MAINTAINERS.md) page. 63 | 64 | ## Setup 65 | 66 | Setup a conda environment 67 | 68 | ```sh 69 | conda env create -f environment.yml 70 | ``` 71 | 72 | Activate it: 73 | 74 | ```sh 75 | conda activate depiction-env 76 | ``` 77 | 78 | Install the module in editable mode: 79 | 80 | ```sh 81 | pip install -e . 82 | ``` 83 | 84 | Optionally, install a `jupiter` playground: 85 | 86 | ```sh 87 | pip install jupyter 88 | ipython kernel install --user --name=depiction-development 89 | ``` 90 | 91 | ## Testing 92 | 93 | For tests we use the `unittest` module. 94 | You can run tess by typing: 95 | 96 | ```sh 97 | python -m unittest discover -p "*test.py" 98 | ``` 99 | 100 | ## Coding style guidelines 101 | 102 | We try to follow PEP8 styling. 103 | -------------------------------------------------------------------------------- /MAINTAINERS.md: -------------------------------------------------------------------------------- 1 | # MAINTAINERS 2 | 3 | - Matteo Manica - drugilsberg@gmail.com - [drugilsberg](https://github.com/drugilsberg) 4 | - An-phi Nguyen - nguyen.phineas@gmail.com - [phineasng](https://github.com/phineasng) 5 | - Joris Cadow - joriscadow@gmail.com - [C-nit](https://github.com/C-nit) 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # depiction 2 | 3 | [![Build Status](https://travis-ci.com/IBM/depiction.svg?branch=master)](https://travis-ci.com/IBM/depiction) 4 | 5 | A collection of tools and resources to interpret deep learning models in a framework-independent fashion. 6 | 7 | The core of the repo is a package, called `depiction`, with wrappers around models and methods for interpretable deep learning. 8 | 9 | **DISCLAIMER**: This repo is undergoing a refactoring. For the latest developments (e.g. as shown in ISMB/ECCB 21), please check the other branches, in particular `visualizations`. 10 | 11 | ## Docker setup 12 | 13 | ### Install docker 14 | 15 | Make sure to have a working [docker](https://www.docker.com/) installation. 16 | Installation instructions for different operative systems can be found on the [website](https://docs.docker.com/install/). 17 | 18 | ### Get `drugilsberg/depiction` image 19 | 20 | We built a [docker image](https://cloud.docker.com/repository/docker/drugilsberg/depiction) for `depiction` containing all models, data and dependencies needed to run the notebooks contained in the repo. 21 | Once the docker installation is complete the `depiction` image can be pulled right away: 22 | 23 | ```sh 24 | docker pull drugilsberg/depiction 25 | ``` 26 | 27 | **NOTE**: the image is quite large (~5.5GB) and this step might require sometime. 28 | 29 | ### Run `drugilsberg/depiction` image 30 | 31 | The image can be run to serve [jupyter](https://jupyter.org/) notebooks by typing: 32 | 33 | ```sh 34 | docker run -p 8899:8888 -it drugilsberg/depiction 35 | ``` 36 | 37 | At this point just connect to [http://localhost:8899/tree](http://localhost:8899/tree) to run the notebooks and experiment with `depiction`. 38 | 39 | #### Daemonization 40 | 41 | We recommend to run it as a daemon: 42 | 43 | ```sh 44 | docker run -d -p 8899:8888 -it drugilsberg/depiction 45 | ``` 46 | 47 | maybe mount your local notebooks directory to keep the changes locally 48 | 49 | ``` 50 | docker run --mount src=`pwd`/notebooks,target=/workspace/notebooks,type=bind -p 8899:8888 -it drugilsberg/depiction 51 | ``` 52 | 53 | and stopped using the container id: 54 | 55 | ```sh 56 | docker stop 57 | ``` 58 | 59 | ## Development setup 60 | 61 | Setup a conda environment 62 | 63 | ```sh 64 | conda env create -f environment.yml 65 | ``` 66 | 67 | Activate it: 68 | 69 | ```sh 70 | conda activate depiction-env 71 | ``` 72 | 73 | Install the module: 74 | 75 | ```sh 76 | pip install . 77 | ``` 78 | -------------------------------------------------------------------------------- /bin/depiction-models-download: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | """Download models wrapped in depiction.""" 3 | import argparse 4 | from depiction.core import DataType 5 | from depiction.models.examples.celltype import CellTyper 6 | from depiction.models.examples.deepbind import DeepBind 7 | from depiction.models.examples.deepbind.deepbind_cli import DeepBind as DB 8 | from depiction.models.examples.paccmann import PaccMann 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument( 12 | '-c', '--cache_dir', type=str, 13 | help=( 14 | 'cache directory where to stored the models. ' 15 | 'Defaults to None. It behaves as tensorflow.keras.utils.get_file.' 16 | ), 17 | default=None 18 | ) 19 | 20 | if __name__ == '__main__': 21 | # parse arguments 22 | args = parser.parse_args() 23 | # CellTyper 24 | _ = CellTyper(cache_dir=args.cache_dir) 25 | # DeepBind 26 | _ = DeepBind(model='DeepBind/Homo_sapiens/TF/D00328.003_SELEX_CTCF') 27 | _ = DeepBind(model='DeepBind/Homo_sapiens/TF/D00761.001_ChIP-seq_FOXA1') 28 | # DeepBind cli version for old workshop notebook support 29 | _ = DB(cache_dir=args.cache_dir) 30 | # PaccMann 31 | # NOTE: here data_type is not relevant 32 | _ = PaccMann(data_type=DataType.TEXT, cache_dir=args.cache_dir) 33 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data for the interpretability tutorial 2 | 3 | Data for the examples: 4 | 5 | - Compound structure, gene expression data and drug sensitivity data from [Yang et al.](https://academic.oup.com/nar/article/41/D1/D955/1059448) (`paccmann`) 6 | - Single cell mass cytometry from [Levine et al.](https://www.cell.com/cell/fulltext/S0092-8674(15)00637-6) (`single-cell`) 7 | -------------------------------------------------------------------------------- /data/paccmann/README.md: -------------------------------------------------------------------------------- 1 | # GDSC compounds 2 | 3 | Molecular structure in `.smi` format and gene expression in `.csv.gz` format for the compounds and cell lines considered in [Yang et al.](https://academic.oup.com/nar/article/41/D1/D955/1059448). 4 | 5 | - `gdsc.smi`, publicly available structures for 209 drugs used in the study. 6 | - `gdsc.csv.gz`, gene expression for 970 cell lines in 2128 genes selected via network propagation as described in [Oskooei at al.](https://arxiv.org/abs/1811.06802) and [Manica et al.](https://arxiv.org/abs/1904.11223). 7 | - `gdsc_sensitivity.csv.gz`, drug sensitivity for 212539 pairs expressed as a boolean label: 1 effective (IC50 < 1 um), 0 non effective. 8 | -------------------------------------------------------------------------------- /data/paccmann/gdsc.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/data/paccmann/gdsc.csv.gz -------------------------------------------------------------------------------- /data/paccmann/gdsc_sensitivity.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/data/paccmann/gdsc_sensitivity.csv.gz -------------------------------------------------------------------------------- /data/single-cell/README.md: -------------------------------------------------------------------------------- 1 | # Single-cell 2 | 3 | Dataset from [Levine et al.](http://www.cell.com/cell/fulltext/S0092-8674(15)00637-6). 4 | The data consist of approximately 80K single cells, where the abundance of 13 cell surface markers was measured by mass cytometry. 5 | In this dataset the existing cell subpopulations are well-characterized by manual gating. 6 | The dataset is composed by: 7 | 8 | - `data.csv`, the makers expression with the subpopulation label. 9 | - `metadata.csv`, the mapping between the label and the actual cell type. 10 | -------------------------------------------------------------------------------- /data/single-cell/metadata.csv: -------------------------------------------------------------------------------- 1 | label,cell type name 2 | 1,CD11b- Monocyte 3 | 2,CD11bhi Monocyte 4 | 3,CD11bmid Monocyte 5 | 4,Erythroblast 6 | 5,HSC 7 | 6,Immature B 8 | 7,Mature CD38lo B 9 | 8,Mature CD38mid B 10 | 9,Mature CD4+ T 11 | 10,Mature CD8+ T 12 | 11,Megakaryocyte 13 | 12,Myelocyte 14 | 13,NK 15 | 14,Naive CD4+ T 16 | 15,Naive CD8+ T 17 | 16,Plasma cell 18 | 17,Plasmacytoid DC 19 | 18,Platelet 20 | 19,Pre-B II 21 | 20,Pre-B I 22 | -------------------------------------------------------------------------------- /depiction/__init__.py: -------------------------------------------------------------------------------- 1 | """DEPICTION initialization module.""" 2 | 3 | name = 'depiction' 4 | __version__ = '0.0.1' 5 | -------------------------------------------------------------------------------- /depiction/core.py: -------------------------------------------------------------------------------- 1 | """Core utilities for depiction.""" 2 | from enum import Enum, Flag, auto 3 | 4 | 5 | class Task(Flag): 6 | """Enum indicating the task performed by a model.""" 7 | BINARY = auto() 8 | MULTICLASS = auto() 9 | REGRESSION = auto() 10 | CLASSIFICATION = BINARY | MULTICLASS 11 | 12 | def __lt__(self, other): 13 | res = (self.value & other.value) 14 | return (res == self.value) and (res != other.value) 15 | 16 | def __le__(self, other): 17 | return self.__lt__(other) or (self.value == other.value) 18 | 19 | def __gt__(self, other): 20 | return ((self.value 21 | | other.value) == self.value) and (self.value != other.value) 22 | 23 | def __ge__(self, other): 24 | return self.__gt__(other) or (self.value == other.value) 25 | 26 | @staticmethod 27 | def check_support(t, tasks_set): 28 | """ 29 | Given an iterable containing tasks, checks if 'self' <= to any of the 30 | tasks in the iterable. 31 | 32 | Args: 33 | tasks_set (iterable): iterable containing tasks 34 | """ 35 | for task in tasks_set: 36 | if t <= task: 37 | return True 38 | return False 39 | 40 | 41 | class DataType(Enum): 42 | """Enum indicating the data type used by a model.""" 43 | TABULAR = 1 44 | TEXT = 2 45 | IMAGE = 3 46 | -------------------------------------------------------------------------------- /depiction/interpreters/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/depiction/interpreters/__init__.py -------------------------------------------------------------------------------- /depiction/interpreters/aix360/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize AIX360 models.""" 2 | from .rule_based_model import RuleAIX360 # noqa -------------------------------------------------------------------------------- /depiction/interpreters/aix360/rule_based_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Wrapper around the rule based models implemented in the AIX360 framework 3 | 4 | References: 5 | - https://github.com/IBM/AIX360 6 | - Wei, D., Dash, S., Gao, T. & Gunluk, O.. (2019). Generalized Linear Rule Models. Proceedings of the 36th International Conference on Machine Learning, in PMLR 97:6687-6696 7 | - Dash, S., Gunluk, O., & Wei, D. (2018). Boolean decision rules via column generation. In Advances in Neural Information Processing Systems (pp. 4655-4665). 8 | """ 9 | import pickle 10 | import numpy as np 11 | import pandas as pd 12 | from pandas import DataFrame 13 | from aix360.algorithms.rbm import BRCGExplainer, BooleanRuleCG 14 | from aix360.algorithms.rbm import ( 15 | GLRMExplainer, LogisticRuleRegression, LinearRuleRegression 16 | ) 17 | from aix360.algorithms.rbm import FeatureBinarizer 18 | 19 | from ...core import Task, DataType 20 | from ..base.base_interpreter import AnteHocInterpreter, ExplanationType 21 | 22 | 23 | class RuleAIX360(AnteHocInterpreter): 24 | _AVAILABLE_RULE_REGRESSORS = {'logistic', 'linear'} 25 | 26 | SUPPORTED_TASK = {Task.BINARY} 27 | SUPPORTED_DATATYPE = {DataType.TABULAR} 28 | 29 | AVAILABLE_INTERPRETERS = {'brcg'}.union( 30 | {'glrm_{}'.format(i) 31 | for i in _AVAILABLE_RULE_REGRESSORS} 32 | ) 33 | 34 | EXPLANATION_TYPE = ExplanationType.GLOBAL 35 | 36 | def __init__(self, explainer, X, model=None, y=None, regressor_params={}): 37 | """ 38 | Constructor. For a description of the missing arguments, 39 | please refer to the AnteHocInterpreter. 40 | 41 | Args: 42 | - explainer (str): name of the explainer to use. 43 | - X (np.ndarray or pd.DataFrame): data to explain. 44 | - model (depiction.models.base.BaseModel): a model to interpret. 45 | Defaults to None, a.k.a. ante-hoc. 46 | - y (np.ndarray): binary labels for X. 47 | Defaults to None, a.k.a. post-hoc. 48 | - regressor_params (dict): parameters for the regressor.s 49 | """ 50 | is_post_hoc = y is None 51 | is_ante_hoc = model is None 52 | if is_ante_hoc and is_post_hoc: 53 | raise RuntimeError( 54 | 'Make sure you pass a model (post-hoc) or labels (ante-hoc)' 55 | ) 56 | if model is None: 57 | super(RuleAIX360, self).__init__( 58 | AnteHocInterpreter.UsageMode.ANTE_HOC, 59 | task_type=Task.BINARY, 60 | data_type=DataType.TABULAR 61 | ) 62 | else: 63 | super(RuleAIX360, self).__init__( 64 | AnteHocInterpreter.UsageMode.POST_HOC, model=model 65 | ) 66 | 67 | if 'glrm' in explainer: 68 | regressor = explainer.split('_')[1] 69 | if regressor == 'logistic': 70 | self.regressor = LogisticRuleRegression(**regressor_params) 71 | elif regressor == 'linear': 72 | self.regressor = LinearRuleRegression(**regressor_params) 73 | else: 74 | raise ValueError( 75 | "Regressor '{}' not supported! Available regressors: {}". 76 | format(regressor, self._AVAILABLE_RULE_REGRESSORS) 77 | ) 78 | self.explainer = GLRMExplainer(self.regressor) 79 | elif explainer == 'brcg': 80 | self.regressor = BooleanRuleCG(**regressor_params) 81 | self.explainer = BRCGExplainer(self.regressor) 82 | else: 83 | raise ValueError( 84 | "Interpreter '{}' not supported! Available interpreters: {}". 85 | format(explainer, self.AVAILABLE_INTERPRETERS) 86 | ) 87 | 88 | if isinstance(X, np.ndarray): 89 | X = pd.DataFrame(X) 90 | self.X = X 91 | self.y = y 92 | self.binarizer = FeatureBinarizer(negations=True) 93 | self.X_binarized = self.binarizer.fit_transform(self.X) 94 | self._fitted = False 95 | 96 | def _fit_antehoc(self, X, y): 97 | """ 98 | Fitting the rule based model (antehoc version). 99 | 100 | Args: 101 | X (pandas.DataFrame): model input data 102 | y (array): model output data 103 | """ 104 | self.explainer.fit(X, y) 105 | self._fitted = True 106 | 107 | def _fit_posthoc(self, X, preprocess_X=None, postprocess_y=None): 108 | """ 109 | Fitting the rule based model to posthoc interpret another model. 110 | 111 | Args: 112 | X: input to the model to be interpreted. Type depends on the model. 113 | preprocess_X: function to create a pandas.DataFrame from the model input to feed to this rule-based model. 114 | postprocess_y: function to postprocess the model output to feed to this rule-based model. 115 | """ 116 | y = self._to_interpret.predict(X) 117 | 118 | processed_X = X 119 | processed_y = y 120 | 121 | if preprocess_X is not None: 122 | processed_X = preprocess_X(processed_X) 123 | 124 | if postprocess_y is not None: 125 | processed_y = postprocess_y(processed_y) 126 | 127 | self._fit_antehoc(processed_X, processed_y) 128 | 129 | def interpret(self, explanation_configs={}, path=None): 130 | """ 131 | Produce explanation. 132 | 133 | Args: 134 | explanation_configs (dict): keyword arguments for the explain 135 | function of the explainer. Refer to the AIX360 implementation 136 | for details. 137 | path (str): path where to save the explanation. If None, a notebook 138 | environment will be assumed, and the explanation will be 139 | visualized. 140 | 141 | Returns: 142 | pd.DataFrame or dict: the explanation. 143 | """ 144 | if not self._fitted: 145 | if self.usage_mode == self.UsageMode.ANTE_HOC: 146 | self._fit_antehoc(self.X_binarized, self.y) 147 | else: 148 | self._fit_posthoc(self.X, self.binarizer.transform) 149 | 150 | self.explanation = self.explainer.explain(**explanation_configs) 151 | if path is None: 152 | self._visualize_explanation(self.explanation) 153 | else: 154 | self._save_explanation(self.explanation, path) 155 | return self.explanation 156 | 157 | def _visualize_explanation(self, explanation): 158 | """ 159 | Helper function to visualize the explanation. 160 | """ 161 | if isinstance(self.explainer, GLRMExplainer): 162 | with pd.option_context( 163 | 'display.max_rows', None, 'display.max_columns', None 164 | ): 165 | print(explanation) 166 | elif isinstance(self.explainer, BRCGExplainer): 167 | # from "https://github.com/IBM/AIX360/blob/master/examples/rbm/breast-cancer-br.ipynb" 168 | isCNF = 'Predict Y=1 if ANY of the following rules are satisfied, otherwise Y=0:' 169 | notCNF = 'Predict Y=0 if ANY of the following rules are satisfied, otherwise Y=1:' 170 | print(isCNF if explanation['isCNF'] else notCNF) 171 | print() 172 | for rule in explanation['rules']: 173 | print(f' - {rule}') 174 | 175 | def _save_explanation(self, explanation, path): 176 | if isinstance(explanation, DataFrame): 177 | explanation.to_pickle(path) 178 | else: 179 | with open(path, 'wb') as f: 180 | pickle.dump(explanation, f) 181 | 182 | def predict(self, X, **kwargs): 183 | self.explainer.predict(X, **kwargs) 184 | -------------------------------------------------------------------------------- /depiction/interpreters/aix360/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/depiction/interpreters/aix360/tests/__init__.py -------------------------------------------------------------------------------- /depiction/interpreters/aix360/tests/rule_based_model_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | from random import choice 4 | from unittest import mock 5 | 6 | from aix360.algorithms.rbm import ( 7 | BooleanRuleCG, BRCGExplainer, GLRMExplainer, LinearRuleRegression, 8 | LogisticRuleRegression 9 | ) 10 | from pandas import DataFrame 11 | 12 | from depiction.core import DataType, Task 13 | from depiction.interpreters.aix360.rule_based_model import RuleAIX360 14 | from depiction.models.base.base_model import BaseModel 15 | 16 | 17 | class DummyModel(BaseModel): 18 | 19 | def predict(self, sample): 20 | return np.array([choice([0, 1]) for _ in range(sample.shape[0])]) 21 | 22 | 23 | class RuleAIX360TestCase(unittest.TestCase): 24 | 25 | def setUp(self): 26 | self.X = np.random.randn(100, 10) 27 | self.y = (np.random.randn(100) > 0.).astype(int) 28 | 29 | def _build_posthoc_interpreter(self): 30 | model = DummyModel( 31 | choice(list(RuleAIX360.SUPPORTED_TASK)), 32 | choice(list(RuleAIX360.SUPPORTED_DATATYPE)) 33 | ) 34 | interpreter = RuleAIX360( 35 | choice(list(RuleAIX360.AVAILABLE_INTERPRETERS)), 36 | X=self.X, 37 | model=model 38 | ) 39 | return interpreter 40 | 41 | def _build_antehoc_interpreter(self): 42 | interpreter = RuleAIX360( 43 | choice(list(RuleAIX360.AVAILABLE_INTERPRETERS)), self.X, y=self.y 44 | ) 45 | return interpreter 46 | 47 | def testConstructor(self): 48 | # test error for wrong model 49 | NOT_SUPPORTED_TASKS = [ 50 | t for t in set(Task) for T in RuleAIX360.SUPPORTED_TASK 51 | if not (t <= T) 52 | ] 53 | NOT_SUPPORTED_TYPES = list( 54 | set(DataType).difference(RuleAIX360.SUPPORTED_DATATYPE) 55 | ) 56 | 57 | wrong_model = DummyModel( 58 | choice(NOT_SUPPORTED_TASKS), choice(NOT_SUPPORTED_TYPES) 59 | ) 60 | 61 | with self.assertRaises(ValueError): 62 | RuleAIX360( 63 | choice(list(RuleAIX360.AVAILABLE_INTERPRETERS)), 64 | X=self.X, 65 | model=wrong_model 66 | ) 67 | 68 | # test error for not supported interpreter 69 | with self.assertRaises(ValueError): 70 | RuleAIX360('', X=self.X, y=self.y) 71 | 72 | # test error for not supported GLRM regressor 73 | with self.assertRaises(ValueError): 74 | RuleAIX360('glrm_bubu', X=self.X, y=self.y) 75 | 76 | # test correctly chosen glrm and regressor 77 | valid_glrm = [ 78 | i for i in RuleAIX360.AVAILABLE_INTERPRETERS if 'glrm' in i 79 | ] 80 | interpreter = RuleAIX360(choice(valid_glrm), X=self.X, y=self.y) 81 | self.assertTrue(isinstance(interpreter.explainer, GLRMExplainer)) 82 | self.assertTrue( 83 | isinstance(interpreter.regressor, LogisticRuleRegression) 84 | or isinstance(interpreter.regressor, LinearRuleRegression) 85 | ) 86 | self.assertFalse(interpreter._fitted) 87 | 88 | # -- test correctness of ante-hoc model 89 | self.assertEqual(interpreter.usage_mode, RuleAIX360.UsageMode.ANTE_HOC) 90 | self.assertTrue( 91 | Task.check_support(interpreter.task, RuleAIX360.SUPPORTED_TASK) 92 | ) 93 | self.assertTrue(interpreter.data_type in RuleAIX360.SUPPORTED_DATATYPE) 94 | 95 | # test brcg model 96 | interpreter = RuleAIX360('brcg', X=self.X, y=self.y) 97 | self.assertTrue(isinstance(interpreter.explainer, BRCGExplainer)) 98 | self.assertTrue(isinstance(interpreter.regressor, BooleanRuleCG)) 99 | self.assertFalse(interpreter._fitted) 100 | 101 | # test with right model 102 | interpreter = self._build_posthoc_interpreter() 103 | self.assertEqual(interpreter.usage_mode, RuleAIX360.UsageMode.POST_HOC) 104 | self.assertFalse(interpreter._fitted) 105 | 106 | def testFit(self): 107 | # test fit antehoc called correctly 108 | interpreter = self._build_antehoc_interpreter() 109 | 110 | with mock.patch.object( 111 | interpreter, '_fit_antehoc' 112 | ) as mock_fit_antehoc: 113 | interpreter.fit(0, 0) 114 | mock_fit_antehoc.assert_called_once() 115 | 116 | # test fit posthoc called correctly 117 | interpreter = self._build_posthoc_interpreter() 118 | 119 | with mock.patch.object( 120 | interpreter, '_fit_posthoc' 121 | ) as mock_fit_posthoc: 122 | interpreter.fit(0, 0) 123 | mock_fit_posthoc.assert_called_once() 124 | 125 | def testFitAntehoc(self): 126 | interpreter = self._build_antehoc_interpreter() 127 | 128 | with mock.patch.object( 129 | interpreter.explainer, 'fit' 130 | ) as mock_explainer_fit: 131 | interpreter.fit(0, 0) 132 | mock_explainer_fit.assert_called_once() 133 | 134 | def testFitPosthoc(self): 135 | interpreter = self._build_posthoc_interpreter() 136 | 137 | with mock.patch.object( 138 | interpreter._to_interpret, 'predict' 139 | ) as mock_predict: 140 | with mock.patch.object( 141 | interpreter, '_fit_antehoc' 142 | ) as mock_fit_antehoc: 143 | interpreter.fit(0) 144 | 145 | mock_predict.assert_called_once() 146 | mock_fit_antehoc.assert_called_once() 147 | 148 | with mock.patch.object( 149 | interpreter._to_interpret, 'predict' 150 | ) as mock_predict: 151 | with mock.patch.object( 152 | interpreter, '_fit_antehoc' 153 | ) as mock_fit_antehoc: 154 | preprocess = mock.MagicMock() 155 | 156 | interpreter.fit(0, preprocess) 157 | preprocess.assert_called_once() 158 | preprocess.assert_called_with(0) 159 | 160 | with mock.patch.object( 161 | interpreter._to_interpret, 'predict', return_value=2 162 | ) as mock_predict: 163 | with mock.patch.object( 164 | interpreter, '_fit_antehoc' 165 | ) as mock_fit_antehoc: 166 | postprocess = mock.MagicMock() 167 | 168 | interpreter.fit(0, postprocess_y=postprocess) 169 | postprocess.assert_called_once() 170 | postprocess.assert_called_with(2) 171 | 172 | def testInterpret(self): 173 | builder = choice( 174 | [self._build_posthoc_interpreter, self._build_antehoc_interpreter] 175 | ) 176 | interpreter = builder() 177 | 178 | with mock.patch.object( 179 | interpreter.explainer, 'explain' 180 | ) as mock_explain: 181 | with mock.patch.object( 182 | interpreter, '_visualize_explanation' 183 | ) as mock_visualize: 184 | e = interpreter.interpret() 185 | 186 | mock_explain.assert_called_once() 187 | mock_visualize.assert_called_once() 188 | self.assertTrue(e, interpreter.explanation) 189 | 190 | with mock.patch.object( 191 | interpreter.explainer, 'explain' 192 | ) as mock_explain: 193 | with mock.patch.object( 194 | interpreter, '_save_explanation' 195 | ) as mock_save: 196 | e = interpreter.interpret(path='') 197 | 198 | mock_explain.assert_called_once() 199 | mock_save.assert_called_once() 200 | self.assertTrue(e, interpreter.explanation) 201 | 202 | def testVisualize(self): 203 | """ 204 | TODO(phineasng): think if it's possible or make sense to test this 205 | """ 206 | pass 207 | 208 | def testSave(self): 209 | builder = choice( 210 | [self._build_posthoc_interpreter, self._build_antehoc_interpreter] 211 | ) 212 | interpreter = builder() 213 | 214 | # test DataFrame 215 | df = DataFrame() 216 | with mock.patch.object(df, 'to_pickle') as mock_to_pickle: 217 | interpreter._save_explanation(df, path='') 218 | mock_to_pickle.assert_called_with('') 219 | 220 | exp = object() 221 | module_name = 'depiction.interpreters.aix360.rule_based_model' 222 | with mock.patch('{}.open'.format(module_name)) as mock_open: 223 | with mock.patch('{}.pickle.dump'.format(module_name)) as mock_dump: 224 | interpreter._save_explanation(exp, path='') 225 | mock_open.assert_called_once() 226 | mock_open.assert_called_with('', 'wb') 227 | mock_dump.assert_called_once() 228 | 229 | def testPredict(self): 230 | builder = choice( 231 | [self._build_posthoc_interpreter, self._build_antehoc_interpreter] 232 | ) 233 | interpreter = builder() 234 | 235 | with mock.patch.object( 236 | interpreter.explainer, 'predict' 237 | ) as mock_predict: 238 | interpreter.predict(0) 239 | mock_predict.assert_called_once() 240 | mock_predict.assert_called_with(0) 241 | 242 | 243 | if __name__ == "__main__": 244 | unittest.main() 245 | -------------------------------------------------------------------------------- /depiction/interpreters/alibi/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize alibi explainers.""" 2 | from .contrastive import CEM # noqa 3 | from .counterfactual import Counterfactual # noqa -------------------------------------------------------------------------------- /depiction/interpreters/alibi/contrastive/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize contrastive explainer.""" 2 | from .cem import CEM # noqa -------------------------------------------------------------------------------- /depiction/interpreters/alibi/contrastive/cem.py: -------------------------------------------------------------------------------- 1 | """Contrastive Explainability Method (without monotonic attribute functions) 2 | 3 | 4 | References: 5 | https://arxiv.org/abs/1802.07623 6 | """ 7 | from alibi.explainers import CEM as CEMImplementation 8 | 9 | from ....core import DataType, Task 10 | from ...base.base_interpreter import BaseInterpreter 11 | 12 | 13 | class CEM(BaseInterpreter): 14 | """Contrastive Explainability Method 15 | 16 | Wrapper for alibis implementation of CEM, which solves the optimization 17 | problem for finding pertinent positives and negatives using tensorflow. 18 | """ 19 | 20 | SUPPORTED_TASK = {Task.CLASSIFICATION} 21 | SUPPORTED_DATATYPE = {DataType.TABULAR, DataType.IMAGE} 22 | 23 | def __init__( 24 | self, 25 | model, 26 | mode, 27 | shape, 28 | kappa=0., 29 | beta=.1, 30 | feature_range=(-1e10, 1e10), 31 | gamma=0., 32 | ae_model=None, 33 | learning_rate_init=1e-2, 34 | max_iterations=1000, 35 | c_init=10., 36 | c_steps=10, 37 | eps=(1e-3, 1e-3), 38 | clip=(-100., 100.), 39 | update_num_grad=1, 40 | no_info_val=None, 41 | write_dir=None, 42 | sess=None 43 | ): 44 | """Constructor. 45 | 46 | References: 47 | CEM implementation, parameter docstrings and defaults adapted from: 48 | https://github.com/SeldonIO/alibi/blob/92e8048ea2f4e4ef57b6874fa854b90de8ed9602/alibi/explainers/cem.py#L16 # noqa 49 | 50 | The major difference in the constructor signature is that model and 51 | ae_model shoud be instances of depictions BaseModel 52 | instead of type Union[Callable, tf.keras.Model, 'keras.Model'] 53 | 54 | Args: 55 | model (BaseModel): Instance implementing predict method returning 56 | class probabilities that is passed to the explainer 57 | implementation. 58 | mode (str): Find pertinent negatives ('PN') or 59 | pertinent positives ('PP'). 60 | shape (tuple): Shape of input data starting with batch size of 1. 61 | kappa (float, optional): Confidence parameter for the attack loss 62 | term. Defaults to 0.. 63 | beta (float, optional): Regularization constant for L1 loss term. 64 | Defaults to .1. 65 | feature_range (tuple, optional): Tuple with min and max ranges to 66 | allow for perturbed instances. Min and max ranges can be floats 67 | or numpy arrays with dimension (1x nb of features) for 68 | feature-wise ranges. Defaults to (-1e10, 1e10). 69 | gamma (float, optional): Regularization constant for optional 70 | auto-encoder loss term. Defaults to 0.. 71 | ae_model (tf.keras.Model, 'keras.Model', optional): Auto-encoder 72 | model used for loss regularization. Only keras is supported. 73 | Defaults to None. 74 | learning_rate_init (float, optional): Initial learning rate of 75 | optimizer. Defaults to 1e-2. 76 | max_iterations (int, optional): Maximum number of iterations for 77 | finding a PN or PP. Defaults to 1000. 78 | c_init (float, optional): Initial value to scale the attack loss 79 | term. Defaults to 10.. 80 | c_steps (int, optional): Number of iterations to adjust the 81 | constant scaling the attack loss term. Defaults to 10. 82 | eps (tuple, optional): If numerical gradients are used to compute 83 | `dL/dx = (dL/dp) * (dp/dx)`, then eps[0] is used to calculate 84 | `dL/dp` and eps[1] is used for `dp/dx`. eps[0] and eps[1] can 85 | be a combination of float values and numpy arrays. For eps[0], 86 | the array dimension should be (1x nb of prediction categories) 87 | and for eps[1] it should be (1x nb of features). 88 | Defaults to (1e-3, 1e-3). 89 | clip (tuple, optional): Tuple with min and max clip ranges for both 90 | the numerical gradients and the gradients obtained from the 91 | TensorFlow graph. Defaults to (-100., 100.). 92 | update_num_grad (int, optional): If numerical gradients are used, 93 | they will be updated every update_num_grad iterations. 94 | Defaults to 1. 95 | no_info_val (Union[float, np.ndarray], optional): Global or 96 | feature-wise value considered as containing no information. 97 | Defaults to None, in this case fit method needs to be called. 98 | write_dir (str, optional): Directory to write tensorboard files to. 99 | Defaults to None. 100 | sess (tf.compat.v1.Session, optional): Optional Tensorflow session 101 | that will be used if passed instead of creating or inferring 102 | one internally. Defaults to None. 103 | """ 104 | super().__init__(model) 105 | 106 | self.explainer = CEMImplementation( 107 | model.predict, mode, shape, kappa, beta, feature_range, gamma, 108 | ae_model, 109 | learning_rate_init, max_iterations, c_init, c_steps, eps, clip, 110 | update_num_grad, no_info_val, write_dir, sess 111 | ) 112 | 113 | def interpret(self, X, Y=None, verbose=False): 114 | """Explain instance and return PP or PN with metadata. 115 | 116 | Args: 117 | X (np.ndarray): Instances to attack. 118 | Y (np.ndarray, optional): Labels for X. 119 | verbose(bool, optinal): Print intermediate results of optimization 120 | 121 | Returns: 122 | dict: the PP or PN with additional metadata 123 | """ 124 | return self.explainer.explain(X, Y, verbose) 125 | 126 | def fit(self, train_data, no_info_type='median'): 127 | """Get 'no information' values from the training data. 128 | 129 | Args: 130 | train_data (np.ndarray): Representative sample from the training 131 | data. 132 | no_info_type (str, optional): 'median' or 'mean' value by feature 133 | supported. Defaults to 'median'. 134 | """ 135 | self.explainer.fit() 136 | -------------------------------------------------------------------------------- /depiction/interpreters/alibi/contrastive/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/depiction/interpreters/alibi/contrastive/tests/__init__.py -------------------------------------------------------------------------------- /depiction/interpreters/alibi/contrastive/tests/cem_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from sklearn.datasets import load_iris 5 | from sklearn.linear_model import LogisticRegression 6 | 7 | from depiction.core import DataType, Task 8 | from depiction.interpreters.alibi.contrastive.cem import CEM 9 | from depiction.models.base.base_model import BaseModel 10 | 11 | 12 | class SKLearnModel(BaseModel): 13 | 14 | def __init__(self, clf): 15 | super().__init__(Task.CLASSIFICATION, DataType.TABULAR) 16 | self.clf = clf 17 | 18 | def predict(self, X): 19 | return self.clf.predict_proba(X) 20 | 21 | 22 | class CEMTestCase(unittest.TestCase): 23 | """Matching test for implementation source. 24 | 25 | Reference: 26 | https://github.com/SeldonIO/alibi/blob/92e8048ea2f4e4ef57b6874fa854b90de8ed9602/alibi/explainers/tests/test_cem.py#L9 # noqa 27 | """ 28 | 29 | def setUp(self): 30 | dataset = load_iris() 31 | 32 | # scale dataset 33 | dataset.data = (dataset.data - 34 | dataset.data.mean(axis=0)) / dataset.data.std(axis=0) 35 | 36 | # define train and test set 37 | self.X, self.Y = dataset.data, dataset.target 38 | 39 | # fit random forest to training data 40 | np.random.seed(0) 41 | clf = LogisticRegression(solver='liblinear') 42 | clf.fit(self.X, self.Y) 43 | 44 | # define Model 45 | self.depiction_model = SKLearnModel(clf) 46 | 47 | def testInterpretation(self): 48 | """Matching test to source implementation.""" 49 | 50 | # instance to be explained 51 | idx = 0 52 | X_expl = np.expand_dims(self.X[idx], axis=0) 53 | 54 | # test explainer initialization 55 | shape = (1, 4) # seems first entry (batch_size) must be 1 56 | feature_range = ( 57 | self.X.min(axis=0).reshape(shape) - .1, 58 | self.X.max(axis=0).reshape(shape) + .1 59 | ) 60 | 61 | def test_mode(mode): 62 | interpreter = CEM( 63 | self.depiction_model, 64 | mode, 65 | shape, 66 | feature_range=feature_range, 67 | max_iterations=10, 68 | no_info_val=-1. 69 | ) 70 | explanation = interpreter.interpret(X_expl, verbose=False) 71 | 72 | cem = interpreter.explainer 73 | self.assertIs(cem.model, False) 74 | if cem.best_attack: 75 | self.assertGreaterEqual( 76 | set(explanation.keys()), 77 | { 78 | 'X', 'X_pred', mode, f'{mode}_pred', 'grads_graph', 79 | 'grads_num' 80 | } # noqa 81 | ) 82 | self.assertGreater( 83 | (explanation['X'] != explanation[mode]).astype(int).sum(), 84 | 0 85 | ) 86 | self.assertNotEqual( 87 | explanation['X_pred'], explanation[f'{mode}_pred'] 88 | ) 89 | self.assertEqual( 90 | explanation['grads_graph'].shape, 91 | explanation['grads_num'].shape 92 | ) 93 | else: 94 | self.assertGreaterEqual( 95 | set(explanation.keys()), {'X', 'X_pred'} 96 | ) 97 | 98 | for mode in ('PN', 'PP'): 99 | test_mode(mode) 100 | 101 | 102 | if __name__ == "__main__": 103 | unittest.main() 104 | -------------------------------------------------------------------------------- /depiction/interpreters/alibi/counterfactual/__init__.py: -------------------------------------------------------------------------------- 1 | """Intialize conterfactual module.""" 2 | from .counterfactual import Counterfactual # noqa -------------------------------------------------------------------------------- /depiction/interpreters/alibi/counterfactual/counterfactual.py: -------------------------------------------------------------------------------- 1 | """Counterfactual explanation method based on Wachter et al. (2017) 2 | 3 | 4 | References: 5 | https://arxiv.org/ftp/arxiv/papers/1711/1711.00399.pdf 6 | """ 7 | from alibi.explainers import CounterFactual 8 | 9 | from ....core import DataType, Task 10 | from ...base.base_interpreter import BaseInterpreter 11 | 12 | 13 | class Counterfactual(BaseInterpreter): 14 | """Counterfactual explanation. 15 | 16 | Wrapper for alibis implementation of counterfactual exaplanation. 17 | """ 18 | 19 | SUPPORTED_TASK = {Task.CLASSIFICATION} 20 | SUPPORTED_DATATYPE = {DataType.TABULAR, DataType.IMAGE} 21 | 22 | def __init__( 23 | self, 24 | model, 25 | shape, 26 | distance_fn='l1', 27 | target_proba=1.0, 28 | target_class='other', 29 | max_iter=1000, 30 | early_stop=50, 31 | lam_init=1e-1, 32 | max_lam_steps=10, 33 | tol=0.05, 34 | learning_rate_init=0.1, 35 | feature_range=(-1e10, 1e10), 36 | eps=0.01, # feature-wise epsilons 37 | init='identity', 38 | decay=True, 39 | write_dir=None, 40 | debug=False, 41 | sess=None 42 | ): 43 | """Constructor. 44 | 45 | References: 46 | Counterfactual explanation implementation, parameter docstrings and defaults adapted from: 47 | https://github.com/SeldonIO/alibi/blob/14804f07457da881a5f70ccff2dcbfed2378b860/alibi/explainers/counterfactual.py#L85 # noqa 48 | 49 | The major difference in the constructor signature is that model and 50 | ae_model shoud be instances of depictions BaseModel 51 | instead of type Union[Callable, tf.keras.Model, 'keras.Model'] 52 | 53 | Args: 54 | model (BaseModel): Instance implementing predict method returning 55 | class probabilities that is passed to the explainer 56 | implementation. 57 | shape (tuple): Shape of input data starting with batch size. 58 | distance_fn (str, optional): Distance function to use in the loss term. Defaults to 'l1'. 59 | target_proba (float, optional): Target probability for the counterfactual to reach. Defaults to 1.0. 60 | target_class (Union[str, int], optional): Target class for the counterfactual to reach, one of 'other', 61 | 'same' or an integer denoting desired class membership for the counterfactual instance. Defaults to 'other'. 62 | max_iter (int, optional): Maximum number of interations to run the gradient descent for (inner loop). 63 | Defaults to 1000. 64 | early_stop (int, optional): Number of steps after which to terminate gradient descent if all or none of found 65 | instances are solutions. Defaults to 50. 66 | lam_init (float, optional): Initial regularization constant for the prediction part of the Wachter loss. 67 | Defaults to 1e-1. 68 | max_lam_steps (int, optional): Maximum number of times to adjust the regularization constant (outer loop) 69 | before terminating the search. Defaults to 10. 70 | tol (float, optional): Tolerance for the counterfactual target probability. Defaults to 0.05. 71 | learning_rate_init (float, optional): Initial learning rate for each outer loop of lambda. Defaults to 0.1. 72 | feature_range (Union[Tuple, str], optional): Tuple with min and max ranges to allow for perturbed instances. 73 | Min and max ranges can be floats or numpy arrays with dimension (1 x nb of features) 74 | for feature-wise ranges. Defaults to (-1e10, 1e10). 75 | eps (Union[float, np.ndarray], optional): Gradient step sizes used in calculating numerical gradients, 76 | defaults to a single value for all features, but can be passed an array for 77 | feature-wise step sizes. Defaults to 0.01. 78 | init (str): Initialization method for the search of counterfactuals, currently must be 'identity'. 79 | decay (bool, optional): Flag to decay learning rate to zero for each outer loop over lambda. 80 | Defaults to True. 81 | write_dir (str, optional): Directory to write Tensorboard files to. Defaults to None. 82 | debug (bool, optional): Flag to write Tensorboard summaries for debugging. Defaults to False. 83 | sess (tf.compat.v1.Session, optional): Optional Tensorflow session that will be used if passed 84 | instead of creating or inferring one internally. Defaults to None. 85 | """ 86 | super().__init__(model) 87 | 88 | self.explainer = CounterFactual( 89 | model.predict, shape, distance_fn, target_proba, target_class, 90 | max_iter, early_stop, lam_init, max_lam_steps, tol, 91 | learning_rate_init, feature_range, eps, init, decay, write_dir, 92 | debug, sess 93 | ) 94 | 95 | def interpret(self, X): 96 | """ 97 | Explain an instance and return the counterfactual with metadata. 98 | 99 | Args: 100 | X (np.ndarray): Instance to be explained. 101 | 102 | Returns: 103 | dict: a dictionary containing the counterfactual 104 | with additional metadata. 105 | """ 106 | return self.explainer.explain(X) 107 | 108 | def fit(self, X=None, y=None): 109 | """ 110 | Since the interpreter is unsupervised the method 111 | is not doing anything. 112 | 113 | Args: 114 | X (np.ndarray): training data. Defaults to None. 115 | y (np.ndarray): optional labels. Defaults to None. 116 | """ 117 | self.explainer.fit(X, y) 118 | -------------------------------------------------------------------------------- /depiction/interpreters/alibi/counterfactual/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/depiction/interpreters/alibi/counterfactual/tests/__init__.py -------------------------------------------------------------------------------- /depiction/interpreters/alibi/counterfactual/tests/counterfactual_test.py: -------------------------------------------------------------------------------- 1 | """Counterfactual explanation test.""" 2 | import unittest 3 | 4 | import numpy as np 5 | from sklearn.datasets import load_iris 6 | from sklearn.linear_model import LogisticRegression 7 | 8 | from depiction.core import DataType, Task 9 | from depiction.interpreters.alibi.counterfactual import Counterfactual 10 | from depiction.models.base.base_model import BaseModel 11 | 12 | 13 | class SKLearnModel(BaseModel): 14 | 15 | def __init__(self, clf): 16 | super().__init__(Task.CLASSIFICATION, DataType.TABULAR) 17 | self.clf = clf 18 | 19 | def predict(self, X): 20 | return self.clf.predict_proba(X) 21 | 22 | 23 | class CEMTestCase(unittest.TestCase): 24 | """Matching test for implementation source. 25 | 26 | Reference: 27 | https://github.com/SeldonIO/alibi/blob/92e8048ea2f4e4ef57b6874fa854b90de8ed9602/alibi/explainers/tests/test_cem.py#L9 # noqa 28 | """ 29 | 30 | def setUp(self): 31 | dataset = load_iris() 32 | 33 | # scale dataset 34 | dataset.data = (dataset.data - 35 | dataset.data.mean(axis=0)) / dataset.data.std(axis=0) 36 | 37 | # define train and test set 38 | self.X, self.Y = dataset.data, dataset.target 39 | 40 | # fit random forest to training data 41 | np.random.seed(0) 42 | clf = LogisticRegression(solver='liblinear') 43 | clf.fit(self.X, self.Y) 44 | 45 | # define Model 46 | self.depiction_model = SKLearnModel(clf) 47 | 48 | def testInterpretation(self): 49 | """Matching test to source implementation.""" 50 | 51 | # instance to be explained 52 | idx = 0 53 | X_to_interpret = np.expand_dims(self.X[idx], axis=0) 54 | 55 | # test explainer initialization 56 | shape = (1, 4) # seems first entry (batch_size) must be 1 57 | 58 | interpreter = Counterfactual(self.depiction_model, shape) 59 | explanation = interpreter.interpret(X_to_interpret) 60 | 61 | counterfactual = interpreter.explainer 62 | self.assertEqual( 63 | counterfactual.return_dict['meta']['name'], 64 | counterfactual.__class__.__name__ 65 | ) 66 | self.assertEqual(explanation['cf']['X'].shape, X_to_interpret.shape) 67 | self.assertEqual(len(explanation['all']), counterfactual.max_lam_steps) 68 | 69 | 70 | if __name__ == "__main__": 71 | unittest.main() 72 | -------------------------------------------------------------------------------- /depiction/interpreters/backprop/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/depiction/interpreters/backprop/__init__.py -------------------------------------------------------------------------------- /depiction/interpreters/backprop/backpropeter.py: -------------------------------------------------------------------------------- 1 | """Backpropagation-like methods for interpretability 2 | 3 | Wrapper around: 4 | - (pytorch) Captum [1] 5 | - (keras) DeepExplain [2] 6 | 7 | References: 8 | [1] https://captum.ai/ 9 | [2] https://arxiv.org/abs/1711.06104 10 | """ 11 | from captum.attr import IntegratedGradients, Saliency, DeepLift,\ 12 | DeepLiftShap, GradientShap, InputXGradient 13 | from tensorflow.keras import backend as K 14 | from deepexplain.tensorflow import DeepExplain 15 | from tensorflow.keras.models import Model 16 | from deepexplain.tensorflow.methods import attribution_methods 17 | from copy import deepcopy 18 | import warnings 19 | from captum.attr import visualization as viz 20 | import numpy as np 21 | from matplotlib.colors import Normalize 22 | 23 | from ...core import DataType, Task 24 | from ..base.base_interpreter import BaseInterpreter 25 | from depiction.models.torch.core import TorchModel 26 | from depiction.models.keras.core import KerasModel 27 | 28 | 29 | def _preprocess_att_methods_keras(): 30 | methods = deepcopy(attribution_methods) 31 | methods.pop('deeplift') 32 | return methods 33 | 34 | 35 | class BackPropeter(BaseInterpreter): 36 | """Backpropagation-like Explainability Method 37 | 38 | Wrapper for Captum and DeepExplain implementations. 39 | """ 40 | SUPPORTED_TASK = {Task.CLASSIFICATION} 41 | SUPPORTED_DATATYPE = {DataType.TABULAR, DataType.IMAGE, DataType.TEXT} 42 | 43 | 44 | METHODS = { 45 | 'torch': { 46 | 'integrated_grads': IntegratedGradients, 47 | 'saliency': Saliency, 48 | 'deeplift': DeepLift, 49 | 'deeplift_shap': DeepLiftShap, 50 | 'gradient_shap': GradientShap, 51 | 'inputxgrad': InputXGradient 52 | }, 53 | 'keras': _preprocess_att_methods_keras() 54 | } 55 | 56 | @classmethod 57 | def _check_supported_method(self, model_type, method): 58 | if method not in self.METHODS[model_type]: 59 | raise ValueError('Method {} not supported! At the moment we only support: {}.'.format( 60 | method,self.METHODS[model_type].keys())) 61 | 62 | def __init__(self, model, method, **method_kwargs): 63 | """ 64 | Constructor for backpropagation-like methods. 65 | 66 | Reference: 67 | https://captum.ai/api/attribution.html 68 | 69 | Args: 70 | model (TorchModel or KerasModel): model to explain 71 | method (str): method to use 72 | method_kwargs: keyword args to pass on to the explainer constrcutor. 73 | Please refer to the the specific algorithm (following the above link) 74 | to see and understand the available arguments. 75 | """ 76 | super(BackPropeter, self).__init__(model) 77 | 78 | self._model = model 79 | self._method = method 80 | 81 | if isinstance(self._model, TorchModel): 82 | self._check_supported_method('torch', method) 83 | self._explainer = self.METHODS['torch'][method](self._model._model, **method_kwargs) 84 | elif isinstance(self._model, KerasModel): 85 | self._check_supported_method('keras', method) 86 | else: 87 | raise ValueError('Model not supported! At the moment we only support {}.' 88 | '\nPlease check again in the future!'.format(self.METHODS.keys())) 89 | 90 | def interpret(self, samples, target_layer=-1, show_in_notebook=False, 91 | explanation_configs={}, 92 | vis_configs={}): 93 | """Explain instance and return PP or PN with metadata. If pyTorch (captum) is used, 94 | the convergence delta is NOT returned by default. 95 | 96 | Args: 97 | samples (tensor or tuple of tensors): Samples to explain 98 | target_layer (int): for KerasModel, specify the target layer. 99 | Following example in: https://github.com/marcoancona/DeepExplain/blob/master/examples/mint_cnn_keras.ipynb 100 | interpret_kwargs (optinal): optional arguments to pass to the explainer for attribution 101 | 102 | Returns: 103 | tensor (or tuple of tensors) containing attributions 104 | """ 105 | if isinstance(self._model, TorchModel): 106 | if self._explainer.has_convergence_delta() and 'return_convergence_delta' not in explanation_configs: 107 | explanation_configs['return_convergence_delta'] = False 108 | explanation = self._explainer.attribute(inputs=self._model._prepare_sample(samples), **explanation_configs) 109 | if show_in_notebook: 110 | if 'return_convergence_delta' in explanation_configs and explanation_configs['return_convergence_delta']: 111 | exp = explanation[0] 112 | else: 113 | exp = explanation 114 | exp = np.transpose(exp.detach().numpy()[0], (1,2,0)) 115 | normalizer = Normalize() 116 | if 'method' not in vis_configs: 117 | vis_configs['method'] = 'masked_image' 118 | viz.visualize_image_attr(exp, normalizer(samples[0]), **vis_configs) 119 | 120 | return explanation 121 | else: 122 | with DeepExplain(session=K.get_session()) as de: 123 | input_tensor = self._model._model.inputs 124 | smpls = samples if isinstance(samples, list) else [samples] 125 | if self._method in {'occlusion', 'shapley_sampling'}: 126 | warnings.warn('For perturbation methods, multiple inputs (modalities) are not supported.', UserWarning) 127 | smpls = smpls[0] 128 | input_tensor = input_tensor[0] 129 | 130 | model = Model(inputs=input_tensor, outputs=self._model._model.outputs) 131 | target_tensor = model(input_tensor) 132 | 133 | if show_in_notebook: 134 | warnings.warn('Sorry! Visualization not implemented yet!', UserWarning) 135 | 136 | return de.explain(self._method, T=target_tensor, X=input_tensor, xs=smpls, **explanation_configs) 137 | -------------------------------------------------------------------------------- /depiction/interpreters/backprop/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/depiction/interpreters/backprop/tests/__init__.py -------------------------------------------------------------------------------- /depiction/interpreters/backprop/tests/backpropeter_test.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.simplefilter(action='ignore') 3 | 4 | import torch 5 | import inspect 6 | import unittest 7 | import numpy as np 8 | from torch import nn 9 | from tensorflow.keras.layers import Dense 10 | from unittest.mock import patch 11 | import torch.nn.functional as F 12 | from tensorflow.keras.models import Sequential 13 | 14 | from depiction.core import DataType, Task 15 | from depiction.models.torch.core import TorchModel 16 | from depiction.models.keras.core import KerasModel 17 | from depiction.models.base.base_model import BaseModel 18 | from depiction.interpreters.backprop.backpropeter import BackPropeter 19 | 20 | 21 | INPUT_SZ = 5 22 | OUTPUT_SZ = 7 23 | 24 | 25 | class DummyModel(BaseModel): 26 | def predict(self, sample): 27 | return None 28 | 29 | 30 | class DummyTorchModel(nn.Module): 31 | """ 32 | From https://github.com/pytorch/captum#getting-started 33 | """ 34 | def __init__(self): 35 | super().__init__() 36 | self.lin1 = nn.Linear(INPUT_SZ, 3) 37 | self.relu = nn.ReLU() 38 | self.lin2 = nn.Linear(3, OUTPUT_SZ) 39 | 40 | # initialize weights and biases 41 | self.lin1.weight = nn.Parameter(torch.arange(-4.0, -4.0 + 3.0*np.float(INPUT_SZ)).view(3, INPUT_SZ)) 42 | self.lin1.bias = nn.Parameter(torch.zeros(1,3)) 43 | self.lin2.weight = nn.Parameter(torch.arange(-4.0, -4.0 + 3.0*np.float(OUTPUT_SZ)).view(OUTPUT_SZ, 3)) 44 | self.lin2.bias = nn.Parameter(torch.ones(1,OUTPUT_SZ)) 45 | 46 | def forward(self, input): 47 | return self.lin2(self.relu(self.lin1(input))) 48 | 49 | 50 | class BackPropeterTestCase(unittest.TestCase): 51 | """ 52 | Test class for back-propagation like attribution methods 53 | """ 54 | def setUp(self): 55 | self._available_model_types = list(BackPropeter.METHODS.keys()) 56 | self._test_data_type = np.random.choice([d for d in DataType]) 57 | self._test_task_type = Task.CLASSIFICATION 58 | 59 | torch_method_name = np.random.choice(list(BackPropeter.METHODS['torch'].keys())) 60 | self._torch = { 61 | 'model': TorchModel(DummyTorchModel(), self._test_task_type, self._test_data_type), 62 | 'method_name': torch_method_name, 63 | 'method_class': BackPropeter.METHODS['torch'][torch_method_name], 64 | 'method_classname': BackPropeter.METHODS['torch'][torch_method_name].__name__ 65 | } 66 | 67 | self._keras = { 68 | 'model': KerasModel(self._build_keras_model(), self._test_task_type, self._test_data_type), 69 | 'method_name': np.random.choice(list(BackPropeter.METHODS['keras'].keys())) 70 | } 71 | 72 | def testMethodCheck(self): 73 | model_type = np.random.choice(self._available_model_types) 74 | with self.assertRaises(ValueError): 75 | BackPropeter._check_supported_method(model_type, 'dummy_test_algo') 76 | 77 | def _build_keras_model(self): 78 | """ 79 | From https://machinelearningmastery.com/multi-class-classification-tutorial-keras-deep-learning-library/ 80 | """ 81 | # create model 82 | model = Sequential() 83 | model.add(Dense(8, input_dim=INPUT_SZ, activation='relu')) 84 | model.add(Dense(OUTPUT_SZ, activation='softmax')) 85 | # Compile model 86 | model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) 87 | return model 88 | 89 | def testConstructor(self): 90 | # test invalid model 91 | with self.assertRaises(ValueError): 92 | BackPropeter(DummyModel(self._test_task_type, self._test_data_type), 'doesnt matter') 93 | 94 | # test torch model 95 | with patch.object(BackPropeter, '_check_supported_method') as mock_check: 96 | class_constructor = 'depiction.interpreters.backprop.backpropeter.{}.__init__'.format( 97 | self._torch['method_classname']) 98 | with patch(class_constructor, return_value=None) as mock_init: 99 | interpreter = BackPropeter(self._torch['model'], self._torch['method_name']) 100 | self.assertIs(type(interpreter._explainer), self._torch['method_class']) 101 | mock_check.assert_called_once_with('torch', self._torch['method_name']) 102 | mock_init.assert_called_once_with(self._torch['model']._model) 103 | 104 | # test keras model 105 | with patch.object(BackPropeter, '_check_supported_method') as mock_check: 106 | interpreter = BackPropeter(self._keras['model'], self._keras['method_name']) 107 | mock_check.assert_called_once_with('keras', self._keras['method_name']) 108 | 109 | def testInterpret(self): 110 | batch_size = np.random.choice(10) + 1 111 | 112 | # test torch model 113 | for m in BackPropeter.METHODS['torch'].keys(): 114 | interpreter = BackPropeter(self._torch['model'], self._torch['method_name']) 115 | args = inspect.signature(interpreter._explainer.attribute).parameters.keys() 116 | 117 | x = np.random.rand(batch_size, INPUT_SZ) 118 | output = torch.tensor(self._torch['model'].predict(x)) 119 | output = F.softmax(output, dim=1) 120 | prediction_score, pred_label_idx = torch.topk(output, 1, dim=1) 121 | pred_label_idx = pred_label_idx.squeeze() 122 | 123 | # -- -- without delta 124 | interpret_kwargs = { 125 | 'baselines': torch.zeros(batch_size, INPUT_SZ), 126 | 'target': pred_label_idx, 127 | } 128 | allowed_kwargs = interpret_kwargs.keys() & set(args) 129 | res = interpreter.interpret(x, explanation_configs={arg: interpret_kwargs[arg] for arg in allowed_kwargs}) 130 | self.assertIsInstance(res, torch.Tensor) 131 | 132 | # -- -- with delta 133 | if interpreter._explainer.has_convergence_delta(): 134 | interpret_kwargs['return_convergence_delta'] = True 135 | allowed_kwargs = interpret_kwargs.keys() & set(args) 136 | res = interpreter.interpret(x, explanation_configs={arg: interpret_kwargs[arg] for arg in allowed_kwargs}) 137 | self.assertIsInstance(res, tuple) 138 | self.assertIsInstance(res[0], torch.Tensor) 139 | self.assertIsInstance(res[1], torch.Tensor) 140 | 141 | # test keras model 142 | for m in BackPropeter.METHODS['keras'].keys(): 143 | interpreter = BackPropeter(self._keras['model'], m) 144 | x = np.random.rand(batch_size, INPUT_SZ) 145 | res1 = interpreter.interpret([x]) 146 | res2 = interpreter.interpret(x) -------------------------------------------------------------------------------- /depiction/interpreters/base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/depiction/interpreters/base/__init__.py -------------------------------------------------------------------------------- /depiction/interpreters/base/base_interpreter.py: -------------------------------------------------------------------------------- 1 | """Core utilities for handling interpreters.""" 2 | from enum import Enum 3 | from abc import ABC, abstractmethod 4 | 5 | from ...core import Task 6 | from ...models.base.base_model import BaseModel, TrainableModel 7 | 8 | 9 | class TransparencyType(Enum): 10 | """Enum denoting black-box or white-box.""" 11 | BLACK_BOX = 1 12 | WHITE_BOX = 2 13 | 14 | 15 | class ExplanationType(Enum): 16 | """Enum denoting type of the explanation.""" 17 | LOCAL = 1 18 | GLOBAL = 2 19 | 20 | 21 | class BaseInterpreter(ABC): 22 | SUPPORTED_TASK = {} 23 | SUPPORTED_DATATYPE = {} 24 | 25 | def __init__(self, model): 26 | """Constructor checking validity of the model.""" 27 | if not isinstance(model, BaseModel): 28 | raise TypeError( 29 | 'For safe use of this library, please wrap this model into a BaseModel!' 30 | ) 31 | 32 | if not Task.check_support(model.task, self.SUPPORTED_TASK): 33 | raise ValueError( 34 | 'Interpreter does not support the task of the provided model!' 35 | ) 36 | 37 | if model.data_type not in self.SUPPORTED_DATATYPE: 38 | raise ValueError( 39 | 'Interpreter does not support the task of the provided model!' 40 | ) 41 | 42 | @abstractmethod 43 | def interpret(self, *args, **kwarg): 44 | """ 45 | Interface to interpret a model. 46 | """ 47 | raise NotImplementedError 48 | 49 | 50 | class AnteHocInterpreter(BaseInterpreter, TrainableModel): 51 | 52 | class UsageMode(Enum): 53 | """Enum indicating use modality since antehoc method could be used in a posthoc fashion.""" 54 | ANTE_HOC = 1 55 | POST_HOC = 2 56 | 57 | def __init__(self, usage_mode, model=None, task_type=None, data_type=None): 58 | """Constructor. Checks consistency among arguments.""" 59 | self.usage_mode = usage_mode 60 | if self.usage_mode == self.UsageMode.ANTE_HOC: 61 | if task_type is None or data_type is None: 62 | raise ValueError( 63 | "If using this model in ante-hoc mode, please provide task and data types!" 64 | ) 65 | TrainableModel.__init__(self, task_type, data_type) 66 | else: 67 | if model is None: 68 | raise ValueError( 69 | "Please provide a model to post-hoc interpret!" 70 | ) 71 | else: 72 | BaseInterpreter.__init__(self, model) 73 | 74 | self._to_interpret = model 75 | TrainableModel.__init__(self, model.task, model.data_type) 76 | 77 | def fit(self, *args, **kwargs): 78 | """Training routine. Implements the antehoc vs posthoc logic.""" 79 | if self.usage_mode == self.UsageMode.ANTE_HOC: 80 | self._fit_antehoc(*args, **kwargs) 81 | else: 82 | self._fit_posthoc(*args, **kwargs) 83 | 84 | @abstractmethod 85 | def _fit_antehoc(self, *args, **kwargs): 86 | raise NotImplementedError 87 | 88 | @abstractmethod 89 | def _fit_posthoc(self, *args, **kwargs): 90 | raise NotImplementedError 91 | -------------------------------------------------------------------------------- /depiction/interpreters/base/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/depiction/interpreters/base/tests/__init__.py -------------------------------------------------------------------------------- /depiction/interpreters/base/tests/base_interpreter_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from random import choice 3 | from unittest import mock 4 | 5 | from depiction.core import DataType, Task 6 | from depiction.interpreters.base.base_interpreter import ( 7 | AnteHocInterpreter, BaseInterpreter 8 | ) 9 | from depiction.models.base.base_model import BaseModel 10 | 11 | 12 | class ConcreteBaseInterpreter(BaseInterpreter): 13 | 14 | def interpret(self): 15 | return 16 | 17 | 18 | class ConcreteAnteHocInterpreter(AnteHocInterpreter): 19 | SUPPORTED_TASK = set(Task) 20 | SUPPORTED_DATATYPE = set(DataType) 21 | 22 | def predict(self, sample): 23 | return sample 24 | 25 | def _fit_antehoc(self, X, y): 26 | return X, y 27 | 28 | def _fit_posthoc(self, X, y): 29 | return X, y 30 | 31 | def interpret(self, sample): 32 | return sample 33 | 34 | 35 | class DummyModel(BaseModel): 36 | 37 | def predict(self, sample): 38 | return sample 39 | 40 | 41 | class BaseInterpreterTestCase(unittest.TestCase): 42 | 43 | def testConstructor(self): 44 | with self.assertRaises(TypeError): 45 | interpreter = ConcreteBaseInterpreter(0) 46 | 47 | with self.assertRaises(ValueError): 48 | interpreter = ConcreteBaseInterpreter( 49 | DummyModel(choice(list(Task)), choice(list(DataType))) 50 | ) 51 | 52 | 53 | class AnteHocInterpreterTestCase(unittest.TestCase): 54 | 55 | def testConstructor(self): 56 | # - antehoc mode 57 | for task_type in Task: 58 | for data_type in DataType: 59 | # -- expected inputs 60 | interpreter = ConcreteAnteHocInterpreter( 61 | AnteHocInterpreter.UsageMode.ANTE_HOC, 62 | task_type=task_type, 63 | data_type=data_type 64 | ) 65 | 66 | # -- missing task or data 67 | with self.assertRaises(ValueError): 68 | interpreter = ConcreteAnteHocInterpreter( 69 | AnteHocInterpreter.UsageMode.ANTE_HOC 70 | ) 71 | 72 | # - posthoc mode 73 | # -- calling base interpreter constructor 74 | def dummy_init(self, model): 75 | return None 76 | 77 | with mock.patch( 78 | 'depiction.interpreters.base.base_interpreter.BaseModel.__init__', 79 | side_effect=dummy_init 80 | ) as mock_par_constructor: 81 | try: 82 | interpreter = ConcreteAnteHocInterpreter( 83 | AnteHocInterpreter.UsageMode.POST_HOC, 84 | model=DummyModel( 85 | choice(list(Task)), choice(list(DataType)) 86 | ) 87 | ) 88 | except: 89 | pass 90 | 91 | mock_par_constructor.assert_called_once() 92 | 93 | # -- expected inputs 94 | for task_type in Task: 95 | for data_type in DataType: 96 | model = DummyModel(task_type, data_type) 97 | interpreter = ConcreteAnteHocInterpreter( 98 | AnteHocInterpreter.UsageMode.POST_HOC, model=model 99 | ) 100 | self.assertTrue(hasattr(interpreter, '_to_interpret')) 101 | self.assertEqual(interpreter._to_interpret.task, task_type) 102 | self.assertEqual( 103 | interpreter._to_interpret.data_type, data_type 104 | ) 105 | 106 | # -- missing model 107 | with self.assertRaises(ValueError): 108 | interpreter = ConcreteAnteHocInterpreter( 109 | AnteHocInterpreter.UsageMode.POST_HOC 110 | ) 111 | 112 | def testFit(self): 113 | # antehoc mode 114 | interpreter = ConcreteAnteHocInterpreter( 115 | AnteHocInterpreter.UsageMode.ANTE_HOC, 116 | task_type=choice(list(Task)), 117 | data_type=choice(list(DataType)) 118 | ) 119 | 120 | with mock.patch.object(interpreter, '_fit_antehoc') as mock_fit: 121 | interpreter.fit(0, 0) 122 | mock_fit.assert_called_with(0, 0) 123 | 124 | model = DummyModel(choice(list(Task)), choice(list(DataType))) 125 | interpreter = ConcreteAnteHocInterpreter( 126 | AnteHocInterpreter.UsageMode.POST_HOC, model=model 127 | ) 128 | 129 | with mock.patch.object(interpreter, '_fit_posthoc') as mock_fit: 130 | interpreter.fit(0, 0) 131 | mock_fit.assert_called_with(0, 0) 132 | 133 | 134 | if __name__ == "__main__": 135 | unittest.main() 136 | -------------------------------------------------------------------------------- /depiction/interpreters/u_wash/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize University of Washington models.""" 2 | from .u_washer import UWasher # noqa -------------------------------------------------------------------------------- /depiction/interpreters/u_wash/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/depiction/interpreters/u_wash/tests/__init__.py -------------------------------------------------------------------------------- /depiction/interpreters/u_wash/tests/u_washer_test.py: -------------------------------------------------------------------------------- 1 | """UWasher tests.""" 2 | import unittest 3 | from unittest import mock 4 | 5 | import numpy as np 6 | 7 | from depiction.core import DataType, Task 8 | from depiction.interpreters.u_wash.u_washer import UWasher 9 | from depiction.models.base.base_model import BaseModel 10 | 11 | 12 | class DummyModel(BaseModel): 13 | 14 | def predict(self, sample): 15 | return sample 16 | 17 | 18 | class UWasherTestCase(unittest.TestCase): 19 | 20 | def testConstructor(self): 21 | 22 | def dummy_init(*args, **kwargs): 23 | return None 24 | 25 | def dummy_fit(*args, **kwargs): 26 | return None 27 | 28 | def test_routine(explainer_key, explainer_cls, model, **kwargs): 29 | with mock.patch( 30 | 'depiction.interpreters.u_wash.u_washer.{}.__init__'. 31 | format(explainer_cls), 32 | side_effect=dummy_init 33 | ) as mock_constructor: 34 | interpreter = UWasher(explainer_key, model, **kwargs) 35 | self.assertTrue(interpreter.model is model) 36 | mock_constructor.assert_called_once() 37 | 38 | task_type = Task.CLASSIFICATION 39 | 40 | data_type = DataType.TEXT 41 | model = DummyModel(task_type, data_type) 42 | test_routine('lime', 'LimeTextExplainer', model) 43 | test_routine('anchors', 'AnchorText', model) 44 | 45 | data_type = DataType.TABULAR 46 | model = DummyModel(task_type, data_type) 47 | test_routine('lime', 'LimeTabularExplainer', model) 48 | with mock.patch( 49 | 'depiction.interpreters.u_wash.u_washer.AnchorTabularExplainer.fit', 50 | side_effect=dummy_fit 51 | ) as mock_fit: 52 | test_routine('anchors', 'AnchorTabularExplainer', model, 53 | # dummy data that will not be used (because 54 | # of the mock) but required from constructor 55 | train_data = [], 56 | train_labels = [], 57 | validation_data = [], 58 | validation_labels = []) 59 | mock_fit.assert_called_once() 60 | 61 | data_type = DataType.IMAGE 62 | model = DummyModel(task_type, data_type) 63 | test_routine('lime', 'LimeImageExplainer', model) 64 | test_routine('anchors', 'AnchorImage', model) 65 | 66 | def testInterpret(self): 67 | task_type = Task.CLASSIFICATION 68 | data_type = DataType.TABULAR 69 | model = DummyModel(task_type, data_type) 70 | 71 | class DummyExplanation: 72 | SHOW_IN_NOTEBOOK = False 73 | PATH = '' 74 | 75 | def show_in_notebook(self): 76 | self.SHOW_IN_NOTEBOOK = True 77 | 78 | def save_to_file(self, path): 79 | self.PATH = path 80 | 81 | def dummy_interpret(*args, **kwargs): 82 | return DummyExplanation() 83 | 84 | interpreter = UWasher( 85 | 'lime', model, **{'training_data': np.array([[0, 0], [1, 1]])} 86 | ) 87 | test_config = {'dummy_config': 10} 88 | test_callback_kwargs = {} 89 | dummy_sample = [10, 15] 90 | 91 | with mock.patch.object( 92 | interpreter.explainer, 93 | 'explain_instance', 94 | side_effect=dummy_interpret 95 | ) as mock_explain: 96 | explanation = interpreter.interpret( 97 | dummy_sample, test_callback_kwargs, test_config, path=None 98 | ) 99 | mock_explain.assert_called_once() 100 | self.assertEqual(explanation.SHOW_IN_NOTEBOOK, True) 101 | self.assertEqual(explanation.PATH, '') 102 | 103 | with mock.patch.object( 104 | interpreter.explainer, 105 | 'explain_instance', 106 | side_effect=dummy_interpret 107 | ) as mock_explain: 108 | dummy_path = 'tests' 109 | explanation = interpreter.interpret( 110 | dummy_sample, 111 | test_callback_kwargs, 112 | test_config, 113 | path=dummy_path 114 | ) 115 | mock_explain.assert_called_once() 116 | self.assertEqual(explanation.SHOW_IN_NOTEBOOK, False) 117 | self.assertEqual(explanation.PATH, dummy_path) 118 | 119 | 120 | if __name__ == "__main__": 121 | unittest.main() 122 | -------------------------------------------------------------------------------- /depiction/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/depiction/models/__init__.py -------------------------------------------------------------------------------- /depiction/models/base/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize base module.""" 2 | from .base_model import BaseModel # noqa 3 | from .binarized_model import BinarizedClassifier # noqa 4 | from .utils import get_model_file # noqa -------------------------------------------------------------------------------- /depiction/models/base/base_model.py: -------------------------------------------------------------------------------- 1 | """Abstract interface for models.""" 2 | from abc import ABC, abstractmethod 3 | 4 | from ...core import Task, DataType 5 | 6 | 7 | class BaseModel(ABC): 8 | """Abstract implementation of a model.""" 9 | 10 | def __init__(self, task, data_type): 11 | """ 12 | Initialize a Model. 13 | 14 | Args: 15 | task (depiction.core.Task): task type. 16 | data_type (depiction.core.DataType): data type. 17 | """ 18 | if not isinstance(task, Task) or not isinstance(data_type, DataType): 19 | raise TypeError("Inputs must be valid Task and DataType types!") 20 | 21 | self.task = task 22 | self.data_type = data_type 23 | 24 | def callback(self, *args, **kwargs): 25 | """ 26 | Return a callback function that can be called directly on the samples. 27 | The additional arguments are wrapped and embedded in the function call. 28 | 29 | Args: 30 | kwargs (dict): list of key-value arguments. 31 | 32 | Returns: 33 | a function taking a sample an input and returning the prediction. 34 | """ 35 | return lambda sample: self.predict(sample, *args, **kwargs) 36 | 37 | @abstractmethod 38 | def predict(self, sample, *args, **kwargs): 39 | """ 40 | Run the model for inference on a given sample and with the provided 41 | parameters. 42 | 43 | Args: 44 | sample (object): an input sample for the model. 45 | args (list): list of arguments. 46 | kwargs (dict): list of key-value arguments. 47 | 48 | Returns: 49 | a prediction for the model on the given sample. 50 | """ 51 | raise NotImplementedError 52 | 53 | def predict_many(self, samples, *args, **kwargs): 54 | """ 55 | Run the model for inference on the given samples and with the provided 56 | parameters. 57 | 58 | Args: 59 | samples (Iterable): input samples for the model. 60 | args (list): list of arguments. 61 | kwargs (dict): list of key-value arguments. 62 | 63 | Returns: 64 | a generator of predictions. 65 | """ 66 | for sample in samples: 67 | yield self.predict(sample, *args, **kwargs) 68 | 69 | 70 | class TrainableModel(BaseModel): 71 | """Interface for trainable models.""" 72 | 73 | @abstractmethod 74 | def fit(self, *args, **kwargs): 75 | raise NotImplementedError 76 | -------------------------------------------------------------------------------- /depiction/models/base/binarized_model.py: -------------------------------------------------------------------------------- 1 | """Binarized model.""" 2 | import numpy as np 3 | 4 | from .base_model import BaseModel 5 | from ...core import Task 6 | 7 | 8 | class BinarizedClassifier(BaseModel): 9 | 10 | def __init__(self, model, data_type, label_index): 11 | """ 12 | Initialize a Model. 13 | 14 | Args: 15 | model (torch.nn.Module): model to wrap. 16 | data_type (depiction.core.DataType): data type. 17 | label_index (int): index of the label to consider as positive. 18 | """ 19 | super(BinarizedClassifier, self).__init__(Task.BINARY, data_type) 20 | self.model = model 21 | self.label_index = label_index 22 | 23 | def predict(self, sample, *args, **kwargs): 24 | """ 25 | Run the model for inference on a given sample and with the provided 26 | parameters. 27 | 28 | Args: 29 | sample (np.ndarray): an input sample for the model. 30 | args (list): list of arguments for prediction. 31 | kwargs (dict): list of key-value arguments for prediction. 32 | 33 | Returns: 34 | int: 1 or 0 depending on the highest logit. 35 | """ 36 | y = self.model.predict(sample, *args, **kwargs) 37 | return (np.argmax(y, axis=1) == self.label_index).astype(np.int) 38 | -------------------------------------------------------------------------------- /depiction/models/base/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/depiction/models/base/tests/__init__.py -------------------------------------------------------------------------------- /depiction/models/base/tests/base_model_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from random import choice 3 | from unittest import mock 4 | 5 | from depiction.core import DataType, Task 6 | from depiction.models.base.base_model import BaseModel 7 | 8 | 9 | class ConcreteTestModel(BaseModel): 10 | 11 | def __init__(self, task_type, data_type): 12 | super(ConcreteTestModel, self).__init__(task_type, data_type) 13 | 14 | def predict(self, sample, *, test_kwarg): 15 | return sample 16 | 17 | 18 | class BaseModelTestCase(unittest.TestCase): 19 | 20 | def testModelConstruction(self): 21 | # expected inputs 22 | for task_type in Task: 23 | for data_type in DataType: 24 | concrete_model = ConcreteTestModel(task_type, data_type) 25 | self.assertEqual(concrete_model.task, task_type) 26 | self.assertEqual(concrete_model.data_type, data_type) 27 | 28 | # unexpected inputs 29 | with self.assertRaises(TypeError): 30 | ConcreteTestModel('asad', 5) 31 | 32 | def testCallback(self): 33 | concrete_model = ConcreteTestModel( 34 | choice(list(Task)), choice(list(DataType)) 35 | ) 36 | 37 | with mock.patch.object(concrete_model, 'predict') as mock_predict: 38 | test_kwarg = {'test_kwarg': 'test'} 39 | callback = concrete_model.callback(**test_kwarg) 40 | test_sample = 10 41 | _ = callback(test_sample) 42 | mock_predict.assert_called_with(test_sample, **test_kwarg) 43 | 44 | def testPredictMany(self): 45 | concrete_model = ConcreteTestModel( 46 | choice(list(Task)), choice(list(DataType)) 47 | ) 48 | 49 | with mock.patch.object(concrete_model, 'predict') as mock_predict: 50 | test_kwarg = {'test_kwarg': 'test'} 51 | test_samples = range(10) 52 | for res in concrete_model.predict_many(test_samples, **test_kwarg): 53 | pass 54 | for s in test_samples: 55 | mock_predict.assert_any_call(s, **test_kwarg) 56 | 57 | 58 | if __name__ == "__main__": 59 | unittest.main() 60 | -------------------------------------------------------------------------------- /depiction/models/base/tests/utils_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest import mock 3 | 4 | from depiction.models.base.utils import MODELS_SUBDIR, get_model_file 5 | 6 | 7 | class BaseUtilsTestCase(unittest.TestCase): 8 | 9 | def testGetModelFile(self): 10 | fname = 'test_file.h5' 11 | origin = 'test_url' 12 | cache_dir = 'cache_path' 13 | 14 | with mock.patch( 15 | 'depiction.models.base.utils.get_file' 16 | ) as mock_get_file: 17 | get_model_file(fname, origin, cache_dir) 18 | 19 | mock_get_file.assert_called_with( 20 | fname, origin, cache_subdir=MODELS_SUBDIR, cache_dir=cache_dir 21 | ) 22 | 23 | 24 | if __name__ == "__main__": 25 | unittest.main() 26 | -------------------------------------------------------------------------------- /depiction/models/base/utils.py: -------------------------------------------------------------------------------- 1 | """Generic util functions for models.""" 2 | from tensorflow.keras.utils import get_file 3 | 4 | MODELS_SUBDIR = 'models' 5 | 6 | 7 | def get_model_file(filename, origin, cache_dir): 8 | """ 9 | Downloads a file from a URL if it not already in the cache. 10 | """ 11 | return get_file( 12 | filename, origin, cache_subdir=MODELS_SUBDIR, cache_dir=cache_dir 13 | ) 14 | -------------------------------------------------------------------------------- /depiction/models/examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/depiction/models/examples/__init__.py -------------------------------------------------------------------------------- /depiction/models/examples/celltype/__init__.py: -------------------------------------------------------------------------------- 1 | from .celltype import CellTyper # noqa -------------------------------------------------------------------------------- /depiction/models/examples/celltype/celltype.py: -------------------------------------------------------------------------------- 1 | """CellTyper model.""" 2 | from tensorflow import keras 3 | from tensorflow.keras.utils import to_categorical 4 | 5 | from ...uri.cache.http_model import HTTPModel 6 | from ....core import Task, DataType 7 | 8 | 9 | def one_hot_encoding(classes): 10 | return to_categorical(classes)[:, 1:] # remove category 0 11 | 12 | 13 | def one_hot_decoding(labels): 14 | return labels.argmax(axis=1) + 1 15 | 16 | 17 | class CellTyper(HTTPModel): 18 | """Classifier of single cells.""" 19 | celltype_names = { 20 | 1: 'CD11b- Monocyte', 21 | 2: 'CD11bhi Monocyte', 22 | 3: 'CD11bmid Monocyte', 23 | 4: 'Erythroblast', 24 | 5: 'HSC', 25 | 6: 'Immature B', 26 | 7: 'Mature CD38lo B', 27 | 8: 'Mature CD38mid B', 28 | 9: 'Mature CD4+ T', 29 | 10: 'Mature CD8+ T', 30 | 11: 'Megakaryocyte', 31 | 12: 'Myelocyte', 32 | 13: 'NK', 33 | 14: 'Naive CD4+ T', 34 | 15: 'Naive CD8+ T', 35 | 16: 'Plasma cell', 36 | 17: 'Plasmacytoid DC', 37 | 18: 'Platelet', 38 | 19: 'Pre-B II', 39 | 20: 'Pre-B I' 40 | } 41 | 42 | def __init__( 43 | self, 44 | filename='celltype_model.h5', 45 | origin='https://ibm.box.com/shared/static/5uhttlduaund89tpti4y0ptipr2dcj0h.h5', # noqa 46 | cache_dir=None 47 | ): 48 | """Initialize the CellTyper.""" 49 | super().__init__( 50 | uri=origin, 51 | task=Task.CLASSIFICATION, 52 | data_type=DataType.TABULAR, 53 | cache_dir=cache_dir, 54 | filename=filename 55 | ) 56 | self.model = keras.models.load_model(self.model_path) 57 | 58 | def predict(self, sample, *args, **kwargs): 59 | """ 60 | Run the model for inference on a given sample and with the provided 61 | parameters. 62 | 63 | Args: 64 | sample (object): an input sample for the model. 65 | args (list): list of arguments. 66 | kwargs (dict): list of key-value arguments. 67 | 68 | Returns: 69 | a prediction for the model on the given sample. 70 | """ 71 | return self.model.predict( 72 | sample, batch_size=None, verbose=0, steps=None, callbacks=None 73 | ) 74 | 75 | @staticmethod 76 | def logits_to_celltype(predictions): 77 | return [ 78 | CellTyper.celltype_names[category] 79 | for category in one_hot_decoding(predictions) 80 | ] 81 | -------------------------------------------------------------------------------- /depiction/models/examples/celltype/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/depiction/models/examples/celltype/tests/__init__.py -------------------------------------------------------------------------------- /depiction/models/examples/celltype/tests/celltype_test.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import tempfile 3 | import unittest 4 | from pathlib import Path 5 | 6 | import pandas as pd 7 | 8 | from depiction.models.examples.celltype.celltype import CellTyper 9 | 10 | 11 | class CellTyperTestCase(unittest.TestCase): 12 | """Test celltype classifier.""" 13 | 14 | def setUp(self): 15 | """Prepare data to predict.""" 16 | filepath = Path(__file__).resolve( 17 | ).parents[5] / 'data' / 'single-cell' / 'data.csv' 18 | data_df = pd.read_csv(filepath) 19 | self.data = data_df.drop('category', axis=1).values 20 | self.tmp_dir = tempfile.mkdtemp() 21 | 22 | def test_prediction(self): 23 | typer = CellTyper(cache_dir=self.tmp_dir) 24 | predictions = typer.predict(self.data) 25 | self.assertEqual( 26 | predictions.shape, 27 | (self.data.shape[0], len(CellTyper.celltype_names)) 28 | ) 29 | 30 | CellTyper.logits_to_celltype(predictions) 31 | self.assertEqual( 32 | predictions.shape, 33 | (self.data.shape[0], len(CellTyper.celltype_names)) 34 | ) 35 | 36 | def tearDown(self): 37 | """Tear down the tests.""" 38 | shutil.rmtree(self.tmp_dir) 39 | 40 | 41 | if __name__ == "__main__": 42 | unittest.main() 43 | -------------------------------------------------------------------------------- /depiction/models/examples/deepbind/__init__.py: -------------------------------------------------------------------------------- 1 | from .deepbind import DeepBind # noqa -------------------------------------------------------------------------------- /depiction/models/examples/deepbind/deepbind.py: -------------------------------------------------------------------------------- 1 | """Wrapper for pretrained Deepbind via Kipoi.""" 2 | import numpy as np 3 | from spacy.tokens import Doc 4 | from spacy.vocab import Vocab 5 | from spacy.language import Language 6 | from concise.preprocessing.sequence import encodeDNA, encodeRNA 7 | 8 | from ....core import Task, DataType 9 | from ...kipoi.core import KipoiModel 10 | 11 | DEEPBIND_CLASSES = ['NotBinding', 'Binding'] 12 | ALPHABET = {'TF': ['T', 'C', 'G', 'A', 'N'], 'RBP': ['U', 'C', 'G', 'A', 'N']} 13 | ONE_HOT_ENCODER = {'TF': encodeDNA, 'RBP': encodeRNA} 14 | 15 | 16 | def create_sequence_language(alphabet): 17 | """Anchor accepts a spacy language for sampling the neighborhood.""" 18 | vocab = Vocab(strings=alphabet) 19 | 20 | def make_doc(sequence): 21 | sequence = sequence.replace(' ', '') 22 | if len(sequence) == 0: 23 | words = np.random.choice(alphabet) 24 | else: 25 | words = list(sequence) 26 | return Doc(vocab, words=words, spaces=[False] * len(words)) 27 | 28 | return Language(vocab, make_doc) 29 | 30 | 31 | def create_DNA_language(): 32 | return create_sequence_language(alphabet=ALPHABET['TF']) 33 | 34 | 35 | def create_RNA_language(): 36 | return create_sequence_language(alphabet=ALPHABET['RBF']) 37 | 38 | 39 | def character_correction(sequences_list, min_length, null_character='N'): 40 | """ 41 | Some perturbation based interpretability methods (e.g. lime) 42 | might introduce null characters which are not viable input. 43 | These are by default replaced with 'N' (for any character). 44 | 45 | The sequence is padded to min_length characters. 46 | """ 47 | return [ 48 | s.replace('\x00', null_character).ljust(min_length, null_character) 49 | for s in sequences_list 50 | ] 51 | 52 | 53 | def preprocessing_function( 54 | nucleotide_sequence, sequence_type, min_length=35, null_character='N' 55 | ): 56 | """One-hot-encode the sequence and allow passing single string.""" 57 | 58 | if isinstance(nucleotide_sequence, str): 59 | sequences_list = [nucleotide_sequence] 60 | else: 61 | if not hasattr(nucleotide_sequence, '__iter__'): 62 | raise IOError( 63 | f'Expected a str or iterable, got {type(nucleotide_sequence)}.' 64 | ) 65 | sequences_list = nucleotide_sequence 66 | 67 | return ONE_HOT_ENCODER[sequence_type]( 68 | character_correction(sequences_list, min_length, null_character) 69 | ) 70 | 71 | 72 | def sigmoid(x): 73 | return 1 / (1 + np.exp(-x)) 74 | 75 | 76 | def postprocessing_function(binding_score, use_labels=True): 77 | """Instead of a score, interpreters expect labels or probabilities.""" 78 | if use_labels: 79 | return binding_score > 0 # binding_probs > 0.5 80 | else: 81 | # not a score, but probability in [0,1] 82 | binding_probs = np.expand_dims(sigmoid(binding_score), axis=1) 83 | return np.hstack([1. - binding_probs, binding_probs]) 84 | 85 | 86 | class DeepBind(KipoiModel): 87 | """Deepbind wrapper via kipoi.""" 88 | 89 | def __init__(self, model, use_labels=True, min_length=0): 90 | """ 91 | Constructor. 92 | 93 | Args: 94 | model (string): kipoi model name. 95 | use_labels (bool): if False, use probabilites instead of label. 96 | min_length (int): minimal lenght of sequence used for eventual 97 | padding with null_character ('N'). Some deepbind models fail 98 | with too short sequences, in that case increase min_length. 99 | 100 | On top of the kipoi model prediction, the predict method of this class 101 | will preprocess a string sequence to one hot encoding using the 102 | the input documentation to determine `sequence_type`. 103 | It will also return not a binding score but either a classification 104 | label or 'NotBinding','Binding' probabilities expected by interpreters. 105 | """ 106 | super().__init__( 107 | model=model, 108 | task=Task.CLASSIFICATION, 109 | data_type=DataType.TEXT, 110 | source='kipoi', 111 | with_dataloader=False, 112 | preprocessing_function=preprocessing_function, 113 | preprocessing_kwargs={}, 114 | postprocessing_function=postprocessing_function, 115 | postprocessing_kwargs={}, 116 | ) 117 | # kwargs 118 | self.use_labels = use_labels 119 | # self.model.schema.inputs.doc is always "DNA Sequence", use name 120 | self.sequence_type = model.split('/')[2] # 'TF' or 'RBP' 121 | self.min_length = min_length 122 | 123 | def predict(self, sample): 124 | self.preprocessing_kwargs['sequence_type'] = self.sequence_type 125 | self.preprocessing_kwargs['min_length'] = self.min_length 126 | self.postprocessing_kwargs['use_labels'] = self.use_labels 127 | return super().predict(sample) 128 | -------------------------------------------------------------------------------- /depiction/models/examples/deepbind/deepbind_cli.py: -------------------------------------------------------------------------------- 1 | """Wrapper for pretrained Deepbind as executable.""" 2 | import os 3 | import tarfile 4 | import tempfile 5 | import subprocess 6 | import numpy as np 7 | from subprocess import PIPE 8 | from spacy.tokens import Doc 9 | from spacy.vocab import Vocab 10 | from spacy.language import Language 11 | 12 | from ....core import Task, DataType 13 | from ...uri.cache.http_model import HTTPModel 14 | 15 | DEEPBIND_CLASSES = ['NotBinding', 'Binding'] 16 | SEQ_FILE_EXTENSION = '.seq' 17 | DNA_ALPHABET = ['T', 'C', 'G', 'A', 'U', 'N'] 18 | 19 | 20 | def sigmoid(x): 21 | return 1 / (1 + np.exp(-x)) 22 | 23 | 24 | def process_deepbind_stdout(deepbind_stdout): 25 | """ 26 | Process the output assuming that there is only one input and one factor, 27 | i.e. the output has this format: 28 | 29 | 30 | \n 31 | \n 32 | 33 | 34 | Returns: 35 | Probability of binding, as sigmoid(binding score) 36 | """ 37 | return np.expand_dims( 38 | sigmoid(np.array(deepbind_stdout.splitlines()[1:]).astype(np.float)), 39 | axis=1 40 | ) 41 | 42 | 43 | def deepbind(factor_id, sequence_fpath, exec_path): 44 | process = subprocess.run( 45 | [exec_path, factor_id, sequence_fpath], stdout=PIPE, stderr=PIPE 46 | ) 47 | 48 | return process_deepbind_stdout(process.stdout) 49 | 50 | 51 | def character_correction(sequences_list): 52 | """Some perturbation based interpretability methods (e.g. lime) 53 | might introduce null characters which are not viable input. 54 | Deleting them is one way of dealing with this. 55 | """ 56 | return [ 57 | s.replace('\x00', '') if len(s.replace('\x00', '')) > 0 58 | else np.random.choice(DNA_ALPHABET) for s in sequences_list 59 | ] 60 | 61 | 62 | def deepbind_on_sequences( 63 | factor_id, sequences_list, exec_path, tmp_folder=None 64 | ): 65 | tmp_file = tempfile.mkstemp(dir=tmp_folder, suffix=SEQ_FILE_EXTENSION)[1] 66 | 67 | with open(tmp_file, 'w') as tmp_fh: 68 | tmp_fh.write( 69 | '\n'.join(character_correction(sequences_list)) 70 | ) 71 | 72 | return deepbind(factor_id, tmp_file, exec_path) 73 | 74 | 75 | def create_DNA_language(): 76 | accepted_values = DNA_ALPHABET 77 | vocab = Vocab(strings=accepted_values) 78 | 79 | def make_doc(sequence): 80 | sequence = sequence.replace(' ', '') 81 | if len(sequence) == 0: 82 | words = np.random.choice(accepted_values) 83 | else: 84 | words = list(sequence) 85 | return Doc(vocab, words=words, spaces=[False] * len(words)) 86 | 87 | return Language(vocab, make_doc) 88 | 89 | 90 | class DeepBind(HTTPModel): 91 | """Deepbind wrapper.""" 92 | 93 | def __init__( 94 | self, 95 | tf_factor_id='D00328.003', 96 | use_labels=True, 97 | filename='deepbind.tgz', 98 | origin='https://ibm.box.com/shared/static/ns9e7666kfjwvlmyk6mrh4n6sqjmzagm.tgz', # noqa 99 | cache_dir=None 100 | ): 101 | """ 102 | Constructor. 103 | 104 | Args: 105 | tf_factor_id (str): ID of the transcription factor to classify 106 | against. 107 | use_labels (bool): if False, use logits insted of labels. 108 | filename (str): where to store the downloaded zip containing the 109 | model. 110 | origin (str): link where to download the model from. 111 | """ 112 | super().__init__( 113 | uri=origin, 114 | task=Task.CLASSIFICATION, 115 | data_type=DataType.TEXT, 116 | cache_dir=cache_dir, 117 | filename=filename 118 | ) 119 | self.tf_factor_id = tf_factor_id 120 | self.use_labels = use_labels 121 | # make sure the model is present 122 | self.save_dir = os.path.dirname(self.model_path) 123 | self.model_dir = os.path.join(self.save_dir, 'deepbind') 124 | if not os.path.exists(self.model_dir): 125 | with tarfile.open(self.model_path, 'r:gz') as model_tar: 126 | model_tar.extractall(self.save_dir) 127 | self.exec_path = os.path.join(self.model_dir, 'deepbind') 128 | 129 | def predict(self, sample, *args, **kwargs): 130 | """ 131 | Run the model for inference on a given sample and with the provided 132 | parameters. 133 | 134 | Args: 135 | sample (object): an input sample for the model. 136 | args (list): list of arguments. 137 | kwargs (dict): list of key-value arguments. 138 | 139 | Returns: 140 | a prediction for the model on the given sample. 141 | """ 142 | if not isinstance(sample, list): 143 | sample = [sample] 144 | binding_probs = deepbind_on_sequences( 145 | self.tf_factor_id, sample, self.exec_path 146 | ) 147 | if self.use_labels: 148 | return binding_probs.flatten() > 0.5 149 | else: 150 | return np.hstack([1. - binding_probs, binding_probs]) 151 | -------------------------------------------------------------------------------- /depiction/models/examples/paccmann/__init__.py: -------------------------------------------------------------------------------- 1 | """PaccMann initialization.""" 2 | from .core import ( # noqa 3 | PaccMann, PaccMannCellLine, PaccMannSmiles 4 | ) 5 | -------------------------------------------------------------------------------- /depiction/models/examples/paccmann/smiles.py: -------------------------------------------------------------------------------- 1 | """SMILES utilities.""" 2 | import os 3 | import re 4 | import logging 5 | import numpy as np 6 | import matplotlib as mpl 7 | import matplotlib.cm as cm 8 | from operator import itemgetter 9 | from rdkit import Chem 10 | from rdkit.Chem.Draw import rdMolDraw2D 11 | from spacy.vocab import Vocab 12 | from spacy.language import Language 13 | from spacy.tokens import Doc 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | ATOM_REGEX = re.compile( 18 | r'(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|' 19 | r'-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])' 20 | ) 21 | MAXIMUM_NUMBER_OF_RINGS = int( 22 | os.environ.get('PACCMANN_MAXIMUM_NUMBER_OF_RINGS', 9) 23 | ) 24 | MAX_LENGTH = int(os.environ.get('PACCMANN_MAX_LENGTH', 155)) 25 | PADDING_ATOM = os.environ.get('PACCMANN_PADDING_ATOM', '') 26 | NON_ATOM_CHARACTERS = set( 27 | [str(index) 28 | for index in range(1, MAXIMUM_NUMBER_OF_RINGS)] + ['(', ')', '#', '='] 29 | ) 30 | ATOM_MAPPING = { 31 | '2': 1, 32 | '7': 2, 33 | 'O': 3, 34 | '[O]': 4, 35 | '#': 5, 36 | '(': 6, 37 | 'P': 7, 38 | 'Cl': 8, 39 | 'C': 9, 40 | 'N': 10, 41 | 'Br': 11, 42 | 'F': 12, 43 | ')': 13, 44 | '=': 14, 45 | '9': 15, 46 | '4': 16, 47 | '1': 17, 48 | '6': 18, 49 | 'I': 19, 50 | '[N+]': 20, 51 | '[NH]': 21, 52 | '.': 22, 53 | 'S': 23, 54 | '[O-]': 24, 55 | '3': 25, 56 | '8': 26, 57 | '5': 27, 58 | PADDING_ATOM: 0 59 | } 60 | REVERSED_ATOM_MAPPING = {index: atom for atom, index in ATOM_MAPPING.items()} 61 | CMAP = cm.Oranges 62 | COLOR_NORMALIZERS = { 63 | 'linear': mpl.colors.Normalize, 64 | 'logarithmic': mpl.colors.LogNorm 65 | } 66 | ATOM_RADII = float(os.environ.get('PACCMANN_ATOM_RADII', .5)) 67 | SVG_WIDTH = int(os.environ.get('PACCMANN_SVG_WIDTH', 400)) 68 | SVG_HEIGHT = int(os.environ.get('PACCMANN_SVG_HEIGHT', 200)) 69 | COLOR_NORMALIZATION = os.environ.get( 70 | 'PACCMANN_COLOR_NORMALIZATION', 'logarithmic' 71 | ) 72 | 73 | 74 | def process_smiles(smiles): 75 | """ 76 | Process a SMILES. 77 | 78 | SMILES string is processed to generate a zero-padded 79 | sequence. 80 | 81 | Args: 82 | smiles (str): a SMILES representing a molecule. 83 | Returns: 84 | a list of token indices.append() 85 | """ 86 | tokens = [token for token in ATOM_REGEX.split(smiles) 87 | if token][:MAX_LENGTH] 88 | return ( 89 | [0] * (MAX_LENGTH - len(tokens)) + 90 | [ATOM_MAPPING.get(token, 0) for token in tokens] 91 | ) 92 | 93 | 94 | def get_atoms(smiles): 95 | """ 96 | Process a SMILES. 97 | 98 | SMILES string is processed to generate a sequence 99 | of atoms. 100 | 101 | Args: 102 | smiles (str): a SMILES representing a molecule. 103 | Returns: 104 | a list of atoms. 105 | """ 106 | tokens = process_smiles(smiles) 107 | return [REVERSED_ATOM_MAPPING[token] for token in tokens] 108 | 109 | 110 | def remove_padding_from_atoms_and_smiles_attention(atoms, smiles_attention): 111 | """ 112 | Remove padding atoms and corresponding attention weights. 113 | 114 | Args: 115 | atoms (Iterable): an iterable of atoms. 116 | smiles_attention (Iterable): an iterable of floating point values. 117 | Returns: 118 | two iterables of atoms and attention values removing the padding. 119 | """ 120 | to_keep = [ 121 | index for index, atom in enumerate(atoms) if atom != PADDING_ATOM 122 | ] 123 | return ( 124 | list(itemgetter(*to_keep)(atoms)), 125 | list(itemgetter(*to_keep)(smiles_attention)) 126 | ) 127 | 128 | 129 | def _get_index_and_colors(values, objects, predicate, color_mapper): 130 | """ 131 | Get index and RGB colors from a color map using a rule. 132 | 133 | The predicate acts on a tuple of (value, object). 134 | 135 | Args: 136 | values (Iterable): floats representing a color. 137 | objects (Iterable): objects associated to the colors. 138 | predicate (Callable): a predicate to filter objects. 139 | color_mapper (cm.ScalarMappable): a mapper from floats to RGBA. 140 | 141 | Returns: 142 | Iterables of indices and RGBA colors. 143 | """ 144 | indices = [] 145 | colors = {} 146 | for index, value in enumerate( 147 | map( 148 | lambda t: t[0], 149 | filter(lambda t: predicate(t), zip(values, objects)) 150 | ) 151 | ): 152 | indices.append(index) 153 | colors[index] = color_mapper.to_rgba(value) 154 | return indices, colors 155 | 156 | 157 | def smiles_attention_to_svg(smiles_attention, atoms, molecule): 158 | """ 159 | Generate an svg of the molecule highlighiting SMILES attention. 160 | 161 | Args: 162 | smiles_attention (Iterable): an iterable of floating point values. 163 | atoms (Iterable): an iterable of atoms. 164 | molecule (rdkit.Chem.Mol): a molecule. 165 | Returns: 166 | the svg of the molecule with highlighted atoms and bonds. 167 | """ 168 | # remove padding 169 | logger.debug('SMILES attention:\n{}'.format(smiles_attention)) 170 | logger.debug( 171 | 'SMILES attention range: [{},{}]'.format( 172 | min(smiles_attention), max(smiles_attention) 173 | ) 174 | ) 175 | atoms, smiles_attention = remove_padding_from_atoms_and_smiles_attention( 176 | atoms, smiles_attention 177 | ) 178 | logger.debug( 179 | 'atoms and SMILES after removal:\n{}\n{}'.format( 180 | atoms, smiles_attention 181 | ) 182 | ) 183 | logger.debug( 184 | 'SMILES attention after padding removal:\n{}'.format(smiles_attention) 185 | ) 186 | logger.debug( 187 | 'SMILES attention range after padding removal: [{},{}]'.format( 188 | min(smiles_attention), max(smiles_attention) 189 | ) 190 | ) 191 | # define a color map 192 | normalize = COLOR_NORMALIZERS.get(COLOR_NORMALIZATION, mpl.colors.LogNorm)( 193 | vmin=min(smiles_attention), vmax=2 * max(smiles_attention) 194 | ) 195 | color_mapper = cm.ScalarMappable(norm=normalize, cmap=CMAP) 196 | # get atom colors 197 | highlight_atoms, highlight_atom_colors = _get_index_and_colors( 198 | smiles_attention, atoms, lambda t: t[1] not in NON_ATOM_CHARACTERS, 199 | color_mapper 200 | ) 201 | logger.debug('Atom colors:\n{}'.format(highlight_atom_colors)) 202 | # get bond colors 203 | highlight_bonds, highlight_bond_colors = _get_index_and_colors( 204 | smiles_attention, atoms, lambda t: t[1] in NON_ATOM_CHARACTERS, 205 | color_mapper 206 | ) 207 | logger.debug('Bond colors:\n{}'.format(highlight_bond_colors)) 208 | # add coordinates 209 | Chem.rdDepictor.Compute2DCoords(molecule) 210 | # draw the molecule 211 | drawer = rdMolDraw2D.MolDraw2DSVG(SVG_WIDTH, SVG_HEIGHT) 212 | drawer.DrawMolecule( 213 | molecule, 214 | highlightAtoms=highlight_atoms, 215 | highlightAtomColors=highlight_atom_colors, 216 | highlightBonds=highlight_bonds, 217 | highlightBondColors=highlight_bond_colors, 218 | highlightAtomRadii={index: ATOM_RADII 219 | for index in highlight_atoms} 220 | ) 221 | drawer.FinishDrawing() 222 | # return the drawn molecule 223 | return drawer.GetDrawingText().replace('\n', ' ') 224 | 225 | 226 | def get_smiles_language(): 227 | """ 228 | Get SMILES language. 229 | 230 | Returns: 231 | a spacy.language.Language representing SMILES. 232 | """ 233 | valid_values = list( 234 | filter(lambda k: k != PADDING_ATOM, ATOM_MAPPING.keys()) 235 | ) 236 | vocabulary = Vocab(strings=valid_values) 237 | 238 | def make_doc(smiles): 239 | """ 240 | Make a SMILES document. 241 | 242 | Args: 243 | smiles (str): a SMILES representing a molecule. 244 | Returns: 245 | a spacy.tokens.Doc representing the molecule. 246 | """ 247 | if len(smiles) == 0: 248 | tokens = np.random.choice(valid_values) 249 | else: 250 | tokens = [token for token in ATOM_REGEX.split(smiles) 251 | if token][:MAX_LENGTH] 252 | return Doc(vocabulary, words=tokens, spaces=[False] * len(tokens)) 253 | 254 | return Language(vocabulary, make_doc) 255 | -------------------------------------------------------------------------------- /depiction/models/keras/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize keras models.""" 2 | from .core import KerasModel # noqa 3 | from .application import KerasApplicationModel # noqa -------------------------------------------------------------------------------- /depiction/models/keras/application.py: -------------------------------------------------------------------------------- 1 | """Core module for keras applications.""" 2 | from .core import KerasModel 3 | 4 | 5 | def identity(sample, *args, **kwargs): 6 | """ 7 | Apply identity. 8 | 9 | Args: 10 | sample (np.ndarray): an input sample for the model. 11 | 12 | Returns: 13 | np.ndarray: output of preprocessing function representing 14 | the sample. 15 | """ 16 | return sample 17 | 18 | 19 | class KerasApplicationModel(KerasModel): 20 | """Keras application wrapper.""" 21 | 22 | def __init__( 23 | self, 24 | model, 25 | task, 26 | data_type, 27 | preprocessing_function=identity, 28 | *args, 29 | **kwargs 30 | ): 31 | """ 32 | Initialize a KerasApplicationModel. 33 | 34 | Args: 35 | model (keras): model to wrap. 36 | task (depiction.core.Task): task type. 37 | data_type (depiction.core.DataType): data type. 38 | preprocessing_function (callable): function to preprocess samples. 39 | *args (list): arguments passed to preprocessing function. 40 | **kwargs (dict): keyword arguments passed to preprocessing 41 | function. 42 | """ 43 | super().__init__(model=model, task=task, data_type=data_type) 44 | self.preprocessing_function = preprocessing_function 45 | self.preprocessing_args = args 46 | self.preprocessing_kwargs = kwargs 47 | 48 | def _prepare_sample(self, sample): 49 | """ 50 | Prepare sample for the model. 51 | 52 | Args: 53 | sample (np.ndarray): an input sample for the model. 54 | 55 | Returns: 56 | output of preprocessing function representing the sample. 57 | """ 58 | return self.preprocessing_function( 59 | sample, *self.preprocessing_args, **self.preprocessing_kwargs 60 | ) 61 | -------------------------------------------------------------------------------- /depiction/models/keras/core.py: -------------------------------------------------------------------------------- 1 | """Core module for Keras models.""" 2 | import copy 3 | 4 | from ..base.base_model import BaseModel 5 | 6 | 7 | class KerasModel(BaseModel): 8 | """Keras model wrapper.""" 9 | 10 | def __init__(self, model, task, data_type): 11 | """ 12 | Initialize a Model. 13 | 14 | Args: 15 | model (torch.nn.Module): model to wrap. 16 | task (depiction.core.Task): task type. 17 | data_type (depiction.core.DataType): data type. 18 | """ 19 | super().__init__(task=task, data_type=data_type) 20 | self._model = model 21 | self._predict_kwargs = { 22 | 'batch_size': None, 23 | 'verbose': 0, 24 | 'steps': None, 25 | 'callbacks': None 26 | } 27 | 28 | def _prepare_sample(self, sample): 29 | """ 30 | Prepare sample for the model. 31 | 32 | Args: 33 | sample (np.ndarray): an input sample for the model. 34 | 35 | Returns: 36 | np.ndarray: a numpy array representing the prepared sample. 37 | """ 38 | return sample 39 | 40 | def predict(self, sample, *args, **kwargs): 41 | """ 42 | Run the model for inference on a given sample and with the provided 43 | parameters. 44 | 45 | Args: 46 | sample (np.ndarray): an input sample for the model. 47 | args (list): list of arguments for prediction. 48 | kwargs (dict): list of key-value arguments for prediction. 49 | 50 | Returns: 51 | np.ndarray: a prediction for the model on the given sample. 52 | """ 53 | predict_kwargs = copy.deepcopy(self._predict_kwargs) 54 | predict_kwargs.update(**kwargs) 55 | return self._model.predict( 56 | self._prepare_sample(sample), *args, **predict_kwargs 57 | ) 58 | -------------------------------------------------------------------------------- /depiction/models/keras/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/depiction/models/keras/tests/__init__.py -------------------------------------------------------------------------------- /depiction/models/keras/tests/application_test.py: -------------------------------------------------------------------------------- 1 | """Test KerasApplicationModel.""" 2 | import os 3 | import unittest 4 | 5 | import numpy as np 6 | from tensorflow import keras 7 | 8 | from depiction.core import DataType, Task 9 | from depiction.models.keras.application import KerasApplicationModel 10 | 11 | 12 | def user_preprocessing(img_path, preprocess_input, target_size): 13 | """Mimic sample preparation from Keras application documentation.""" 14 | img = keras.preprocessing.image.load_img(img_path, target_size=target_size) 15 | x = keras.preprocessing.image.img_to_array(img) 16 | x = np.expand_dims(x, axis=0) 17 | return preprocess_input(x) 18 | 19 | 20 | class KerasApplicationTestCase(unittest.TestCase): 21 | """Test Keras Applications.""" 22 | 23 | def setUp(self): 24 | self.img_path = keras.utils.get_file( 25 | 'elephant.jpg', 26 | 'https://upload.wikimedia.org/wikipedia/commons/thumb/f/f9/Zoorashia_elephant.jpg/120px-Zoorashia_elephant.jpg' # noqa 27 | ) 28 | 29 | def test_predict_shape(self): 30 | """Test passing no preprocessing function.""" 31 | model = KerasApplicationModel( 32 | model=keras.applications.MobileNetV2(), 33 | task=Task.CLASSIFICATION, 34 | data_type=DataType.IMAGE 35 | ) 36 | image = np.random.randn(1, 224, 224, 3) 37 | self.assertEqual(model.predict(image).shape, (1, 1000)) 38 | 39 | def test_predict_with_preprocessing(self): 40 | """Test passing of preprocessing function and arguments.""" 41 | # default keras workflow 42 | img = keras.preprocessing.image.load_img( 43 | self.img_path, target_size=(224, 224) 44 | ) 45 | x = keras.preprocessing.image.img_to_array(img) 46 | x = np.expand_dims(x, axis=0) 47 | sample = keras.applications.mobilenet_v2.preprocess_input(x) 48 | keras_output = keras.applications.MobileNetV2().predict(sample) 49 | 50 | # wrapped application 51 | application_model = KerasApplicationModel( 52 | model=keras.applications.MobileNetV2(), 53 | task=Task.CLASSIFICATION, 54 | data_type=DataType.IMAGE, 55 | preprocessing_function=user_preprocessing, 56 | # kwargs passed to preprocessing_function 57 | preprocess_input=keras.applications.mobilenet_v2.preprocess_input, 58 | target_size=(224, 224) 59 | ) 60 | depiction_output = application_model.predict(self.img_path) 61 | 62 | self.assertIs(np.array_equal(keras_output, depiction_output), True) 63 | 64 | def tearDown(self): 65 | os.remove(self.img_path) 66 | 67 | 68 | if __name__ == "__main__": 69 | unittest.main() 70 | -------------------------------------------------------------------------------- /depiction/models/keras/tests/core_test.py: -------------------------------------------------------------------------------- 1 | """Test KerasModel.""" 2 | import unittest 3 | 4 | import numpy as np 5 | from tensorflow import keras 6 | 7 | from depiction.core import DataType, Task 8 | from depiction.models.keras.core import KerasModel 9 | 10 | 11 | class KerasModelTestCase(unittest.TestCase): 12 | """Test KerasModel.""" 13 | 14 | def test_prediction(self): 15 | model = KerasModel( 16 | model=keras.applications.MobileNetV2(), 17 | task=Task.CLASSIFICATION, 18 | data_type=DataType.IMAGE 19 | ) 20 | image = np.random.randn(1, 224, 224, 3) 21 | self.assertEqual(model.predict(image).shape, (1, 1000)) 22 | 23 | 24 | if __name__ == "__main__": 25 | unittest.main() 26 | -------------------------------------------------------------------------------- /depiction/models/kipoi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/depiction/models/kipoi/__init__.py -------------------------------------------------------------------------------- /depiction/models/kipoi/core.py: -------------------------------------------------------------------------------- 1 | """Core module for Keras models.""" 2 | import copy 3 | 4 | import kipoi 5 | from ..base.base_model import BaseModel 6 | 7 | 8 | def identity(sample, *args, **kwargs): 9 | """ 10 | Apply identity. 11 | 12 | Args: 13 | sample (np.ndarray): an input sample for the model. 14 | 15 | Returns: 16 | np.ndarray: output of preprocessing function representing 17 | the sample. 18 | """ 19 | return sample 20 | 21 | 22 | class KipoiModel(BaseModel): 23 | """To use Kipoi models via its Python API. 24 | 25 | See https://github.com/kipoi/kipoi/blob/master/notebooks/python-api.ipynb. 26 | 27 | Take care that Kipoi models might define additional dependencies.""" 28 | 29 | def __init__( 30 | self, model, task, data_type, source='kipoi', with_dataloader=False, 31 | preprocessing_function=identity, 32 | postprocessing_function=identity, 33 | preprocessing_kwargs={}, 34 | postprocessing_kwargs={} 35 | ): 36 | """ 37 | Initialize a KipoiModel via `kipoi.get_model`. 38 | 39 | Args: 40 | model (string): kipoi model name. 41 | task (depiction.core.Task): task type. 42 | data_type (depiction.core.DataType): data type. 43 | source (str): kipoi model source name. Defaults to 'kipoi'. 44 | with_dataloader (bool): if True, the kipoi models' default 45 | dataloader is loaded to `model.default_dataloader` and the 46 | pipeline at `model.pipeline` enabled. Defaults to False. 47 | preprocessing_function (callable): function to preprocess samples. 48 | **preprocessing_kwargs (dict): keyword arguments passed to 49 | preprocessing function. 50 | postprocessing_function (callable): function to postprocess output 51 | of kipois `predict_on_batch`. 52 | **postprocessing_kwargs (dict): keyword arguments passed to 53 | postprocessing function. 54 | 55 | The processing functions default to the identity function. 56 | """ 57 | super().__init__(task=task, data_type=data_type) 58 | self.preprocessing_function = preprocessing_function 59 | self.preprocessing_kwargs = copy.deepcopy(preprocessing_kwargs) 60 | self.postprocessing_function = postprocessing_function 61 | self.postprocessing_kwargs = copy.deepcopy(postprocessing_kwargs) 62 | self.model = kipoi.get_model( 63 | model, source=source, with_dataloader=with_dataloader 64 | ) 65 | 66 | def _prepare_sample(self, sample): 67 | """ 68 | Prepare sample for the model. 69 | 70 | Args: 71 | sample (np.ndarray): an input sample for the model. 72 | 73 | Returns: 74 | output of preprocessing function representing the sample. 75 | """ 76 | return self.preprocessing_function( 77 | sample, **self.preprocessing_kwargs 78 | ) 79 | 80 | def predict(self, sample): 81 | """ 82 | Run the model for inference on a given sample. The sample is 83 | preprocessed and output postprocessed. 84 | 85 | Args: 86 | sample (np.ndarray): an input sample for the model. 87 | 88 | Returns: 89 | np.ndarray: a prediction for the model on the given sample. 90 | """ 91 | return self.postprocessing_function( 92 | self.model.predict_on_batch(self._prepare_sample(sample)), 93 | **self.postprocessing_kwargs 94 | ) 95 | -------------------------------------------------------------------------------- /depiction/models/kipoi/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/depiction/models/kipoi/tests/__init__.py -------------------------------------------------------------------------------- /depiction/models/kipoi/tests/kipoi_test.py: -------------------------------------------------------------------------------- 1 | """Test KipoiModel.""" 2 | import unittest 3 | 4 | import numpy as np 5 | from concise.preprocessing.sequence import encodeDNA 6 | from depiction.core import DataType, Task 7 | from depiction.models.kipoi.core import KipoiModel 8 | 9 | 10 | def preprocessing_function(nucleotide_sequence: str) -> np.ndarray: 11 | """One-hot-encode (a single) the sequence. 12 | 13 | The kipoi deepbind model does not accept a string of nucleotides. 14 | 15 | Args: 16 | nucleotide_sequence (str): defined to be of lenght 101, though other 17 | lenghts might be accepted. 18 | 19 | Returns: 20 | np.ndarray: of shape `[1, len(nucleotide_sequence), 4]` 21 | """ 22 | return encodeDNA([nucleotide_sequence]) 23 | 24 | 25 | class KipoiModelTestCase(unittest.TestCase): 26 | """Test KipoiModel. 27 | 28 | Kopoi model page: 29 | http://kipoi.org/models/DeepBind/Homo_sapiens/TF/D00817.001_ChIP-seq_TBP/ 30 | """ 31 | 32 | def test_prediction(self): 33 | model = KipoiModel( 34 | 'DeepBind/Homo_sapiens/TF/D00817.001_ChIP-seq_TBP', 35 | Task.CLASSIFICATION, DataType.TEXT, 36 | preprocessing_function=preprocessing_function 37 | ) 38 | sequence = 'ATGGGCCAGCACACAGACCAGCACGTTGCCCAGGAGCTCGCTATAAAAGGGCGTGGGAGGAAGATAAGAGGTATGAACATGATTAGCAAAAGGGCCTAGCT' # noqa 39 | # contains the TATA box: ~~~~~~~ 40 | self.assertTrue((model.predict(sequence) > 0)[0]) # shape is (1,) 41 | 42 | 43 | if __name__ == "__main__": 44 | unittest.main() 45 | -------------------------------------------------------------------------------- /depiction/models/max/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize MAX models.""" 2 | from .toxic_comment_classifier import ToxicCommentClassifier # noqa 3 | from .breast_cancer_mitosis_detector import BreastCancerMitosisDetector # noqa -------------------------------------------------------------------------------- /depiction/models/max/breast_cancer_mitosis_detector.py: -------------------------------------------------------------------------------- 1 | """MAX Breast Cancer Mitosis Detector Keras Model.""" 2 | import imageio 3 | import numpy as np 4 | from io import BytesIO 5 | 6 | from ...core import Task, DataType 7 | from ..uri.rest_api.max_model import MAXModel 8 | 9 | 10 | class BreastCancerMitosisDetector(MAXModel): 11 | """MAX Breast Cancer Mitosis Detector Keras Model.""" 12 | 13 | def __init__(self, uri): 14 | """ 15 | Initialize MAX Breast Cancer Mitosis Detector Keras Model. 16 | 17 | Args: 18 | uri (str): URI to access the model. 19 | """ 20 | super().__init__(uri=uri, task=Task.BINARY, data_type=DataType.IMAGE) 21 | self.labels = ['non mitotic', 'mitotic'] 22 | 23 | def _process_prediction(self, prediction): 24 | """ 25 | Process json prediction response. 26 | 27 | Args: 28 | prediction (dict): json prediction response. 29 | 30 | Returns: 31 | np.ndarray: numpy array representing the prediction. 32 | """ 33 | return np.array( 34 | [ 35 | [ 36 | 1.0 - a_prediction['probability'], 37 | a_prediction['probability'] 38 | ] for a_prediction in prediction['predictions'] 39 | ] 40 | ) 41 | 42 | def _predict(self, sample, *args, **kwargs): 43 | """ 44 | Run the model for inference on a given sample and with the provided 45 | parameters. 46 | 47 | Args: 48 | sample (object): an input sample for the model. 49 | args (list): list of arguments. 50 | kwargs (dict): list of key-value arguments. 51 | 52 | Returns: 53 | dict: a prediction for the model on the given sample. 54 | """ 55 | # NOTE: create a buffer containing the image 56 | buffer = BytesIO() 57 | imageio.imwrite(buffer, sample, format='png') 58 | buffer.seek(0, 0) 59 | return self._request(method='post', files={'image': buffer}) 60 | -------------------------------------------------------------------------------- /depiction/models/max/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/depiction/models/max/tests/__init__.py -------------------------------------------------------------------------------- /depiction/models/max/tests/breast_cancer_mitosis_detector_test.py: -------------------------------------------------------------------------------- 1 | """Test MAX toxic comment classifier.""" 2 | import os 3 | import unittest 4 | 5 | import numpy as np 6 | 7 | from depiction.models.max.breast_cancer_mitosis_detector import \ 8 | BreastCancerMitosisDetector 9 | 10 | 11 | class BreastCancerMitosisDetectorTestCase(unittest.TestCase): 12 | """Test MAX toxic comment classifier.""" 13 | 14 | def test_prediction(self): 15 | model = BreastCancerMitosisDetector( 16 | uri='http://{}:5000'.format( 17 | os.environ. 18 | get('TEST_MAX_BREAST_CANCER_MITOSIS_DETECTOR', 'localhost') 19 | ) 20 | ) 21 | image = np.random.randn(64, 64, 3) 22 | self.assertEqual(model.predict(image).shape, (1, len(model.labels))) 23 | 24 | 25 | if __name__ == "__main__": 26 | unittest.main() 27 | -------------------------------------------------------------------------------- /depiction/models/max/tests/toxic_comment_classifier_test.py: -------------------------------------------------------------------------------- 1 | """Test MAX toxic comment classifier.""" 2 | import os 3 | import unittest 4 | 5 | from depiction.models.max.toxic_comment_classifier import \ 6 | ToxicCommentClassifier 7 | 8 | 9 | class ToxicCommentClassifierTestCase(unittest.TestCase): 10 | """Test MAX toxic comment classifier.""" 11 | 12 | def test_prediction(self): 13 | toxic = ToxicCommentClassifier( 14 | uri='http://{}:5000'.format( 15 | os.environ. 16 | get('TEST_MAX_TOXIC_COMMENT_CLASSIFIER', 'localhost') 17 | ) 18 | ) 19 | texts = ['This movie sucks.', 'I really liked the play.'] 20 | self.assertEqual( 21 | toxic.predict(texts).shape, (len(texts), len(toxic.labels)) 22 | ) 23 | 24 | 25 | if __name__ == "__main__": 26 | unittest.main() 27 | -------------------------------------------------------------------------------- /depiction/models/max/toxic_comment_classifier.py: -------------------------------------------------------------------------------- 1 | """MAX Toxic Comment Classifier.""" 2 | import numpy as np 3 | 4 | from ...core import Task, DataType 5 | from ..uri.rest_api.max_model import MAXModel 6 | 7 | 8 | class ToxicCommentClassifier(MAXModel): 9 | """MAX Toxic Comment Classifier.""" 10 | 11 | def __init__(self, uri): 12 | """ 13 | Initialize MAX Toxic Comment Classifier. 14 | 15 | Args: 16 | uri (str): URI to access the model. 17 | """ 18 | super().__init__( 19 | uri=uri, task=Task.MULTICLASS, data_type=DataType.TEXT 20 | ) 21 | self.labels = sorted( 22 | self._request(method='get', 23 | endpoint=self.labels_endpoint)['labels'].keys() 24 | ) 25 | 26 | def _process_prediction(self, prediction): 27 | """ 28 | Process json prediction response. 29 | 30 | Args: 31 | prediction (dict): json prediction response. 32 | 33 | Returns: 34 | np.ndarray: numpy array representing the prediction. 35 | """ 36 | return np.array( 37 | [ 38 | [result['predictions'][label] for label in self.labels] 39 | for result in prediction['results'] # API changed with new level nesting 40 | ] 41 | ) 42 | 43 | def _predict(self, sample, *args, **kwargs): 44 | """ 45 | Run the model for inference on a given sample and with the provided 46 | parameters. 47 | 48 | Args: 49 | sample (object): an input sample for the model. 50 | args (list): list of arguments. 51 | kwargs (dict): list of key-value arguments. 52 | 53 | Returns: 54 | dict: a prediction for the model on the given sample. 55 | """ 56 | texts = [sample] if isinstance(sample, 57 | str) else [text for text in sample] 58 | return self._request(method='post', json={'text': texts}) 59 | -------------------------------------------------------------------------------- /depiction/models/torch/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize torch models.""" 2 | from .core import TorchModel # noqa 3 | from .torchvision import TorchVisionModel # noqa 4 | -------------------------------------------------------------------------------- /depiction/models/torch/core.py: -------------------------------------------------------------------------------- 1 | """Core module for PyTorch models.""" 2 | import torch 3 | 4 | from ..base.base_model import BaseModel 5 | 6 | 7 | class TorchModel(BaseModel): 8 | """PyTorch model wrapper.""" 9 | 10 | def __init__(self, model, task, data_type, double=False): 11 | """ 12 | Initialize a TorchModel. 13 | 14 | Args: 15 | model (torch.nn.Module): model to wrap. 16 | task (depiction.core.Task): task type. 17 | data_type (depiction.core.DataType): data type. 18 | """ 19 | super().__init__(task=task, data_type=data_type) 20 | self._model = model 21 | self._double = double 22 | 23 | def _prepare_sample(self, sample): 24 | """ 25 | Prepare sample for the model. 26 | 27 | Args: 28 | sample (np.ndarray): an input sample for the model. 29 | 30 | Returns: 31 | torch.tensor: a tensor representing the sample. 32 | """ 33 | if self._double: 34 | return torch.from_numpy(sample).double() 35 | return torch.from_numpy(sample).float() 36 | 37 | def predict(self, sample, *args, **kwargs): 38 | """ 39 | Run the model for inference on a given sample and with the provided 40 | parameters. 41 | 42 | Args: 43 | sample (np.ndarray): an input sample for the model. 44 | args (list): list of arguments. 45 | kwargs (dict): list of key-value arguments. 46 | 47 | Returns: 48 | np.ndarray: a prediction for the model on the given sample. 49 | """ 50 | if self._double: 51 | self._model = self._model.double().eval() 52 | return self._model(self._prepare_sample(sample).double(), **kwargs).detach().numpy() 53 | self._model = self._model.float().eval() 54 | return self._model(self._prepare_sample(sample).float(), **kwargs).detach().numpy() 55 | -------------------------------------------------------------------------------- /depiction/models/torch/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/depiction/models/torch/tests/__init__.py -------------------------------------------------------------------------------- /depiction/models/torch/tests/core_test.py: -------------------------------------------------------------------------------- 1 | """Test TorchModel.""" 2 | import unittest 3 | 4 | import numpy as np 5 | import torchvision.models as models 6 | 7 | from depiction.core import DataType, Task 8 | from depiction.models.torch.core import TorchModel 9 | 10 | 11 | class TorchModelTestCase(unittest.TestCase): 12 | """Test TorchModel.""" 13 | 14 | def test_prediction(self): 15 | model = TorchModel( 16 | model=models.mobilenet_v2(pretrained=True), 17 | task=Task.CLASSIFICATION, 18 | data_type=DataType.IMAGE 19 | ) 20 | image = np.random.randn(1, 3, 224, 224) 21 | self.assertEqual(model.predict(image).shape, (1, 1000)) 22 | 23 | 24 | if __name__ == "__main__": 25 | unittest.main() 26 | -------------------------------------------------------------------------------- /depiction/models/torch/tests/torchvision_test.py: -------------------------------------------------------------------------------- 1 | """Test TorchVisionModel.""" 2 | import unittest 3 | 4 | import numpy as np 5 | import torchvision.models as models 6 | 7 | from depiction.core import DataType, Task 8 | from depiction.models.torch.torchvision import TorchVisionModel 9 | 10 | 11 | class TorchVisionModelTestCase(unittest.TestCase): 12 | """Test TorchVisionModel.""" 13 | 14 | def test_prediction(self): 15 | model = TorchVisionModel( 16 | model=models.mobilenet_v2(pretrained=True), 17 | task=Task.CLASSIFICATION, 18 | data_type=DataType.IMAGE 19 | ) 20 | image = np.random.randn(1, 3, 224, 224) 21 | self.assertEqual(model.predict(image).shape, (1, 1000)) 22 | 23 | 24 | if __name__ == "__main__": 25 | unittest.main() 26 | -------------------------------------------------------------------------------- /depiction/models/torch/torchvision.py: -------------------------------------------------------------------------------- 1 | """Core module for torchvision models.""" 2 | import torch 3 | import torchvision.transforms as transforms 4 | 5 | from .core import TorchModel 6 | 7 | # NOTE: From https://pytorch.org/docs/stable/torchvision/models.html. 8 | NORMALIZE = transforms.Normalize( 9 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 10 | ) 11 | 12 | 13 | class TorchVisionModel(TorchModel): 14 | """torchvision model wrapper.""" 15 | 16 | def __init__(self, model, task, data_type): 17 | """ 18 | Initialize a TorchVisionModel. 19 | 20 | Args: 21 | model (torch.nn.Module): model to wrap. 22 | task (depiction.core.Task): task type. 23 | data_type (depiction.core.DataType): data type. 24 | """ 25 | super().__init__(model=model, task=task, data_type=data_type) 26 | 27 | def _prepare_sample(self, sample): 28 | """ 29 | Prepare sample for the model. 30 | 31 | Args: 32 | sample (np.ndarray): an input sample for the model. 33 | 34 | Returns: 35 | torch.tensor: a tensor representing the sample. 36 | """ 37 | return torch.stack( 38 | [ 39 | NORMALIZE(example) for example in 40 | torch.unbind(TorchModel._prepare_sample(self, sample), dim=0) 41 | ], 42 | axis=0 43 | ) 44 | -------------------------------------------------------------------------------- /depiction/models/uri/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize uri module.""" 2 | from .cache import FileSystemModel # noqa 3 | from .cache import HTTPModel # noqa 4 | from .cache import COSModel # noqa 5 | from .rest_api import RESTAPIModel # noqa 6 | from .rest_api import MAXModel # noqa 7 | -------------------------------------------------------------------------------- /depiction/models/uri/cache/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize cache module.""" 2 | from .file_system_model import FileSystemModel # noqa 3 | from .http_model import HTTPModel # noqa 4 | from .cos_model import COSModel # noqa -------------------------------------------------------------------------------- /depiction/models/uri/cache/cache_model.py: -------------------------------------------------------------------------------- 1 | """Abstract interface for URI models.""" 2 | import os 3 | from abc import abstractclassmethod 4 | 5 | from ..uri_model import URIModel 6 | from ...base.utils import MODELS_SUBDIR 7 | 8 | 9 | class CacheModel(URIModel): 10 | """Abstract implementation of a cached URI model.""" 11 | 12 | def __init__(self, uri, task, data_type, cache_dir, filename=None): 13 | """ 14 | Initialize a CacheModel. 15 | 16 | Args: 17 | uri (str): URI to access the model. 18 | task (depiction.core.Task): task type. 19 | data_type (depiction.core.DataType): data type. 20 | cache_dir (str): cache directory. 21 | filename (str): name of the model file when cached. 22 | Defaults to None, a.k.a. inferring the name from 23 | uri. 24 | """ 25 | super().__init__(uri=uri, task=task, data_type=data_type) 26 | self.models_subdir = MODELS_SUBDIR 27 | if filename is None: 28 | filename = os.path.basename(self.uri) 29 | self.cache_dir = cache_dir 30 | self.filename = filename 31 | self.model_path = self._get_model_file(self.filename, self.cache_dir) 32 | 33 | @abstractclassmethod 34 | def _get_model_file(self, filename, cache_dir): 35 | """ 36 | Cache model file. 37 | 38 | Args: 39 | filename (str): name of the file. 40 | cache_dir (str): cache directory. 41 | 42 | Returns: 43 | str: path to the model file. 44 | """ 45 | raise NotImplementedError 46 | -------------------------------------------------------------------------------- /depiction/models/uri/cache/cos_model.py: -------------------------------------------------------------------------------- 1 | """Abstract interface from Cloud Object Storage (COS) models.""" 2 | import os 3 | import copy 4 | from minio import Minio 5 | 6 | from .cache_model import CacheModel 7 | 8 | 9 | class COSModel(CacheModel): 10 | """ 11 | Abstract implementation of a model cached from Cloud Object Storage (COS). 12 | """ 13 | 14 | def __init__( 15 | self, uri, task, data_type, cache_dir, filename=None, **kwargs 16 | ): 17 | """ 18 | Initialize a COSModel. 19 | 20 | Args: 21 | uri (str): URI to access the model. 22 | task (depiction.core.Task): task type. 23 | data_type (depiction.core.DataType): data type. 24 | cache_dir (str): cache directory. 25 | filename (str): name of the model file when cached. 26 | Defaults to None, a.k.a. inferring the name from 27 | uri. 28 | kwargs (dict): key-value arguments for the Minio client. 29 | """ 30 | # NOTE: check the validity of the uri provided. 31 | self.remote_coordinates = COSModel.parse_cos_uri(uri) 32 | self.minio_kwargs = copy.deepcopy(kwargs) 33 | # NOTE: we make sure there are no duplicated parameters 34 | _ = self.minio_kwargs.pop('access_key') 35 | _ = self.minio_kwargs.pop('secret_key') 36 | super().__init__( 37 | uri=uri, 38 | task=task, 39 | data_type=data_type, 40 | cache_dir=cache_dir, 41 | filename=filename 42 | ) 43 | 44 | def _get_model_file(self, filename, cache_dir): 45 | """ 46 | Cache model file. 47 | 48 | Args: 49 | filename (str): name of the file. 50 | cache_dir (str): cache directory. 51 | 52 | Returns: 53 | str: path to the model file. 54 | """ 55 | client = Minio( 56 | ( 57 | f'{self.remote_coordinates["host"]}:' 58 | f'{self.remote_coordinates["port"]}' 59 | ), 60 | access_key=self.remote_coordinates['access_key'], 61 | secret_key=self.remote_coordinates['secret_key'], 62 | **self.minio_kwargs 63 | ) 64 | filepath = os.path.join(cache_dir, self.models_subdir, filename) 65 | _ = client.fget_object( 66 | self.remote_coordinates['bucket'], 67 | self.remote_coordinates['filepath'], 68 | filepath, 69 | request_headers=self.minio_kwargs.get('request_headers', None) 70 | ) 71 | return filepath 72 | 73 | @staticmethod 74 | def parse_cos_uri(uri): 75 | """ 76 | Parse COS remote connection. 77 | 78 | Args: 79 | uri (str): cos uri. 80 | 81 | Returns: 82 | dict: a remote connection dictionary. 83 | """ 84 | if not uri.startswith('s3://'): 85 | raise RuntimeError('Invalid S3 URI: {}'.format(uri)) 86 | tokenized = uri[5:].split('/') 87 | authorization = tokenized[0] 88 | bucket = tokenized[1] 89 | filepath = '/'.join(tokenized[2:]) 90 | keys, host = authorization.split('@') 91 | access_key, secret_key = keys.split(':') 92 | splitted_host = host.split(':') 93 | if len(splitted_host) > 1: 94 | host, port = splitted_host 95 | else: 96 | host, port = splitted_host[0], None 97 | return { 98 | 'secret_key': secret_key, 99 | 'access_key': access_key, 100 | 'host': host, 101 | 'port': int(port) if port else port, 102 | 'bucket': bucket, 103 | 'filepath': filepath 104 | } 105 | -------------------------------------------------------------------------------- /depiction/models/uri/cache/file_system_model.py: -------------------------------------------------------------------------------- 1 | """Abstract interface for file system models.""" 2 | from .cache_model import CacheModel 3 | 4 | 5 | class FileSystemModel(CacheModel): 6 | """Abstract implementation of a model stored on file system.""" 7 | 8 | def __init__(self, uri, task, data_type): 9 | """ 10 | Initialize a FileSystemModel. 11 | 12 | Args: 13 | uri (str): URI to access the model. 14 | task (depiction.core.Task): task type. 15 | data_type (depiction.core.DataType): data type. 16 | """ 17 | super().__init__( 18 | uri=uri, 19 | task=task, 20 | data_type=data_type, 21 | cache_dir=None, 22 | filename=None 23 | ) 24 | 25 | def _get_model_file(self, filename, cache_dir): 26 | """ 27 | Cache model file. 28 | 29 | Args: 30 | filename (str): name of the file. 31 | cache_dir (str): cache directory. 32 | 33 | Returns: 34 | str: path to the model file. 35 | """ 36 | return self.uri 37 | -------------------------------------------------------------------------------- /depiction/models/uri/cache/http_model.py: -------------------------------------------------------------------------------- 1 | """Abstract interface for HTTP models.""" 2 | from .cache_model import CacheModel 3 | from ...base.utils import get_model_file 4 | 5 | 6 | class HTTPModel(CacheModel): 7 | """Abstract implementation of a model cached from HTTP.""" 8 | 9 | def __init__(self, uri, task, data_type, cache_dir, filename=None): 10 | """ 11 | Initialize a HTTPModel. 12 | 13 | Args: 14 | uri (str): URI to access the model. 15 | task (depiction.core.Task): task type. 16 | data_type (depiction.core.DataType): data type. 17 | cache_dir (str): cache directory. 18 | filename (str): name of the model file when cached. 19 | Defaults to None, a.k.a. inferring the name from 20 | uri. 21 | """ 22 | super().__init__( 23 | uri=uri, 24 | task=task, 25 | data_type=data_type, 26 | cache_dir=cache_dir, 27 | filename=filename 28 | ) 29 | 30 | def _get_model_file(self, filename, cache_dir): 31 | """ 32 | Cache model file. 33 | 34 | Args: 35 | filename (str): name of the file. 36 | cache_dir (str): cache directory. 37 | 38 | Returns: 39 | str: path to the model file. 40 | """ 41 | return get_model_file(filename, self.uri, cache_dir) 42 | -------------------------------------------------------------------------------- /depiction/models/uri/cache/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/depiction/models/uri/cache/tests/__init__.py -------------------------------------------------------------------------------- /depiction/models/uri/rest_api/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize rest_api module.""" 2 | from .rest_api_model import RESTAPIModel # noqa 3 | from .max_model import MAXModel # noqa -------------------------------------------------------------------------------- /depiction/models/uri/rest_api/max_model.py: -------------------------------------------------------------------------------- 1 | """Abstract interface for MAX models.""" 2 | import os 3 | 4 | from .rest_api_model import RESTAPIModel 5 | 6 | 7 | class MAXModel(RESTAPIModel): 8 | """ 9 | Abstract implementation of a MAX model. 10 | 11 | For a complete model list see here: 12 | https://developer.ibm.com/exchanges/models/all/. 13 | """ 14 | 15 | def __init__(self, uri, task, data_type): 16 | """ 17 | Initialize a MAX model. 18 | 19 | Args: 20 | uri (str): URI to access the model. 21 | task (depiction.core.Task): task type. 22 | data_type (depiction.core.DataType): data type. 23 | """ 24 | self.base_endpoint = 'model' 25 | super().__init__( 26 | endpoint=os.path.join(self.base_endpoint, 'predict'), 27 | uri=uri, 28 | task=task, 29 | data_type=data_type 30 | ) 31 | self.metadata_endpoint = os.path.join(self.base_endpoint, 'metadata') 32 | self.labels_endpoint = os.path.join(self.base_endpoint, 'labels') 33 | self.metadata = self._request( 34 | method='get', endpoint=self.metadata_endpoint 35 | ) 36 | -------------------------------------------------------------------------------- /depiction/models/uri/rest_api/rest_api_model.py: -------------------------------------------------------------------------------- 1 | """Abstract interface for REST API models.""" 2 | import os 3 | import requests 4 | from abc import abstractmethod 5 | 6 | from ..uri_model import URIModel 7 | 8 | 9 | class RESTAPIModel(URIModel): 10 | """Abstract implementation of a REST API model.""" 11 | 12 | def __init__(self, endpoint, uri, task, data_type): 13 | """ 14 | Initialize a REST API model. 15 | 16 | Args: 17 | endpoint (str): endpoint for prediction. 18 | uri (str): URI to access the model. 19 | task (depiction.core.Task): task type. 20 | data_type (depiction.core.DataType): data type. 21 | """ 22 | super().__init__(uri=uri, task=task, data_type=data_type) 23 | self.endpoint = endpoint 24 | 25 | def _request(self, method, endpoint=None, **kwargs): 26 | """ 27 | Perform a request to self.uri. 28 | 29 | Args: 30 | method (str): request method. 31 | endpoint (str): request endpoint. 32 | Defaults to None, a.k.a. use self.endpoint. 33 | kwargs (dict): key-value arguments for requests.request. 34 | 35 | Returns: 36 | dict: response dictionary. 37 | """ 38 | response = requests.request( 39 | method=method, 40 | url=os.path.join( 41 | self.uri, endpoint if endpoint else self.endpoint 42 | ), 43 | **kwargs 44 | ) 45 | response.raise_for_status() 46 | return response.json() 47 | 48 | @abstractmethod 49 | def _process_prediction(self, prediction): 50 | """ 51 | Process json prediction response. 52 | 53 | Args: 54 | prediction (dict): json prediction response. 55 | 56 | Returns: 57 | np.ndarray: numpy array representing the prediction. 58 | """ 59 | raise NotImplementedError 60 | 61 | @abstractmethod 62 | def _predict(self, sample, *args, **kwargs): 63 | """ 64 | Run the model for inference on a given sample and with the provided 65 | parameters. 66 | 67 | Args: 68 | sample (object): an input sample for the model. 69 | args (list): list of arguments. 70 | kwargs (dict): list of key-value arguments. 71 | 72 | Returns: 73 | a prediction for the model on the given sample. 74 | """ 75 | raise NotImplementedError 76 | 77 | def predict(self, sample, *args, **kwargs): 78 | """ 79 | Run the model for inference on a given sample and with the provided 80 | parameters. 81 | 82 | Args: 83 | sample (object): an input sample for the model. 84 | args (list): list of arguments. 85 | kwargs (dict): list of key-value arguments. 86 | 87 | Returns: 88 | a prediction for the model on the given sample. 89 | """ 90 | return self._process_prediction(self._predict(sample, *args, **kwargs)) -------------------------------------------------------------------------------- /depiction/models/uri/rest_api/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/depiction/models/uri/rest_api/tests/__init__.py -------------------------------------------------------------------------------- /depiction/models/uri/rest_api/tests/max_model_test.py: -------------------------------------------------------------------------------- 1 | """Test MAX model.""" 2 | import os 3 | import unittest 4 | from random import choice 5 | 6 | from depiction.core import Task, DataType 7 | from depiction.models.uri.rest_api.max_model import MAXModel 8 | 9 | 10 | class ConcreteTestModel(MAXModel): 11 | 12 | def __init__(self, uri, task_type, data_type): 13 | super(ConcreteTestModel, self).__init__(uri, task_type, data_type) 14 | 15 | def _process_prediction(self, prediction): 16 | return prediction 17 | 18 | def _predict(self, sample, *args, **kwargs): 19 | return sample 20 | 21 | 22 | class MAXModelTestCase(unittest.TestCase): 23 | """Test MAX model.""" 24 | 25 | def test_initialization(self): 26 | model = ConcreteTestModel( 27 | uri='http://{}:5000'.format( 28 | os.environ.get('TEST_MAX_BASE', 'localhost') 29 | ), 30 | task_type=choice(list(Task)), 31 | data_type=choice(list(DataType)) 32 | ) 33 | self.assertTrue(isinstance(model.metadata, dict)) 34 | self.assertEqual(model.metadata_endpoint, 'model/metadata') 35 | self.assertEqual(model.labels_endpoint, 'model/labels') 36 | self.assertEqual(model.endpoint, 'model/predict') 37 | 38 | 39 | if __name__ == "__main__": 40 | unittest.main() 41 | -------------------------------------------------------------------------------- /depiction/models/uri/rest_api/tests/rest_api_model_test.py: -------------------------------------------------------------------------------- 1 | """Test REST API model.""" 2 | import os 3 | import unittest 4 | from random import choice 5 | 6 | from depiction.core import DataType, Task 7 | from depiction.models.uri.rest_api.rest_api_model import RESTAPIModel 8 | 9 | 10 | class ConcreteTestModel(RESTAPIModel): 11 | 12 | def __init__(self, endpoint, uri, task_type, data_type): 13 | super(ConcreteTestModel, 14 | self).__init__(endpoint, uri, task_type, data_type) 15 | 16 | def _process_prediction(self, prediction): 17 | return prediction 18 | 19 | def _predict(self, sample, *args, **kwargs): 20 | return sample 21 | 22 | 23 | class RESTAPIModelTestCase(unittest.TestCase): 24 | """Test REST API model.""" 25 | 26 | def test_initialization(self): 27 | model = ConcreteTestModel( 28 | endpoint='predict', 29 | uri='http://{}:5000'.format( 30 | os.environ.get('TEST_REST_API', 'localhost') 31 | ), 32 | task_type=choice(list(Task)), 33 | data_type=choice(list(DataType)) 34 | ) 35 | self.assertTrue( 36 | model._request(method='get', endpoint='model/metadata'), dict 37 | ) 38 | self.assertTrue( 39 | model._request(method='get', endpoint='model/labels'), dict 40 | ) 41 | self.assertTrue( 42 | model._request( 43 | method='post', 44 | endpoint='model/predict', 45 | json={'text': ['a test.', 'another test.']} 46 | ), dict 47 | ) 48 | 49 | 50 | if __name__ == "__main__": 51 | unittest.main() 52 | -------------------------------------------------------------------------------- /depiction/models/uri/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/depiction/models/uri/tests/__init__.py -------------------------------------------------------------------------------- /depiction/models/uri/uri_model.py: -------------------------------------------------------------------------------- 1 | """Abstract interface for URI models.""" 2 | from ..base.base_model import BaseModel 3 | 4 | 5 | class URIModel(BaseModel): 6 | """Abstract implementation of a URI model.""" 7 | 8 | def __init__(self, uri, task, data_type): 9 | """ 10 | Initialize a URIModel. 11 | 12 | Args: 13 | uri (str): URI to access the model. 14 | task (depiction.core.Task): task type. 15 | data_type (depiction.core.DataType): data type. 16 | """ 17 | super().__init__(task=task, data_type=data_type) 18 | self.uri = uri 19 | -------------------------------------------------------------------------------- /depiction/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/depiction/tests/__init__.py -------------------------------------------------------------------------------- /depiction/tests/core_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from depiction.core import Task 4 | 5 | 6 | class TaskTestCase(unittest.TestCase): 7 | 8 | def testLtOperator(self): 9 | self.assertLess(Task.BINARY, Task.CLASSIFICATION) 10 | self.assertLess(Task.MULTICLASS, Task.CLASSIFICATION) 11 | 12 | self.assertFalse(Task.CLASSIFICATION < Task.BINARY) 13 | self.assertFalse(Task.CLASSIFICATION < Task.MULTICLASS) 14 | 15 | self.assertFalse(Task.CLASSIFICATION < Task.REGRESSION) 16 | self.assertFalse(Task.REGRESSION < Task.CLASSIFICATION) 17 | 18 | self.assertFalse(Task.BINARY < Task.BINARY) 19 | 20 | def testLeOperator(self): 21 | self.assertLessEqual(Task.BINARY, Task.CLASSIFICATION) 22 | self.assertLessEqual(Task.MULTICLASS, Task.CLASSIFICATION) 23 | self.assertLessEqual(Task.BINARY, Task.BINARY) 24 | 25 | self.assertFalse(Task.CLASSIFICATION <= Task.REGRESSION) 26 | self.assertFalse(Task.REGRESSION <= Task.CLASSIFICATION) 27 | 28 | def testGtOperator(self): 29 | self.assertGreater(Task.CLASSIFICATION, Task.BINARY) 30 | self.assertGreater(Task.CLASSIFICATION, Task.MULTICLASS) 31 | 32 | self.assertFalse(Task.BINARY > Task.CLASSIFICATION) 33 | self.assertFalse(Task.MULTICLASS > Task.CLASSIFICATION) 34 | 35 | self.assertFalse(Task.CLASSIFICATION > Task.REGRESSION) 36 | self.assertFalse(Task.REGRESSION > Task.CLASSIFICATION) 37 | 38 | self.assertFalse(Task.BINARY > Task.BINARY) 39 | 40 | def testGeOperator(self): 41 | self.assertGreaterEqual(Task.CLASSIFICATION, Task.BINARY) 42 | self.assertGreaterEqual(Task.CLASSIFICATION, Task.MULTICLASS) 43 | self.assertGreaterEqual(Task.BINARY, Task.BINARY) 44 | 45 | self.assertFalse(Task.CLASSIFICATION >= Task.REGRESSION) 46 | self.assertFalse(Task.REGRESSION >= Task.CLASSIFICATION) 47 | 48 | def testCheckSupport(self): 49 | supported = [Task.CLASSIFICATION, Task.MULTICLASS] 50 | 51 | self.assertTrue(Task.check_support(Task.BINARY, supported)) 52 | self.assertTrue(Task.check_support(Task.CLASSIFICATION, supported)) 53 | self.assertFalse(Task.check_support(Task.REGRESSION, supported)) 54 | 55 | 56 | if __name__ == "__main__": 57 | unittest.main() 58 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM continuumio/miniconda3 2 | # labels 3 | LABEL maintainer="Matteo Manica , An-phi Nguyen , Joris Cadow " 4 | # needed settings 5 | ENV PATH /opt/conda/bin:$PATH 6 | ENV LANG C 7 | # install system dependencies 8 | RUN apt-get update \ 9 | && apt-get install -y --no-install-recommends gcc g++ python3-dev libxrender-dev\ 10 | && rm -rf /var/lib/apt/lists/* 11 | # install rdkit 12 | RUN conda config --add channels https://conda.anaconda.org/rdkit 13 | RUN conda install -y rdkit==2019.03.1 python=3.7 14 | # install pip dependencies 15 | WORKDIR /build 16 | COPY requirements.txt /build/ 17 | RUN pip install --no-cache-dir -r requirements.txt 18 | # install depiction 19 | WORKDIR /build/depiction 20 | COPY depiction /build/depiction 21 | COPY bin /build/bin 22 | COPY setup.py /build/ 23 | RUN pip install --no-cache-dir /build 24 | # install jupyter 25 | RUN pip install --no-cache-dir jupyter==1.0.0 26 | # setup data for tests 27 | WORKDIR /build/data 28 | COPY data /build/data 29 | # setup the workspace 30 | WORKDIR /workspace 31 | COPY data /workspace/data 32 | COPY notebooks /workspace/notebooks 33 | # expose the right port 34 | EXPOSE 8888 35 | # setup the entrypoint 36 | COPY docker/docker-entrypoint.sh /usr/local/bin/ 37 | RUN chmod a+x /usr/local/bin/docker-entrypoint.sh 38 | # entrypoint to startup the notebook 39 | CMD ["docker-entrypoint.sh"] 40 | -------------------------------------------------------------------------------- /docker/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.2" 2 | services: 3 | depiction: 4 | build: 5 | context: .. 6 | dockerfile: docker/Dockerfile 7 | container_name: depiction 8 | image: drugilsberg/depiction 9 | ports: 10 | - "8899:8888" 11 | environment: 12 | - JUPYTER_TOKEN=depiction-token 13 | -------------------------------------------------------------------------------- /docker/docker-entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | echo "Start jupyter-notebook on port 8888" 4 | 5 | jupyter notebook --ip="0.0.0.0" --port=8888 \ 6 | --no-browser --allow-root \ 7 | --NotebookApp.token="${JUPYTER_TOKEN}" --notebook-dir="/workspace" 8 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: depiction-env 2 | channels: 3 | - https://conda.anaconda.org/rdkit 4 | dependencies: 5 | - python>=3.6 6 | - pip 7 | - pip: 8 | - spacy==2.1.4 9 | - pandas==0.24.2 10 | - scikit-learn==0.21.2 11 | - tensorflow==1.14.0 12 | - torch>=1.0.0 13 | - torchvision>=0.2.1 14 | - seaborn==0.9.0 15 | - lime==0.1.1.32 16 | - requests==2.22.0 17 | - minio==5.0.4 18 | - imageio==2.6.1 19 | - aix360==0.1.0 20 | - alibi==0.3.2 21 | - kipoi==0.6.24 22 | - concise==0.6.8 23 | - rdkit-pypi==2021.03.2 24 | - anchor_custom @ git+https://github.com/phineasng/anchor@6419bbdfb46b3f1c0d7738e521623856b92bf5af#egg=anchor_custom-0.0.0.5 25 | - paccmann @ git+https://github.com/drugilsberg/paccmann@77eaedb860b66c32d541398c32caa34e93b7c443#egg=paccmann-0.1 26 | -------------------------------------------------------------------------------- /examples/cem_mnist.py: -------------------------------------------------------------------------------- 1 | # %% [markdown] 2 | # # Contrastive Explanations Method (CEM) applied to MNIST 3 | 4 | # %% [markdown] 5 | # The Contrastive Explanation Method (CEM) can generate black box model explanations in terms of pertinent positives (PP) and pertinent negatives (PN). For PP, it finds what should be minimally and sufficiently present (e.g. important pixels in an image) to justify its classification. PN on the other hand identify what should be minimally and necessarily absent from the explained instance in order to maintain the original prediction. 6 | # 7 | # The original paper where the algorithm is based on can be found on [arXiv](https://arxiv.org/pdf/1802.07623.pdf). 8 | # Depiction wraps an implementation by the alibi package and follows their [example](https://docs.seldon.io/projects/alibi/en/stable/examples/cem_mnist.html) heavily. 9 | # %% 10 | 11 | import tempfile 12 | import random 13 | # import pandas as pd 14 | import numpy as np 15 | 16 | import matplotlib 17 | from IPython import get_ipython 18 | # get_ipython().run_line_magic('matplotlib', 'inline') 19 | import matplotlib.pyplot as plt 20 | import seaborn as sns 21 | # import ipywidgets as widgets 22 | # from ipywidgets import interact, interact_manual 23 | import tensorflow as tf 24 | tf.logging.set_verbosity(tf.logging.ERROR) # suppress deprecation messages 25 | import tensorflow.keras as keras 26 | from tensorflow.keras import backend as K 27 | from tensorflow.keras.models import load_model 28 | from tensorflow.keras.utils import to_categorical 29 | 30 | from depiction.models.base.utils import get_model_file 31 | from depiction.models.uri.cache.http_model import HTTPModel 32 | from depiction.core import Task, DataType 33 | from depiction.interpreters.alibi.contrastive.cem import CEM 34 | 35 | # %% [markdown] 36 | # ## Load MNIST data 37 | 38 | # %% 39 | (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() 40 | print('x_train shape:', x_train.shape, 'y_train shape:', y_train.shape) 41 | plt.gray() 42 | # plt.imshow(x_test[4]); 43 | 44 | # %% [markdown] 45 | # Models were trained as shown in [alibis example](https://docs.seldon.io/projects/alibi/en/stable/examples/cem_mnist.html) 46 | # which we follow but using depiction with the benefit of using its other interpreters 47 | # %% 48 | 49 | 50 | class MNISTClassifier(HTTPModel): 51 | 52 | def __init__( 53 | self, 54 | filename='mninst_cnn.h5', 55 | origin='https://ibm.box.com/shared/static/k1x70cmr01fahob5ub7y2r82jqv3r75b.h5', # noqa 56 | cache_dir=tempfile.mkdtemp() 57 | ): 58 | """Initialize the CellTyper.""" 59 | super().__init__( 60 | uri=origin, 61 | task=Task.CLASSIFICATION, 62 | data_type=DataType.IMAGE, 63 | cache_dir=cache_dir, 64 | filename=filename 65 | ) 66 | self.model = keras.models.load_model(self.model_path) 67 | 68 | def predict(self, sample): 69 | return self.model.predict(sample) 70 | 71 | 72 | # %% 73 | # A model from depiction for interpretation 74 | cnn = MNISTClassifier() 75 | cnn.model.summary() 76 | # %% 77 | # CEM accepts an optional keras autoencoder to find better a pertinent 78 | # negative/positive 79 | 80 | ae = keras.models.load_model( 81 | get_model_file( 82 | filename='mninst_ae.h5', 83 | origin= 84 | 'https://ibm.box.com/shared/static/psogbwnx1cz0s8w6z2fdswj25yd7icpi.h5', # noqa 85 | cache_dir=cnn.cache_dir 86 | ) 87 | ) 88 | 89 | ae.summary() 90 | 91 | # %% [markdown] 92 | # The models were trained and expext processed data 93 | # so we scale, reshape and categorize 94 | 95 | # %% 96 | 97 | 98 | def transform(x): 99 | """Move to -0.5, 0.5 range and add channel dimension.""" 100 | return np.expand_dims(x.astype('float32') / 255 - 0.5, axis=-1) 101 | 102 | 103 | def transform_sample(x): 104 | return np.expand_dims(transform(x), axis=0) 105 | 106 | 107 | def inverse_transform(data): 108 | return (data.squeeze() + 0.5) * 255 109 | 110 | 111 | def show_image(x): 112 | return plt.imshow(x.squeeze()) 113 | 114 | 115 | # %% [markdown] 116 | # Compare original with decoded images 117 | 118 | # %% 119 | score = cnn.model.evaluate( 120 | transform(x_test), to_categorical(y_test), verbose=0 121 | ) 122 | print('Test accuracy: ', score[1]) 123 | 124 | # %% ------------------------------------------ 125 | 126 | decoded_imgs = ae.predict(transform(x_test)) 127 | n = 5 128 | plt.figure(figsize=(20, 4)) 129 | for i in range(1, n + 1): 130 | # display original 131 | ax = plt.subplot(2, n, i) 132 | # show_image(transform(x_test[i])) 133 | ax.get_xaxis().set_visible(False) 134 | ax.get_yaxis().set_visible(False) 135 | # display reconstruction 136 | ax = plt.subplot(2, n, i + n) 137 | # show_image(transform(decoded_imgs[i])) 138 | ax.get_xaxis().set_visible(False) 139 | ax.get_yaxis().set_visible(False) 140 | 141 | # plt.show() 142 | 143 | # %% [markdown] 144 | # ## Generate contrastive explanation with pertinent negative 145 | # %% [markdown] 146 | # Explained instance: 147 | 148 | # %% 149 | idx = 15 150 | X = transform_sample(x_test[idx]) 151 | 152 | # %% 153 | # show_image(X) 154 | 155 | # %% [markdown] 156 | # Model prediction: 157 | 158 | # %% 159 | cnn.predict(X).argmax(), cnn.predict(X).max() 160 | 161 | # %% [markdown] 162 | # CEM parameters: 163 | 164 | # %% 165 | mode = 'PN' # 'PN' (pertinent negative) or 'PP' (pertinent positive) 166 | shape = X.shape # instance shape, batchsize must be 1 167 | assert shape[0] == 1 168 | kappa = 0. # minimum difference needed between the prediction probability for the perturbed instance on the 169 | # class predicted by the original instance and the max probability on the other classes 170 | # in order for the first loss term to be minimized 171 | beta = .1 # weight of the L1 loss term 172 | gamma = 100 # weight of the optional auto-encoder loss term 173 | c_init = 1. # initial weight c of the loss term encouraging to predict a different class (PN) or 174 | # the same class (PP) for the perturbed instance compared to the original instance to be explained 175 | c_steps = 10 # nb of updates for c 176 | max_iterations = 1000 # nb of iterations per value of c 177 | feature_range = ( 178 | x_train.min(), x_train.max() 179 | ) # feature range for the perturbed instance 180 | clip = (-1000., 1000.) # gradient clipping 181 | lr = 1e-2 # initial learning rate 182 | no_info_val = -1. # a value, float or feature-wise, which can be seen as containing no info to make a prediction 183 | # perturbations towards this value means removing features, and away means adding features 184 | # for our MNIST images, the background (-0.5) is the least informative, 185 | # so positive/negative perturbations imply adding/removing features 186 | 187 | # %% [markdown] 188 | # Generate pertinent negative: 189 | 190 | # %% 191 | # initialize CEM explainer and explain instance 192 | cem = CEM( 193 | cnn, 194 | mode, 195 | shape, 196 | kappa=kappa, 197 | beta=beta, 198 | feature_range=feature_range, 199 | gamma=gamma, 200 | ae_model=ae, 201 | max_iterations=max_iterations, 202 | c_init=c_init, 203 | c_steps=c_steps, 204 | learning_rate_init=lr, 205 | clip=clip, 206 | no_info_val=no_info_val 207 | ) 208 | 209 | explanation = cem.interpret(X, verbose=True) 210 | 211 | # %% [markdown] 212 | # Pertinent negative: 213 | 214 | # %% 215 | print('Pertinent negative prediction: {}'.format(explanation[mode + '_pred'])) 216 | # show_image(explanation[mode]); 217 | 218 | # %% [markdown] 219 | # ## Generate pertinent positive 220 | 221 | # %% 222 | mode = 'PP' 223 | 224 | # %% 225 | # initialize CEM explainer and explain instance 226 | cem = CEM( 227 | cnn, 228 | mode, 229 | shape, 230 | kappa=kappa, 231 | beta=beta, 232 | feature_range=feature_range, 233 | gamma=gamma, 234 | ae_model=ae, 235 | max_iterations=max_iterations, 236 | c_init=c_init, 237 | c_steps=c_steps, 238 | learning_rate_init=lr, 239 | clip=clip, 240 | no_info_val=no_info_val 241 | ) 242 | 243 | explanation = cem.interpret(X, verbose=True) 244 | 245 | # %% [markdown] 246 | # Pertinent positive: 247 | 248 | # %% 249 | print('Pertinent positive prediction: {}'.format(explanation[mode + '_pred'])) 250 | # show_image(explanation[mode]); 251 | 252 | # %% 253 | # to delete the downloaded files before your next reboot 254 | # import os 255 | # os.remove(cnn.model_path) 256 | # os.remove(os.path.join(cnn.cache_dir, 'mninst_ae.h5')) 257 | -------------------------------------------------------------------------------- /examples/uwashers_imagenet.py: -------------------------------------------------------------------------------- 1 | """UWahsers for images.""" 2 | # %% 3 | import json 4 | import numpy as np 5 | import keras_applications 6 | from tensorflow import keras 7 | 8 | from depiction.core import DataType, Task 9 | from depiction.models.keras import KerasApplicationModel 10 | from depiction.interpreters.u_wash import UWasher 11 | 12 | 13 | # %% 14 | # general utils 15 | def image_preprocessing(image_path, preprocess_input, target_size): 16 | """ 17 | Read and preprocess an image from disk. 18 | 19 | Args: 20 | image_path (str): path to the image. 21 | preprocess_input (funciton): a preprocessing function. 22 | target_size (tuple): image target size. 23 | 24 | Returns: 25 | np.ndarray: the preprocessed image. 26 | """ 27 | image = keras.preprocessing.image.load_img( 28 | image_path, target_size=target_size 29 | ) 30 | x = keras.preprocessing.image.img_to_array(image) 31 | x = np.expand_dims(x, axis=0) 32 | return preprocess_input(x) 33 | 34 | 35 | def get_imagenet_labels(): 36 | """ 37 | Get ImamgeNet labels. 38 | 39 | Returns: 40 | list: list of labels. 41 | """ 42 | labels_filepath = keras.utils.get_file( 43 | 'imagenet_class_index.json', 44 | keras_applications.imagenet_utils.CLASS_INDEX_PATH 45 | ) 46 | with open(labels_filepath) as fp: 47 | labels_json = json.load(fp) 48 | labels = [None] * len(labels_json) 49 | for index, (_, label) in labels_json.items(): 50 | labels[int(index)] = label 51 | return labels 52 | 53 | 54 | # %% 55 | labels = get_imagenet_labels() 56 | 57 | # %% 58 | # instantiate the model 59 | model = KerasApplicationModel( 60 | keras.applications.MobileNetV2(), Task.CLASSIFICATION, DataType.IMAGE 61 | ) 62 | 63 | #%% 64 | image_path = keras.utils.get_file( 65 | 'elephant.jpg', 66 | 'https://upload.wikimedia.org/wikipedia/commons/thumb/f/f9/Zoorashia_elephant.jpg/120px-Zoorashia_elephant.jpg' # noqa 67 | ) 68 | image = image_preprocessing( 69 | image_path, 70 | keras.applications.mobilenet_v2.preprocess_input, 71 | target_size=(224, 224) 72 | ) 73 | 74 | # LIME 75 | # %% 76 | interpreter = UWasher('lime', model, class_names=labels) 77 | 78 | # %% 79 | explanation = interpreter.interpret(image.squeeze()) 80 | 81 | # Anchors 82 | # %% 83 | interpreter = UWasher('anchors', model) 84 | 85 | # %% 86 | explanation = interpreter.interpret(image.squeeze()) 87 | -------------------------------------------------------------------------------- /notebooks/celltype_training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Training of a super simple model for celltype classification" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import tensorflow as tf\n", 17 | "!which python\n", 18 | "!python --version\n", 19 | "print(tf.VERSION)\n", 20 | "print(tf.keras.__version__)\n", 21 | "!pwd # start jupyter under notebooks/ for correct relative paths" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import datetime\n", 31 | "import inspect\n", 32 | "import pandas as pd\n", 33 | "import numpy as np\n", 34 | "import seaborn as sns\n", 35 | "from tensorflow.keras import layers\n", 36 | "from tensorflow.keras.utils import to_categorical\n", 37 | "from sklearn.model_selection import train_test_split\n", 38 | "from sklearn.preprocessing import MinMaxScaler\n", 39 | "from depiction.models.examples.celltype.celltype import one_hot_encoding, one_hot_decoding" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "## a look at the data\n", 47 | "labels are categories 1-20, here's the associated celltype:" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "meta_series = pd.read_csv('../data/single-cell/metadata.csv', index_col=0)\n", 57 | "meta_series" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "There are 13 unbalanced classes, and over 80k samples" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "data_df = pd.read_csv('../data/single-cell/data.csv')\n", 74 | "data_df.groupby('category').count()['CD45']" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "data_df.sample(n=10)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "print(inspect.getsource(one_hot_encoding)) # from keras, but taking care of 1 indexed classes\n", 93 | "print(inspect.getsource(one_hot_decoding))" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "classes = data_df['category'].values\n", 103 | "labels = one_hot_encoding(classes)\n", 104 | "\n", 105 | "#scale the data from 0 to 1\n", 106 | "min_max_scaler = MinMaxScaler(feature_range=(0, 1), copy=True)\n", 107 | "data = min_max_scaler.fit_transform(data_df.drop('category', axis=1).values)\n", 108 | "data.shape" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "one_hot_decoding(labels)" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "data_train, data_test, labels_train, labels_test = train_test_split(\n", 127 | " data, labels, test_size=0.33, random_state=42, stratify=data_df.category)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "labels" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "batchsize = 32" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "dataset = tf.data.Dataset.from_tensor_slices((data_train, labels_train))\n", 155 | "dataset = dataset.shuffle(2 * batchsize).batch(batchsize)\n", 156 | "dataset = dataset.repeat()\n", 157 | "\n", 158 | "testset = tf.data.Dataset.from_tensor_slices((data_test, labels_test))\n", 159 | "testset = testset.batch(batchsize)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "## I don't know how a simpler network would look like" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": { 173 | "scrolled": false 174 | }, 175 | "outputs": [], 176 | "source": [ 177 | "model = tf.keras.Sequential()\n", 178 | "# Add a softmax layer with output units per celltype:\n", 179 | "model.add(layers.Dense(\n", 180 | " len(meta_series), activation='softmax',\n", 181 | " batch_input_shape=tf.data.get_output_shapes(dataset)[0]\n", 182 | "))" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "model.summary()" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "model.compile(optimizer=tf.keras.optimizers.Adam(0.001),\n", 201 | " loss='categorical_crossentropy',\n", 202 | " metrics=[tf.keras.metrics.categorical_accuracy])" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "metadata": { 209 | "scrolled": true 210 | }, 211 | "outputs": [], 212 | "source": [ 213 | "# evaluation on testset on every epoch\n", 214 | "# log_dir=\"logs/fit/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n", 215 | "# tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)\n", 216 | "model.fit(\n", 217 | " dataset,\n", 218 | " epochs=20, steps_per_epoch=np.ceil(data_train.shape[0]/batchsize),\n", 219 | " validation_data=testset, # callbacks=[tensorboard_callback]\n", 220 | ")" 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "metadata": {}, 226 | "source": [ 227 | "## Is such a simple model interpretable?" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "# Save entire model to a HDF5 file\n", 237 | "model.save('./celltype_model.h5')" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "# tensorboard --logdir logs/fit" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "# To recreate the exact same model, including weights and optimizer.\n", 256 | "# model = tf.keras.models.load_model('../data/models/celltype_dnn_model.h5')" 257 | ] 258 | }, 259 | { 260 | "cell_type": "markdown", 261 | "metadata": {}, 262 | "source": [ 263 | "# What is the effect of increasing model complexity? \n", 264 | "Play around by adding some layers, train and save the model under some name to use with the other notebook." 265 | ] 266 | }, 267 | { 268 | "cell_type": "markdown", 269 | "metadata": {}, 270 | "source": [ 271 | "![title](https://i.kym-cdn.com/photos/images/newsfeed/000/531/557/a88.jpg)" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": null, 277 | "metadata": {}, 278 | "outputs": [], 279 | "source": [ 280 | "model = tf.keras.Sequential()\n", 281 | "# Adds a densely-connected layers with 64 units to the model:\n", 282 | "model.add(layers.Dense(64, activation='relu', batch_input_shape=tf.data.get_output_shapes(dataset)[0])) # \n", 283 | "# ...\n", 284 | "# do whatever you want\n", 285 | "# model.add(layers.Dense(64, activation='relu'))\n", 286 | "# model.add(layers.Dropout(0.5))\n", 287 | "# ...\n", 288 | "# Add a softmax layer with output units per celltype:\n", 289 | "model.add(layers.Dense(len(meta_series), activation='softmax'))" 290 | ] 291 | } 292 | ], 293 | "metadata": { 294 | "kernelspec": { 295 | "display_name": "Python 3", 296 | "language": "python", 297 | "name": "python3" 298 | }, 299 | "language_info": { 300 | "codemirror_mode": { 301 | "name": "ipython", 302 | "version": 3 303 | }, 304 | "file_extension": ".py", 305 | "mimetype": "text/x-python", 306 | "name": "python", 307 | "nbconvert_exporter": "python", 308 | "pygments_lexer": "ipython3", 309 | "version": "3.6.5" 310 | } 311 | }, 312 | "nbformat": 4, 313 | "nbformat_minor": 2 314 | } 315 | -------------------------------------------------------------------------------- /notebooks/deepbind.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Having fun with DeepBind" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import warnings; warnings.filterwarnings('ignore', category=FutureWarning)\n", 17 | "import tensorflow as tf; tf.logging.set_verbosity(tf.logging.ERROR) # suppress deprecation messages\n", 18 | "from depiction.models.examples.deepbind.deepbind import DeepBind, create_DNA_language\n", 19 | "from depiction.interpreters.u_wash.u_washer import UWasher\n", 20 | "from ipywidgets import interact" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "## Setup task" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "class_names = ['NOT BINDING', 'BINDING']\n", 37 | "classifier = DeepBind(model='DeepBind/Homo_sapiens/TF/D00328.003_SELEX_CTCF', min_length=35)\n", 38 | "# this class has task (classification) and data_type (text) and some processing defined for your convenience" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "http://kipoi.org/models/DeepBind/Homo_sapiens/TF/D00328.003_SELEX_CTCF/" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "# Interpreter parametrization" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "# a LIME text interpreter\n", 62 | "lime_explanation_configs = {\n", 63 | " 'labels': (1,),\n", 64 | "}\n", 65 | "lime_params = {\n", 66 | " 'class_names': class_names,\n", 67 | " 'split_expression': list,\n", 68 | " 'bow': False,\n", 69 | " 'char_level': True\n", 70 | "}\n", 71 | "\n", 72 | "# an Anchor text intepreter\n", 73 | "anchors_explanation_configs = {\n", 74 | " 'use_proba': False,\n", 75 | " 'batch_size': 100\n", 76 | "}\n", 77 | "anchors_params = {\n", 78 | " 'class_names': class_names,\n", 79 | " 'nlp': create_DNA_language(),\n", 80 | " 'unk_token': 'N',\n", 81 | " 'sep_token': '',\n", 82 | " 'use_unk_distribution': True\n", 83 | "}" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "### Wrapper for the interactive widget" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "class InteractiveWrapper:\n", 100 | " def __init__(self, classifier):\n", 101 | " self.classifier = classifier\n", 102 | " self.lime_explainer = UWasher(\"lime\", self.classifier, **lime_params)\n", 103 | " self.anchor_explainer = UWasher(\"anchors\", self.classifier, **anchors_params)\n", 104 | "\n", 105 | " def callback(self, sequence):\n", 106 | " # LIME\n", 107 | " self.classifier.use_labels = False\n", 108 | " self.lime_explainer.interpret(sequence, explanation_configs=lime_explanation_configs)\n", 109 | " # Anchors \n", 110 | " self.classifier.use_labels = True\n", 111 | " self.anchor_explainer.interpret(sequence, explanation_configs=anchors_explanation_configs)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "# Let's interpret" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "wrapper = InteractiveWrapper(classifier)\n", 128 | "\n", 129 | "interact(\n", 130 | " wrapper.callback,\n", 131 | " sequence=[\n", 132 | " 'AGGCTAGCTAGGGGCGCCC', 'AGGCTAGCTAGGGGCGCTT', 'AGGGTAGCTAGGGGCGCTT',\n", 133 | " 'AGGGTAGCTGGGGGCGCTT', 'AGGCTAGGTGGGGGCGCTT', 'AGGCTCGGTGGGGGCGCTT',\n", 134 | " 'AGGCTCGGTAGGGGGCGATT'\n", 135 | " ]\n", 136 | ")" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "CTCF binding motif\n", 144 | "![CTCF binding motif](https://media.springernature.com/full/springer-static/image/art%3A10.1186%2Fgb-2009-10-11-r131/MediaObjects/13059_2009_Article_2281_Fig2_HTML.jpg?as=webp)\n", 145 | "from Essien, Kobby, et al. \"CTCF binding site classes exhibit distinct evolutionary, genomic, epigenomic and transcriptomic features.\" Genome biology 10.11 (2009): R131." 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "classifier_foxa1 = DeepBind('DeepBind/Homo_sapiens/TF/D00761.001_ChIP-seq_FOXA1', min_length=40)" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "wrapper_foxa1 = InteractiveWrapper(classifier_foxa1)\n", 164 | "\n", 165 | "interact(wrapper_foxa1.callback, sequence='TGTGTGTGTG')" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "metadata": {}, 171 | "source": [ 172 | "FOXA1 binding motif\n", 173 | "![FOXA1 binding motif](https://ismara.unibas.ch/supp/dataset1_IBM_v2/ismara_report/logos/FOXA1.png)\n", 174 | "from https://ismara.unibas.ch/supp/dataset1_IBM_v2/ismara_report/pages/FOXA1.html" 175 | ] 176 | } 177 | ], 178 | "metadata": { 179 | "kernelspec": { 180 | "display_name": "Python 3", 181 | "language": "python", 182 | "name": "python3" 183 | }, 184 | "language_info": { 185 | "codemirror_mode": { 186 | "name": "ipython", 187 | "version": 3 188 | }, 189 | "file_extension": ".py", 190 | "mimetype": "text/x-python", 191 | "name": "python", 192 | "nbconvert_exporter": "python", 193 | "pygments_lexer": "ipython3", 194 | "version": "3.7.3" 195 | } 196 | }, 197 | "nbformat": 4, 198 | "nbformat_minor": 2 199 | } 200 | -------------------------------------------------------------------------------- /notebooks/kaggle_create_a_new_api_token.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/notebooks/kaggle_create_a_new_api_token.png -------------------------------------------------------------------------------- /notebooks/kaggle_go_to_your_account.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/notebooks/kaggle_go_to_your_account.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | spacy>=2.2.0 2 | pandas==1.0.5 3 | scikit-learn==0.21.2 4 | tensorflow==1.14.0 5 | torch>=1.2.0 6 | torchvision>=0.4.2 7 | seaborn==0.9.0 8 | lime==0.1.1.32 9 | requests==2.22.0 10 | minio==5.0.4 11 | imageio==2.6.1 12 | aix360==0.1.0 13 | alibi==0.3.2 14 | kipoi==0.6.24 15 | concise==0.6.8 16 | captum==0.1.0 17 | kaggle==1.5.6 18 | wget==3.2 19 | Pillow==8.1.1 20 | h5py<3.0.0 21 | rdkit-pypi==2021.03.2 22 | anchor_custom @ git+https://github.com/phineasng/anchor@6419bbdfb46b3f1c0d7738e521623856b92bf5af#egg=anchor_custom-0.0.0.5 23 | paccmann @ git+https://github.com/drugilsberg/paccmann@77eaedb860b66c32d541398c32caa34e93b7c443#egg=paccmann-0.1 24 | deepexplain @ git+https://github.com/marcoancona/DeepExplain.git#egg=deepexplain 25 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Package installer.""" 2 | import os 3 | from setuptools import setup 4 | from setuptools import find_packages 5 | 6 | LONG_DESCRIPTION = '' 7 | if os.path.exists('README.md'): 8 | with open('README.md') as fp: 9 | LONG_DESCRIPTION = fp.read() 10 | 11 | REQUIREMENTS = [] 12 | if os.path.exists('requirements.txt'): 13 | with open('requirements.txt') as fp: 14 | REQUIREMENTS = [ 15 | line.strip() 16 | for line in fp 17 | ] 18 | 19 | setup( 20 | name='depiction', 21 | version='0.0.1', 22 | description='DEPICTION, a package for deep learning interpretability.', 23 | long_description=LONG_DESCRIPTION, 24 | long_description_content_type='text/markdown', 25 | author='Matteo Manica, An-phi Nguyen, Joris Cadow', 26 | author_email=( 27 | 'drugilsberg@gmail.com, nguyen.phineas@gmail.com, joriscadow@gmail.com' 28 | ), 29 | url='https://github.com/IBM/dl-interpretability-compbio', 30 | license='Apache License 2.0', 31 | install_requires=REQUIREMENTS, 32 | classifiers=[ 33 | 'Intended Audience :: Developers', 34 | 'Intended Audience :: Science/Research', 35 | 'License :: OSI Approved :: MIT License', 36 | 'Programming Language :: Python :: 3', 37 | 'Programming Language :: Python :: 3.6', 38 | 'Programming Language :: Python :: 3.7', 39 | 'Topic :: Software Development :: Libraries :: Python Modules' 40 | ], 41 | packages=find_packages(), 42 | scripts=['bin/depiction-models-download'] 43 | ) 44 | -------------------------------------------------------------------------------- /workshops/20190909_BC2/README.md: -------------------------------------------------------------------------------- 1 | 2 | # T4: Interpretability for deep learning models in computational biology 3 | 4 | ## Location 5 | 6 | ### Venue 7 | 8 | [University of Basel](https://www.unibas.ch/de), Kollegienhaus, Petersplatz 1, CH-4001, Basel 9 | 10 | ### Room 11 | 12 | Hörsaal 117 13 | 14 | ### Map 15 | 16 | [How to get there](https://www.google.ch/maps/place/Petersplatz+1,+4051+Basel/@47.5584029,7.5825258,17.67z/data=!4m13!1m7!3m6!1s0x4791b9a96c44bba1:0xe0a7bc8b66787bdb!2sPetersplatz+1,+4051+Basel!3b1!8m2!3d47.5586129!4d7.5827926!3m4!1s0x4791b9a96c44bba1:0xe0a7bc8b66787bdb!8m2!3d47.5586129!4d7.5827926) 17 | 18 | ## Requirements 19 | 20 | This course is designed for everyone who would like to learn the basics of interpretability techniques for deep learning. The tutorial will provide a brief introduction to key concepts in deep learning, before exploring recent developments in the field of interpretability. **Participants who want to participate in the hands-on exercises should bring a laptop and follow the steps to install the docker image for `depiction` prior the day of the tutorial. Instructions can be found [here](https://github.com/IBM/dl-interpretability-compbio/blob/master/README.md).** 21 | 22 | ## Organisers and tutors 23 | 24 | - María Rodríguez Martínez, IBM Research Zürich 25 | - Joris Cadow, IBM Research Zürich 26 | - An-Phi Nguyen, IBM Research Zürich 27 | 28 | ## Schedule 29 | 30 | | Time | Title | Speaker | 31 | |-------------|--------------------------------------|--------------------------| 32 | | 09:00-10:00 | Introduction to deep learning | María Rodríguez Martínez | 33 | | 10:00-10:30 | Model Zoo I | Joris Cadow | 34 | | 10:30-10:45 | Coffee break | N/A | 35 | | 10:45-11:15 | Model Zoo II | Joris Cadow | 36 | | 11:15-12:00 | Interpretability in deep learning | An-phi Nguyen | 37 | | 12:00-13:00 | Lunch | N/A | 38 | | 13:00-13:30 | Introduction to depiction | An-phi Nguyen | 39 | | 13:30-14:00 | Interpret DeepBind (genomics) | An-phi Nguyen | 40 | | 14:00-14:45 | depiction DIY I | Joris Cadow | 41 | | 14:45-15:00 | Break | N/A | 42 | | 15:00-15:30 | depiction DIY II | Joris Cadow | 43 | | 15:30-16:00 | Interpret PaccMann (drug sensitvity) | An-phi Nguyen | 44 | -------------------------------------------------------------------------------- /workshops/20191120_ODSC2019/README.md: -------------------------------------------------------------------------------- 1 | # Opening The Black Box — Interpretability In Deep Learning 2 | 3 | ## Schedule 4 | 5 | | Time | Title | Speaker | 6 | |-------------|--------------------------|---------------| 7 | | 14:00-14:45 | What is interpretability | Matteo Manica | 8 | | 14:45-15:30 | Introducing depiction | Matteo Manica | 9 | | 15:30-16:00 | Coffee break | N/A | 10 | | 16:00-16:45 | Explaining images | Matteo Manica | 11 | | 16:45-17:30 | Explaining tables | Joris Cadow | 12 | 13 | ## Slides 14 | 15 | The slides can be found on box: [https://ibm.box.com/v/odsc-2019-tutorial](https://ibm.box.com/v/odsc-2019-tutorial) 16 | 17 | ## Running the notebooks 18 | 19 | The notebooks can be run either with a conda environment or using docker. 20 | Either way, follow the general [README.md](../../README.md) and see the respective section below for details concerning the workshop. 21 | 22 | ### Conda 23 | 24 | Depending from where you start the `jupyter notebook` server you might have do minor adjustments to relative paths in notebooks from `workshops/20191120_ODSC2019/notebooks`. 25 | 26 | ### Docker 27 | 28 | With the docker setup we mount a different directory. 29 | ```docker run --mount src=`pwd`/workshops/20191120_ODSC2019/notebooks,target=/workspace/notebooks,type=bind -p 8899:8888 -it drugilsberg/depiction``` 30 | 31 | and start your browser at [http://localhost:8899/tree/notebooks](http://localhost:8899/tree/notebooks) 32 | -------------------------------------------------------------------------------- /workshops/20191120_ODSC2019/blog/README.md: -------------------------------------------------------------------------------- 1 | # Opening The Black Box — Interpretability In Deep Learning 2 | 3 | ## Why interpretability? 4 | 5 | In the last decade, the application of deep neural networks to long-standing problems has brought a break-through in performance and prediction power. 6 | However, high accuracy, deriving from the increased model complexity, often comes at the price of loss of interpretability, i.e., many of these models behave as black-boxes and fail to provide explanations on their predictions. 7 | While in certain application fields this issue may play a secondary role, in high risk domains, e.g., health care, it is crucial to build trust in a model and being able to understand its behaviuor. 8 | 9 | ## What is interpretability? 10 | 11 | The definition of the verb *interpret* is "to explain or tell the meaning of : present in understandable terms" ([Merriam-Webster 2019](https://www.merriam-webster.com/dictionary/interpret)). 12 | Despite the apparent simplicity of this statement, the machine learning research community is struggling to agree upon a formal definition of the concept of interpretability/explainability. 13 | In the last years, in the room left by this lack of formalism, many methodologies have been proposed based on different "interpretations" (pun intended) of the above defintion. 14 | While the proliferation of this multitude of disparate algorithms has posed challenges on rigorously comparing them, it is nevertheless interesting and useful to apply these techniques to analyze the behaviour of deep learning models. 15 | 16 | ## What is this tutorial about? 17 | 18 | This tutorial focuses on illustrating some of the recent advancements in the field of interpretable deep learning. 19 | We will show common techniques that can be used to explain predictions on pretrained models and that can be used to shed light on their inner mechanisms. 20 | The tutorial is aimed to strike the right balance between theoretical input and practical exercises. 21 | The session has been designed to provide the participants not only with the theory behind deep learning interpretability, but also to offer a set of frameworks and tools that they can easily reuse in their own projects. 22 | 23 | ### depiction: a framework for explanability 24 | 25 | The group of Cognitive Health Care and Life Sciences at IBM Research Zürich has opensourced a python toolbox, [depiction](https://github.com/IBM/dl-interpretability-compbio), with the aim of providing a framework to ease the application of explainability methods on custom models, especially for less experienced users. 26 | The module provide wrappers for multiple algorithms and is continously updated including the latest algorithms from [AIX360](https://github.com/IBM/AIX360.git). 27 | The core concept behind depiction is to allow users to seamlessly run state-of-art interpretability methods with minimal requirements in terms of programming skills. 28 | Below an example of how depiction can be used to analyze a pretrained model. 29 | 30 | ### A simple example 31 | 32 | Let's assume to have a fancy model for classification of tabular data pretrained in Keras and avaialble at a public url. 33 | Explaining its predictions with `depiction` is easy as implementing a lightweight wrapper of `depiction.models.uri.HTTPModel` where its `predict` method is overloaded. 34 | 35 | ```python 36 | from depiction.core import Task, DataType 37 | from depiction.models.uri import HTTPModel 38 | 39 | 40 | class FancyModel(HTTPModel): 41 | """A fancy classifier.""" 42 | 43 | 44 | def __init__(self, 45 | filename='fancy_model.h5', 46 | origin='https://url/to/my/fancy_model.h5', 47 | cache_dir='/path/to/cache/models', 48 | *args, **kwargs): 49 | """Initialize the FancyModel.""" 50 | super().__init__( 51 | uri=origin, 52 | task=Task.CLASSIFICATION, 53 | data_type=DataType.TABULAR, 54 | cache_dir=cache_dir, 55 | filename=filename 56 | ) 57 | self.model = keras.models.load_model(self.model_path) 58 | 59 | def predict(self, sample, *args, **kwargs): 60 | """ 61 | Run the fancy model for inference on a given sample and with the provided 62 | parameters. 63 | 64 | Args: 65 | sample (object): an input sample for the model. 66 | args (list): list of arguments. 67 | kwargs (dict): list of key-value arguments. 68 | 69 | Returns: 70 | a prediction for the model on the given sample. 71 | """ 72 | return self.model.predict( 73 | sample, 74 | batch_size=None, verbose=0, 75 | steps=None, callbacks=None 76 | ) 77 | ``` 78 | 79 | Once `FancyModel` is implemented, using any of the `depiction.interpreters` available in the library, is as easy as typing: 80 | 81 | ```python 82 | fancy_model = FancyModel() 83 | # NOTE: interpreters are implemented inheriting from 84 | # depiction.interpreters.base.base_interpreter.BaseInterpreter 85 | # and they share a common interface. 86 | explanations = interpreter.interpret(example) 87 | ``` 88 | 89 | The explanations generated depend on the specific interpreter used. 90 | For example, in the case of exaplanations generated using [LIME](https://github.com/marcotcr/lime) ([Ribeiro et al.](https://arxiv.org/abs/1602.04938)), when using a Jupyter notebook, one can simply run: 91 | 92 | ```python 93 | # LIME example 94 | from depiction.interpreters.u_wash import UWasher 95 | 96 | # NOTE: `interpreter_params`: minimal settings for the dataset considered 97 | # NOTE: `explanation_configs`: minimal settings for the dataset considered 98 | 99 | interpreter = UWasher('lime', fancy_model, **interpreter_params) 100 | interpreter.interpret(example, explanation_configs=explanation_configs) 101 | ``` 102 | 103 | and directly obtain the model-specific explanation: 104 | 105 |

106 | LIME example 107 |

108 | 109 | ## Want to know more? 110 | 111 | If you found this blog post interesting and you want to know more about interpretability and depiction, come and join us at the tutorial ["Opening The Black Box — Interpretability In Deep Learning"](https://odsc.com/training/portfolio/opening-the-black-box-interpretability-in-deep-learning/) at [ODSC2019](https://odsc.com/london/) next November 20th in London. 112 | -------------------------------------------------------------------------------- /workshops/20191120_ODSC2019/blog/lime.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/workshops/20191120_ODSC2019/blog/lime.png -------------------------------------------------------------------------------- /workshops/20191120_ODSC2019/notebooks/imagenet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Interpreting predictions on ImageNet" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "# general imports\n", 17 | "import warnings; warnings.filterwarnings(\"ignore\", category=FutureWarning)\n", 18 | "import tensorflow as tf; tf.logging.set_verbosity(tf.logging.ERROR) # suppress deprecation messages\n", 19 | "import os\n", 20 | "import json\n", 21 | "import numpy as np\n", 22 | "import keras_applications\n", 23 | "from tensorflow import keras\n", 24 | "from ipywidgets import interact\n", 25 | "from matplotlib import pyplot as plt\n", 26 | "\n", 27 | "from depiction.core import DataType, Task" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "# plotting\n", 37 | "plt.rcParams['figure.figsize'] = [20, 10]" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "# general utils\n", 47 | "def image_preprocessing(image_path, preprocess_input, target_size):\n", 48 | " \"\"\"\n", 49 | " Read and preprocess an image from disk.\n", 50 | "\n", 51 | " Args:\n", 52 | " image_path (str): path to the image.\n", 53 | " preprocess_input (funciton): a preprocessing function.\n", 54 | " target_size (tuple): image target size.\n", 55 | "\n", 56 | " Returns:\n", 57 | " np.ndarray: the preprocessed image.\n", 58 | " \"\"\"\n", 59 | " image = keras.preprocessing.image.load_img(\n", 60 | " image_path, target_size=target_size\n", 61 | " )\n", 62 | " x = keras.preprocessing.image.img_to_array(image)\n", 63 | " x = np.expand_dims(x, axis=0)\n", 64 | " return preprocess_input(x)\n", 65 | "\n", 66 | "\n", 67 | "def get_imagenet_labels():\n", 68 | " \"\"\"\n", 69 | " Get ImamgeNet labels.\n", 70 | "\n", 71 | " Returns:\n", 72 | " list: list of labels.\n", 73 | " \"\"\"\n", 74 | " labels_filepath = keras.utils.get_file(\n", 75 | " 'imagenet_class_index.json',\n", 76 | " keras_applications.imagenet_utils.CLASS_INDEX_PATH\n", 77 | " )\n", 78 | " with open(labels_filepath) as fp:\n", 79 | " labels_json = json.load(fp)\n", 80 | " labels = [None] * len(labels_json)\n", 81 | " for index, (_, label) in labels_json.items():\n", 82 | " labels[int(index)] = label\n", 83 | " return labels\n", 84 | "\n", 85 | "\n", 86 | "def show_image(x, title=None):\n", 87 | " \"\"\"\n", 88 | " Show an image.\n", 89 | "\n", 90 | " Args:\n", 91 | " x (np.ndarray): a 4D-array representing a batch with a\n", 92 | " single image.\n", 93 | " title (str): optional title.\n", 94 | " \"\"\"\n", 95 | " axes_image = plt.imshow(x.squeeze())\n", 96 | " axes_image.axes.set_xticks([], [])\n", 97 | " axes_image.axes.set_yticks([], [])\n", 98 | " if title is not None:\n", 99 | " axes_image.axes.set_title(title)\n", 100 | " return axes_image" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": {}, 106 | "source": [ 107 | "## Instantiate a model to intepret" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "from depiction.models.keras import KerasApplicationModel\n", 117 | "# instantiate the model\n", 118 | "model = KerasApplicationModel(\n", 119 | " keras.applications.MobileNetV2(), Task.CLASSIFICATION, DataType.IMAGE\n", 120 | ")" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "## Get data" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "# get labels\n", 137 | "labels = get_imagenet_labels()\n", 138 | "examples = {}\n", 139 | "for filename, url in [\n", 140 | " ('elephant.jpg', 'https://upload.wikimedia.org/wikipedia/commons/thumb/f/f9/Zoorashia_elephant.jpg/120px-Zoorashia_elephant.jpg'),\n", 141 | " ('dog.jpg', 'https://upload.wikimedia.org/wikipedia/commons/thumb/1/15/Welsh_Springer_Spaniel.jpg/400px-Welsh_Springer_Spaniel.jpg'),\n", 142 | " ('cat.jpg', 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Six_weeks_old_cat_%28aka%29.jpg/400px-Six_weeks_old_cat_%28aka%29.jpg'),\n", 143 | " ('cat-and-dog.jpg.', 'https://upload.wikimedia.org/wikipedia/commons/9/97/Greyhound_and_cat.jpg'),\n", 144 | " ('plush.jpg', 'https://upload.wikimedia.org/wikipedia/commons/thumb/5/51/Plush_bunny_with_headphones.jpg/320px-Plush_bunny_with_headphones.jpg')\n", 145 | "]:\n", 146 | " filepath = keras.utils.get_file(filename, url)\n", 147 | " examples[filename.split('.')[0]] = image_preprocessing(\n", 148 | " filepath,\n", 149 | " keras.applications.mobilenet_v2.preprocess_input,\n", 150 | " target_size=(224, 224)\n", 151 | " )\n", 152 | "interact(lambda key: show_image(examples[key], title=f'{key}'), key=examples.keys());" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "# pick an example\n", 162 | "image = examples['elephant']" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": {}, 168 | "source": [ 169 | "## LIME" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "from depiction.interpreters.u_wash import UWasher\n", 179 | "\n", 180 | "interpreter = UWasher('lime', model, class_names=labels)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "metadata": { 187 | "lines_to_next_cell": 0, 188 | "scrolled": false 189 | }, 190 | "outputs": [], 191 | "source": [ 192 | "explanation = interpreter.interpret(image)" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": {}, 198 | "source": [ 199 | "## Anchors" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "from depiction.interpreters.u_wash import UWasher\n", 209 | "\n", 210 | "interpreter = UWasher('anchors', model)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "explanation = interpreter.interpret(image)" 220 | ] 221 | } 222 | ], 223 | "metadata": { 224 | "jupytext": { 225 | "cell_metadata_filter": "-all", 226 | "main_language": "python", 227 | "notebook_metadata_filter": "-all" 228 | }, 229 | "kernelspec": { 230 | "display_name": "Python 3", 231 | "language": "python", 232 | "name": "python3" 233 | }, 234 | "language_info": { 235 | "codemirror_mode": { 236 | "name": "ipython", 237 | "version": 3 238 | }, 239 | "file_extension": ".py", 240 | "mimetype": "text/x-python", 241 | "name": "python", 242 | "nbconvert_exporter": "python", 243 | "pygments_lexer": "ipython3", 244 | "version": "3.7.4" 245 | } 246 | }, 247 | "nbformat": 4, 248 | "nbformat_minor": 2 249 | } 250 | -------------------------------------------------------------------------------- /workshops/20191121_PyPharma19/README.md: -------------------------------------------------------------------------------- 1 | # Tutorial on interpretability in machine learning for Computational Biology 2 | 3 | ## Location 4 | 5 | ### Venue 6 | 7 | [University of Basel](https://www.unibas.ch/de), Biozentrum, [Klingelbergstrasse 70](https://goo.gl/maps/51YhnLf5YLDBEjMY8), CH-4056 Basel, Switzerland 8 | 9 | 10 | ### Room 11 | 12 | Seminarraum 104 13 | 14 | ## Requirements 15 | 16 | This course is designed for everyone who would like to learn the basics of interpretability techniques for machine learning. The tutorial will provide a brief introduction to key concepts and recent developments in the field of interpretability. **Participants who want to participate in the hands-on exercises should bring a laptop.** 17 | 18 | ### Setup 19 | 20 | We provide 3 ways to follow the exercises. 21 | 22 | 1. _Docker_. Instructions can be found [here](https://github.com/IBM/dl-interpretability-compbio#docker-setup). 23 | 2. _Conda environment_. Instructions can be found [here](https://github.com/IBM/dl-interpretability-compbio#development-setup). 24 | 3. [Notebook](https://github.com/IBM/dl-interpretability-compbio/blob/master/workshops/20191121_PyPharma19/tutorial_colab.ipynb) that should work out-of-the-box on [Google Colab](https://colab.research.google.com/) (with some caveats). 25 | 26 | We personally suggest the first setup (docker). 27 | 28 | For both the _docker_ and the _conda_ setup, **we strongly suggest to setup your machine prior to the tutorial (ideally the day before)** following the linked instructions. Should you incur into any issue, please do not hesitate to contact [Jannis](mailto:jab@zurich.ibm.com) or [myself](mailto:uye@zurich.ibm.com). 29 | 30 | ## Organisers and tutors 31 | 32 | - [An-Phi Nguyen](https://researcher.watson.ibm.com/researcher/view.php?person=zurich-UYE), IBM Research Zürich 33 | - [Jannis Born](https://researcher.watson.ibm.com/researcher/view.php?person=zurich-JAB), IBM Research Zürich 34 | 35 | ## (Tentative) Schedule 36 | 37 | | Time | Title | Speaker | 38 | |-------------|--------------------------------------|--------------------------| 39 | | 13:00-14:30 | Interpretability in Machine Learning for Computational Biology | An-phi Nguyen | 40 | | 14:30-15:00 | Break | N/A | 41 | | 15:00-15:15 | Introduction to depiction | An-phi Nguyen, Jannis Born | 42 | | 15:15-15:45 | CellTyper | An-phi Nguyen, Jannis Born | 43 | | 15:45-16:15 | Interpreting DeepBind | An-phi Nguyen | 44 | | 16:15-17:00 | Interpret PaccMann (drug sensitvity) | Jannis Born | -------------------------------------------------------------------------------- /workshops/20200125_AMLD2020/README.md: -------------------------------------------------------------------------------- 1 | # Interpretability in machine learning for Computational Biology 2 | 3 | ## Requirements 4 | 5 | This course is designed for everyone who would like to learn the basics of interpretability techniques for machine learning. The tutorial will provide a brief introduction to key concepts and recent developments in the field of interpretability. **Participants should bring a laptop to follow the hands-on exercises.** 6 | 7 | ### Setup 8 | 9 | We provide two main ways to follow the exercises. 10 | 11 | #### docker 12 | 13 | Assuming you have the image [`drugilsberg/depiction`](https://hub.docker.com/r/drugilsberg/depiction) up-to-date just run: 14 | 15 | ```console 16 | docker run --mount src=`pwd`/workshops/20200125_AMLD2020/notebooks,target=/workspace/notebooks,type=bind -p 8899:8888 -it drugilsberg/depiction 17 | ``` 18 | 19 | Detailed setup instructions can be found [here](https://github.com/IBM/dl-interpretability-compbio#docker-setup). 20 | 21 | #### conda 22 | 23 | Instructions can be found [here](https://github.com/IBM/dl-interpretability-compbio#development-setup). 24 | 25 | We personally suggest the first setup (docker). 26 | 27 | For both the _docker_ and the _conda_ setup, **we strongly suggest to setup your machine prior to the tutorial (ideally the day before)** following the linked instructions. Should you incur into any issue, please do not hesitate to contact [Matteo](mailto:tte@zurich.ibm.com) or [An-phi](mailto:uye@zurich.ibm.com). 28 | 29 | ## Organisers and tutors 30 | 31 | - [Dr. Matteo Manica](https://researcher.watson.ibm.com/researcher/view.php?person=zurich-TTE), IBM Research Zürich 32 | - [An-Phi Nguyen](https://researcher.watson.ibm.com/researcher/view.php?person=zurich-UYE), IBM Research Zürich 33 | 34 | ## Schedule 35 | 36 | | Time | Title | Speaker | 37 | |-------------|--------------------------------------|--------------------------| 38 | | 09:00-10:00 | Interpretability in Machine Learning | An-phi Nguyen | 39 | | 10:00-10:15 | Introduction to depiction | Matteo Manica | 40 | | 10:15-10:30 | (Exercise 1) Hands-on intro to depiction. CellTyper: linear models and interpretability | Matteo Manica | 41 | | 10:30-11:00 | Coffee break | N/A | 42 | | 11:00-11:30 | (Exercise 2) Breast cancer image classification | An-phi Nguyen | 43 | | 11:30-12:00 | (Exercise 3) Understanding transcription factors binding | An-phi Nguyen | 44 | | 12:00-12:30 | (Exercise 4) PaccMann: what to do for multimodal data | Matteo Manica | 45 | 46 | ## Slides 47 | 48 | The slides can be found on Box: https://ibm.box.com/v/amld-2020-depiction. 49 | -------------------------------------------------------------------------------- /workshops/20200125_AMLD2020/notebooks/deepbind.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Having fun with DeepBind" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import warnings; warnings.filterwarnings('ignore', category=FutureWarning)\n", 17 | "import tensorflow as tf; tf.logging.set_verbosity(tf.logging.ERROR) # suppress deprecation messages\n", 18 | "from depiction.models.examples.deepbind.deepbind import DeepBind, create_DNA_language\n", 19 | "from depiction.interpreters.u_wash.u_washer import UWasher\n", 20 | "from ipywidgets import interact" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "## Setup task" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "class_names = ['NOT BINDING', 'BINDING']\n", 37 | "classifier = DeepBind(model='DeepBind/Homo_sapiens/TF/D00328.003_SELEX_CTCF', min_length=35)\n", 38 | "# this class has task (classification) and data_type (text) and some processing defined for your convenience" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "http://kipoi.org/models/DeepBind/Homo_sapiens/TF/D00328.003_SELEX_CTCF/" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "# Interpreter parametrization" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "# a LIME text interpreter\n", 62 | "lime_explanation_configs = {\n", 63 | " 'labels': (1,),\n", 64 | "}\n", 65 | "lime_params = {\n", 66 | " 'class_names': class_names,\n", 67 | " 'split_expression': list,\n", 68 | " 'bow': False,\n", 69 | " 'char_level': True\n", 70 | "}\n", 71 | "\n", 72 | "# an Anchor text intepreter\n", 73 | "anchors_explanation_configs = {\n", 74 | " 'use_proba': False,\n", 75 | " 'batch_size': 100\n", 76 | "}\n", 77 | "anchors_params = {\n", 78 | " 'class_names': class_names,\n", 79 | " 'nlp': create_DNA_language(),\n", 80 | " 'unk_token': 'N',\n", 81 | " 'sep_token': '',\n", 82 | " 'use_unk_distribution': True\n", 83 | "}" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "### Wrapper for the interactive widget" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "class InteractiveWrapper:\n", 100 | " def __init__(self, classifier):\n", 101 | " self.classifier = classifier\n", 102 | " self.lime_explainer = UWasher(\"lime\", self.classifier, **lime_params)\n", 103 | " self.anchor_explainer = UWasher(\"anchors\", self.classifier, **anchors_params)\n", 104 | "\n", 105 | " def callback(self, sequence):\n", 106 | " # LIME\n", 107 | " self.classifier.use_labels = False\n", 108 | " self.lime_explainer.interpret(sequence, explanation_configs=lime_explanation_configs)\n", 109 | " # Anchors \n", 110 | " self.classifier.use_labels = True\n", 111 | " self.anchor_explainer.interpret(sequence, explanation_configs=anchors_explanation_configs)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "# Let's interpret" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "wrapper = InteractiveWrapper(classifier)\n", 128 | "\n", 129 | "interact(\n", 130 | " wrapper.callback,\n", 131 | " sequence=[\n", 132 | " 'AGGCTAGCTAGGGGCGCCC', 'AGGCTAGCTAGGGGCGCTT', 'AGGGTAGCTAGGGGCGCTT',\n", 133 | " 'AGGGTAGCTGGGGGCGCTT', 'AGGCTAGGTGGGGGCGCTT', 'AGGCTCGGTGGGGGCGCTT',\n", 134 | " 'AGGCTCGGTAGGGGGCGATT'\n", 135 | " ]\n", 136 | ")" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "CTCF binding motif\n", 144 | "![CTCF binding motif](https://media.springernature.com/full/springer-static/image/art%3A10.1186%2Fgb-2009-10-11-r131/MediaObjects/13059_2009_Article_2281_Fig2_HTML.jpg?as=webp)\n", 145 | "from Essien, Kobby, et al. \"CTCF binding site classes exhibit distinct evolutionary, genomic, epigenomic and transcriptomic features.\" Genome biology 10.11 (2009): R131." 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "classifier_foxa1 = DeepBind('DeepBind/Homo_sapiens/TF/D00761.001_ChIP-seq_FOXA1', min_length=40)" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "wrapper_foxa1 = InteractiveWrapper(classifier_foxa1)\n", 164 | "\n", 165 | "interact(wrapper_foxa1.callback, sequence='TGTGTGTGTG')" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "metadata": {}, 171 | "source": [ 172 | "FOXA1 binding motif\n", 173 | "![FOXA1 binding motif](https://ismara.unibas.ch/supp/dataset1_IBM_v2/ismara_report/logos/FOXA1.png)\n", 174 | "from https://ismara.unibas.ch/supp/dataset1_IBM_v2/ismara_report/pages/FOXA1.html" 175 | ] 176 | } 177 | ], 178 | "metadata": { 179 | "kernelspec": { 180 | "display_name": "Python 3", 181 | "language": "python", 182 | "name": "python3" 183 | }, 184 | "language_info": { 185 | "codemirror_mode": { 186 | "name": "ipython", 187 | "version": 3 188 | }, 189 | "file_extension": ".py", 190 | "mimetype": "text/x-python", 191 | "name": "python", 192 | "nbconvert_exporter": "python", 193 | "pygments_lexer": "ipython3", 194 | "version": "3.7.3" 195 | } 196 | }, 197 | "nbformat": 4, 198 | "nbformat_minor": 2 199 | } 200 | -------------------------------------------------------------------------------- /workshops/20200125_AMLD2020/notebooks/kaggle_create_a_new_api_token.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/workshops/20200125_AMLD2020/notebooks/kaggle_create_a_new_api_token.png -------------------------------------------------------------------------------- /workshops/20200125_AMLD2020/notebooks/kaggle_go_to_your_account.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/depiction/3b13394f2dd9614736b4183b407a938a2c5924ac/workshops/20200125_AMLD2020/notebooks/kaggle_go_to_your_account.png --------------------------------------------------------------------------------