├── .circleci └── config.yml ├── .coveragerc ├── .gitignore ├── .readthedocs.yml ├── .travis.yml ├── LICENSE ├── MANIFEST.in ├── README.rst ├── appveyor.yml ├── doc ├── Makefile ├── _static │ ├── css │ │ └── project-template.css │ ├── img │ │ ├── pye_logo_text.svg │ │ └── pye_logo_text_dark.svg │ └── js │ │ └── copybutton.js ├── _templates │ ├── class.rst │ ├── function.rst │ └── numpydoc_docstring.py ├── conf.py ├── index.rst ├── make.bat ├── modules │ ├── logic │ │ ├── metrics.rst │ │ ├── nn │ │ │ ├── entropy.rst │ │ │ ├── psi.rst │ │ │ └── utils.rst │ │ └── utils.rst │ └── nn │ │ ├── concepts.rst │ │ ├── functional │ │ ├── loss.rst │ │ └── prune.rst │ │ ├── logic.rst │ │ └── semantics.rst └── user_guide │ ├── authors.rst │ ├── contributing.rst │ ├── installation.rst │ ├── licence.rst │ ├── running_tests.rst │ ├── tutorial_cem.rst │ ├── tutorial_dcr.rst │ └── tutorial_lens.rst ├── environment.yml ├── experiments ├── bio │ └── tabula_muris.ipynb ├── data │ ├── load_datasets.py │ └── tabula_muris_comet │ │ ├── __init__.py │ │ ├── datamgr.py │ │ ├── dataset.py │ │ ├── feature_loader.py │ │ ├── map_GO.py │ │ └── preprocess.py ├── elens │ ├── L1_vs_entropy.py │ ├── blackbox │ │ ├── cub.py │ │ ├── mimic.py │ │ ├── mnist.py │ │ ├── vdem.ipynb │ │ └── vdem.py │ ├── celldiff.ipynb │ ├── cub.py │ ├── hyperparams │ │ ├── cub.py │ │ ├── mnist.py │ │ └── vdem.py │ ├── mimic.py │ ├── mnist.py │ ├── results_old │ │ ├── CUB │ │ │ └── explainer │ │ │ │ └── results_aware_cub.csv │ │ ├── MNIST │ │ │ └── explainer │ │ │ │ └── results_aware_mnist.csv │ │ └── vdem │ │ │ └── explainer │ │ │ └── results_aware_vdem.csv │ ├── summary_plots.ipynb │ ├── summary_to_latex.ipynb │ └── vdem.ipynb └── vlens │ └── prova.py ├── requirements.txt ├── setup.cfg ├── setup.py ├── tests ├── __init__.py ├── test_cem.py ├── test_dcr.py ├── test_logic_layer.py └── test_utils.py └── torch_explain ├── __init__.py ├── _version.py ├── datasets ├── __init__.py └── benchmarks.py ├── logic ├── __init__.py ├── metrics.py ├── nn │ ├── __init__.py │ ├── entropy.py │ ├── psi.py │ └── utils.py └── utils.py └── nn ├── __init__.py ├── concepts.py ├── functional ├── __init__.py ├── loss.py └── prune.py ├── logic.py └── semantics.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | jobs: 4 | build: 5 | docker: 6 | - image: circleci/python:3.6.1 7 | working_directory: ~/repo 8 | steps: 9 | - checkout 10 | - run: 11 | name: install dependencies 12 | command: | 13 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh 14 | chmod +x miniconda.sh && ./miniconda.sh -b -p ~/miniconda 15 | export PATH="~/miniconda/bin:$PATH" 16 | conda update --yes --quiet conda 17 | conda create -n testenv --yes --quiet python=3 18 | source activate testenv 19 | conda install --yes pip numpy scipy scikit-learn matplotlib sphinx sphinx_rtd_theme numpydoc pillow 20 | pip install sphinx-gallery 21 | pip install . 22 | cd doc 23 | make html 24 | - store_artifacts: 25 | path: doc/_build/html/ 26 | destination: doc 27 | - store_artifacts: 28 | path: ~/log.txt 29 | - run: ls -ltrh doc/_build/html 30 | filters: 31 | branches: 32 | ignore: gh-pages 33 | 34 | workflows: 35 | version: 2 36 | workflow: 37 | jobs: 38 | - build 39 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | # Configuration for coverage.py 2 | 3 | [run] 4 | branch = True 5 | source = torch_explain 6 | include = */torch_explain/* 7 | omit = 8 | */setup.py 9 | 10 | [report] 11 | exclude_lines = 12 | pragma: no cover 13 | def __repr__ 14 | if self.debug: 15 | if settings.DEBUG 16 | raise AssertionError 17 | raise NotImplementedError 18 | if 0: 19 | if __name__ == .__main__.: 20 | if self.verbose: 21 | show_missing = True -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # scikit-learn specific 10 | doc/_build/ 11 | doc/auto_examples/ 12 | doc/modules/generated/ 13 | doc/datasets/generated/ 14 | 15 | # Distribution / packaging 16 | 17 | .Python 18 | env/ 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *,cover 53 | .hypothesis/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | 62 | # Sphinx documentation 63 | doc/_build/ 64 | doc/generated/ 65 | 66 | # PyBuilder 67 | target/ 68 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | formats: 2 | - none 3 | requirements_file: requirements.txt 4 | python: 5 | pip_install: true 6 | extra_requirements: 7 | - tests 8 | - docs 9 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: trusty 2 | sudo: false 3 | 4 | language: python 5 | 6 | cache: 7 | directories: 8 | - $HOME/.cache/pip 9 | 10 | matrix: 11 | include: 12 | - env: PYTHON_VERSION="3.7" NUMPY_VERSION="*" SCIPY_VERSION="*" PANDAS_VERSION="*" PYTORCH_VERSION="*" SYMPY_VERSION="*" 13 | 14 | install: 15 | # install miniconda 16 | - deactivate 17 | - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh 18 | - MINICONDA_PATH=/home/travis/miniconda 19 | - chmod +x miniconda.sh && ./miniconda.sh -b -p $MINICONDA_PATH 20 | - export PATH=$MINICONDA_PATH/bin:$PATH 21 | - conda update --yes conda 22 | # create the testing environment 23 | - conda create -n testenv --yes python=$PYTHON_VERSION pip 24 | - source activate testenv 25 | - conda install --yes pytorch torchvision torchaudio cpuonly -c pytorch 26 | - pip install sympy pandas scikit-learn tqdm cython nose seaborn matplotlib pytest pytest-cov codecov 27 | - pip install . 28 | 29 | script: 30 | - coverage run -m unittest discover 31 | 32 | after_success: 33 | - codecov 34 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | -------------------------------------------------------------------------------- /appveyor.yml: -------------------------------------------------------------------------------- 1 | build: false 2 | 3 | environment: 4 | matrix: 5 | - PYTHON: "C:\\Miniconda3-x64" 6 | PYTHON_VERSION: "3.5.x" 7 | PYTHON_ARCH: "32" 8 | NUMPY_VERSION: "1.13.1" 9 | SCIPY_VERSION: "0.19.1" 10 | SKLEARN_VERSION: "0.19.1" 11 | 12 | - PYTHON: "C:\\Miniconda3-x64" 13 | PYTHON_VERSION: "3.6.x" 14 | PYTHON_ARCH: "64" 15 | NUMPY_VERSION: "*" 16 | SCIPY_VERSION: "*" 17 | SKLEARN_VERSION: "*" 18 | 19 | install: 20 | # Prepend miniconda installed Python to the PATH of this build 21 | # Add Library/bin directory to fix issue 22 | # https://github.com/conda/conda/issues/1753 23 | - "SET PATH=%PYTHON%;%PYTHON%\\Scripts;%PYTHON%\\Library\\bin;%PATH%" 24 | # install the dependencies 25 | - "conda install --yes -c conda-forge pip numpy==%NUMPY_VERSION% scipy==%SCIPY_VERSION% scikit-learn==%SKLEARN_VERSION%" 26 | - pip install codecov nose pytest pytest-cov 27 | - pip install . 28 | 29 | test_script: 30 | - mkdir for_test 31 | - cd for_test 32 | - pytest -v --cov=template --pyargs template 33 | 34 | after_test: 35 | - cp .coverage %APPVEYOR_BUILD_FOLDER% 36 | - cd %APPVEYOR_BUILD_FOLDER% 37 | - codecov 38 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 21 | 22 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext 23 | 24 | help: 25 | @echo "Please use \`make ' where is one of" 26 | @echo " html to make standalone HTML files" 27 | @echo " dirhtml to make HTML files named index.html in directories" 28 | @echo " singlehtml to make a single large HTML file" 29 | @echo " pickle to make pickle files" 30 | @echo " json to make JSON files" 31 | @echo " htmlhelp to make HTML files and a HTML help project" 32 | @echo " qthelp to make HTML files and a qthelp project" 33 | @echo " devhelp to make HTML files and a Devhelp project" 34 | @echo " epub to make an epub" 35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 36 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 38 | @echo " text to make text files" 39 | @echo " man to make manual pages" 40 | @echo " texinfo to make Texinfo files" 41 | @echo " info to make Texinfo files and run them through makeinfo" 42 | @echo " gettext to make PO message catalogs" 43 | @echo " changes to make an overview of all changed/added/deprecated items" 44 | @echo " xml to make Docutils-native XML files" 45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 46 | @echo " linkcheck to check all external links for integrity" 47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 48 | 49 | clean: 50 | -rm -rf $(BUILDDIR)/* 51 | -rm -rf auto_examples/ 52 | -rm -rf generated/* 53 | -rm -rf modules/generated/* 54 | 55 | html: 56 | # These two lines make the build a bit more lengthy, and the 57 | # the embedding of images more robust 58 | rm -rf $(BUILDDIR)/html/_images 59 | #rm -rf _build/doctrees/ 60 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 61 | @echo 62 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 63 | 64 | dirhtml: 65 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 66 | @echo 67 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 68 | 69 | singlehtml: 70 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 71 | @echo 72 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 73 | 74 | pickle: 75 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 76 | @echo 77 | @echo "Build finished; now you can process the pickle files." 78 | 79 | json: 80 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 81 | @echo 82 | @echo "Build finished; now you can process the JSON files." 83 | 84 | htmlhelp: 85 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 86 | @echo 87 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 88 | ".hhp project file in $(BUILDDIR)/htmlhelp." 89 | 90 | qthelp: 91 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 92 | @echo 93 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 94 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 95 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/project-template.qhcp" 96 | @echo "To view the help file:" 97 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/project-template.qhc" 98 | 99 | devhelp: 100 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 101 | @echo 102 | @echo "Build finished." 103 | @echo "To view the help file:" 104 | @echo "# mkdir -p $$HOME/.local/share/devhelp/project-template" 105 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/project-template" 106 | @echo "# devhelp" 107 | 108 | epub: 109 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 110 | @echo 111 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 112 | 113 | latex: 114 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 115 | @echo 116 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 117 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 118 | "(use \`make latexpdf' here to do that automatically)." 119 | 120 | latexpdf: 121 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 122 | @echo "Running LaTeX files through pdflatex..." 123 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 124 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 125 | 126 | latexpdfja: 127 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 128 | @echo "Running LaTeX files through platex and dvipdfmx..." 129 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 130 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 131 | 132 | text: 133 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 134 | @echo 135 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 136 | 137 | man: 138 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 139 | @echo 140 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 141 | 142 | texinfo: 143 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 144 | @echo 145 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 146 | @echo "Run \`make' in that directory to run these through makeinfo" \ 147 | "(use \`make info' here to do that automatically)." 148 | 149 | info: 150 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 151 | @echo "Running Texinfo files through makeinfo..." 152 | make -C $(BUILDDIR)/texinfo info 153 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 154 | 155 | gettext: 156 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 157 | @echo 158 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 159 | 160 | changes: 161 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 162 | @echo 163 | @echo "The overview file is in $(BUILDDIR)/changes." 164 | 165 | linkcheck: 166 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 167 | @echo 168 | @echo "Link check complete; look for any errors in the above output " \ 169 | "or in $(BUILDDIR)/linkcheck/output.txt." 170 | 171 | doctest: 172 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 173 | @echo "Testing of doctests in the sources finished, look at the " \ 174 | "results in $(BUILDDIR)/doctest/output.txt." 175 | 176 | xml: 177 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 178 | @echo 179 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 180 | 181 | pseudoxml: 182 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 183 | @echo 184 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 185 | -------------------------------------------------------------------------------- /doc/_static/css/project-template.css: -------------------------------------------------------------------------------- 1 | @import url("theme.css"); 2 | 3 | .highlight a { 4 | text-decoration: underline; 5 | } 6 | 7 | .deprecated p { 8 | padding: 10px 7px 10px 10px; 9 | color: #b94a48; 10 | background-color: #F3E5E5; 11 | border: 1px solid #eed3d7; 12 | } 13 | 14 | .deprecated p span.versionmodified { 15 | font-weight: bold; 16 | } 17 | -------------------------------------------------------------------------------- /doc/_static/img/pye_logo_text.svg: -------------------------------------------------------------------------------- 1 |  7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /doc/_static/img/pye_logo_text_dark.svg: -------------------------------------------------------------------------------- 1 |  7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /doc/_static/js/copybutton.js: -------------------------------------------------------------------------------- 1 | $(document).ready(function() { 2 | /* Add a [>>>] button on the top-right corner of code samples to hide 3 | * the >>> and ... prompts and the output and thus make the code 4 | * copyable. */ 5 | var div = $('.highlight-python .highlight,' + 6 | '.highlight-python3 .highlight,' + 7 | '.highlight-pycon .highlight,' + 8 | '.highlight-default .highlight') 9 | var pre = div.find('pre'); 10 | 11 | // get the styles from the current theme 12 | pre.parent().parent().css('position', 'relative'); 13 | var hide_text = 'Hide the prompts and output'; 14 | var show_text = 'Show the prompts and output'; 15 | var border_width = pre.css('border-top-width'); 16 | var border_style = pre.css('border-top-style'); 17 | var border_color = pre.css('border-top-color'); 18 | var button_styles = { 19 | 'cursor':'pointer', 'position': 'absolute', 'top': '0', 'right': '0', 20 | 'border-color': border_color, 'border-style': border_style, 21 | 'border-width': border_width, 'color': border_color, 'text-size': '75%', 22 | 'font-family': 'monospace', 'padding-left': '0.2em', 'padding-right': '0.2em', 23 | 'border-radius': '0 3px 0 0' 24 | } 25 | 26 | // create and add the button to all the code blocks that contain >>> 27 | div.each(function(index) { 28 | var jthis = $(this); 29 | if (jthis.find('.gp').length > 0) { 30 | var button = $('>>>'); 31 | button.css(button_styles) 32 | button.attr('title', hide_text); 33 | button.data('hidden', 'false'); 34 | jthis.prepend(button); 35 | } 36 | // tracebacks (.gt) contain bare text elements that need to be 37 | // wrapped in a span to work with .nextUntil() (see later) 38 | jthis.find('pre:has(.gt)').contents().filter(function() { 39 | return ((this.nodeType == 3) && (this.data.trim().length > 0)); 40 | }).wrap(''); 41 | }); 42 | 43 | // define the behavior of the button when it's clicked 44 | $('.copybutton').click(function(e){ 45 | e.preventDefault(); 46 | var button = $(this); 47 | if (button.data('hidden') === 'false') { 48 | // hide the code output 49 | button.parent().find('.go, .gp, .gt').hide(); 50 | button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'hidden'); 51 | button.css('text-decoration', 'line-through'); 52 | button.attr('title', show_text); 53 | button.data('hidden', 'true'); 54 | } else { 55 | // show the code output 56 | button.parent().find('.go, .gp, .gt').show(); 57 | button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'visible'); 58 | button.css('text-decoration', 'none'); 59 | button.attr('title', hide_text); 60 | button.data('hidden', 'false'); 61 | } 62 | }); 63 | }); 64 | -------------------------------------------------------------------------------- /doc/_templates/class.rst: -------------------------------------------------------------------------------- 1 | :mod:`{{module}}`.{{objname}} 2 | {{ underline }}============== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autoclass:: {{ objname }} 7 | 8 | {% block methods %} 9 | .. automethod:: __init__ 10 | {% endblock %} 11 | 12 | .. include:: {{module}}.{{objname}}.examples 13 | 14 | .. raw:: html 15 | 16 |
17 | -------------------------------------------------------------------------------- /doc/_templates/function.rst: -------------------------------------------------------------------------------- 1 | :mod:`{{module}}`.{{objname}} 2 | {{ underline }}==================== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autofunction:: {{ objname }} 7 | 8 | .. include:: {{module}}.{{objname}}.examples 9 | 10 | .. raw:: html 11 | 12 |
13 | -------------------------------------------------------------------------------- /doc/_templates/numpydoc_docstring.py: -------------------------------------------------------------------------------- 1 | {{index}} 2 | {{summary}} 3 | {{extended_summary}} 4 | {{parameters}} 5 | {{returns}} 6 | {{yields}} 7 | {{other_parameters}} 8 | {{attributes}} 9 | {{raises}} 10 | {{warns}} 11 | {{warnings}} 12 | {{see_also}} 13 | {{notes}} 14 | {{references}} 15 | {{examples}} 16 | {{methods}} 17 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | import datetime 17 | import os 18 | import sys 19 | sys.path.insert(0, os.path.abspath('../')) 20 | import torch_explain 21 | 22 | # -- Project information ----------------------------------------------------- 23 | 24 | project = 'pytorch_explain' 25 | author = 'Pietro Barbiero' 26 | copyright = '{}, {}'.format(datetime.datetime.now().year, author) 27 | 28 | version = torch_explain.__version__ 29 | release = torch_explain.__version__ 30 | 31 | 32 | # -- General configuration --------------------------------------------------- 33 | 34 | master_doc = 'index' 35 | 36 | # Add any Sphinx extension module names here, as strings. They can be 37 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 38 | # ones. 39 | extensions = ['sphinx.ext.autodoc', 'sphinx.ext.coverage', 'sphinx_rtd_theme'] 40 | 41 | # Add any paths that contain templates here, relative to this directory. 42 | templates_path = ['_templates'] 43 | 44 | # List of patterns, relative to source directory, that match files and 45 | # directories to ignore when looking for source files. 46 | # This pattern also affects html_static_path and html_extra_path. 47 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 48 | 49 | 50 | # -- Options for HTML output ------------------------------------------------- 51 | 52 | # The theme to use for HTML and HTML Help pages. See the documentation for 53 | # a list of builtin themes. 54 | # 55 | # html_theme = 'alabaster' 56 | html_theme = "sphinx_rtd_theme" 57 | html_logo = './_static/img/pye_logo_text.svg' 58 | 59 | html_theme_options = { 60 | 'canonical_url': 'https://pytorch_explain.readthedocs.io/en/latest/', 61 | 'logo_only': True, 62 | 'display_version': True, 63 | 'prev_next_buttons_location': 'bottom', 64 | 'style_external_links': False, 65 | # Toc options 66 | 'collapse_navigation': False, 67 | 'sticky_navigation': True, 68 | 'navigation_depth': 4, 69 | 'includehidden': True, 70 | 'titles_only': False, 71 | } 72 | 73 | 74 | # Add any paths that contain custom static files (such as style sheets) here, 75 | # relative to this directory. They are copied after the builtin static files, 76 | # so a file named "default.css" will overwrite the builtin "default.css". 77 | html_static_path = ['_static'] -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | PYTORCH EXPLAIN DOCUMENTATION 2 | =============================== 3 | 4 | 5 | `PyTorch, Explain!` is an extension library for PyTorch to develop 6 | explainable deep learning models going beyond the current accuracy-interpretability trade-off. 7 | 8 | The library includes a set of tools to develop: 9 | 10 | * Deep Concept Reasoner (Deep CoRe): an interpretable concept-based model going 11 | **beyond the current accuracy-interpretability trade-off**; 12 | * Concept Embedding Models (CEMs): a class of concept-based models going 13 | **beyond the current accuracy-explainability trade-off**; 14 | * Logic Explained Networks (LENs): a class of concept-based models generating 15 | accurate compound logic explanations for their predictions 16 | **without the need for a post-hoc explainer**. 17 | 18 | 19 | Quick start 20 | ----------- 21 | 22 | You can install ``torch_explain`` along with all its dependencies from 23 | `PyPI `__: 24 | 25 | .. code:: bash 26 | 27 | pip install torch-explain 28 | 29 | 30 | Source 31 | ------ 32 | 33 | The source code and minimal working examples can be found on 34 | `GitHub `__. 35 | 36 | 37 | .. toctree:: 38 | :caption: User Guide 39 | :maxdepth: 2 40 | 41 | user_guide/installation 42 | user_guide/tutorial_lens 43 | user_guide/tutorial_cem 44 | user_guide/tutorial_dcr 45 | user_guide/contributing 46 | user_guide/running_tests 47 | 48 | .. toctree:: 49 | :caption: API Reference 50 | :maxdepth: 2 51 | 52 | modules/logic/nn/entropy 53 | modules/logic/nn/psi 54 | modules/logic/nn/utils 55 | modules/logic/metrics 56 | modules/logic/utils 57 | modules/nn/logic 58 | modules/nn/functional/loss 59 | modules/nn/functional/prune 60 | modules/nn/concepts 61 | modules/nn/semantics 62 | 63 | 64 | .. toctree:: 65 | :caption: Copyright 66 | :maxdepth: 1 67 | 68 | user_guide/authors 69 | user_guide/licence 70 | 71 | 72 | Indices and tables 73 | ~~~~~~~~~~~~~~~~~~ 74 | 75 | * :ref:`genindex` 76 | * :ref:`modindex` 77 | * :ref:`search` 78 | 79 | 80 | Benchmark datasets 81 | ------------------------- 82 | 83 | We provide a suite of 3 benchmark datasets to evaluate the performance of our models 84 | in the folder `torch_explain/datasets`. These 3 datasets were proposed as benchmarks 85 | for concept-based models in the paper "Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off". 86 | 87 | Real-world datasets can be downloaded from the links provided in the supplementary material of the paper. 88 | 89 | 90 | Theory 91 | -------- 92 | Theoretical foundations can be found in the following papers. 93 | 94 | Deep Concept Reasoning:: 95 | 96 | @article{barbiero2023interpretable, 97 | title={Interpretable Neural-Symbolic Concept Reasoning}, 98 | author={Barbiero, Pietro and Ciravegna, Gabriele and Giannini, Francesco and Zarlenga, Mateo Espinosa and Magister, Lucie Charlotte and Tonda, Alberto and Lio, Pietro and Precioso, Frederic and Jamnik, Mateja and Marra, Giuseppe}, 99 | journal={arXiv preprint arXiv:2304.14068}, 100 | year={2023} 101 | } 102 | 103 | Concept Embedding Models:: 104 | 105 | @article{espinosa2022concept, 106 | title={Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off}, 107 | author={Espinosa Zarlenga, Mateo and Barbiero, Pietro and Ciravegna, Gabriele and Marra, Giuseppe and Giannini, Francesco and Diligenti, Michelangelo and Shams, Zohreh and Precioso, Frederic and Melacci, Stefano and Weller, Adrian and others}, 108 | journal={Advances in Neural Information Processing Systems}, 109 | volume={35}, 110 | pages={21400--21413}, 111 | year={2022} 112 | } 113 | 114 | 115 | Logic Explained Networks:: 116 | 117 | @article{ciravegna2023logic, 118 | title={Logic explained networks}, 119 | author={Ciravegna, Gabriele and Barbiero, Pietro and Giannini, Francesco and Gori, Marco and Li{\'o}, Pietro and Maggini, Marco and Melacci, Stefano}, 120 | journal={Artificial Intelligence}, 121 | volume={314}, 122 | pages={103822}, 123 | year={2023}, 124 | publisher={Elsevier} 125 | } 126 | 127 | Entropy-based LENs:: 128 | 129 | @inproceedings{barbiero2022entropy, 130 | title={Entropy-based logic explanations of neural networks}, 131 | author={Barbiero, Pietro and Ciravegna, Gabriele and Giannini, Francesco and Li{\'o}, Pietro and Gori, Marco and Melacci, Stefano}, 132 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 133 | volume={36}, 134 | number={6}, 135 | pages={6046--6054}, 136 | year={2022} 137 | } 138 | 139 | Psi network ("learning of constraints"):: 140 | 141 | @inproceedings{ciravegna2020constraint, 142 | title={A constraint-based approach to learning and explanation}, 143 | author={Ciravegna, Gabriele and Giannini, Francesco and Melacci, Stefano and Maggini, Marco and Gori, Marco}, 144 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 145 | volume={34}, 146 | number={04}, 147 | pages={3658--3665}, 148 | year={2020} 149 | } 150 | 151 | 152 | Learning with constraints:: 153 | 154 | @inproceedings{marra2019lyrics, 155 | title={LYRICS: A General Interface Layer to Integrate Logic Inference and Deep Learning}, 156 | author={Marra, Giuseppe and Giannini, Francesco and Diligenti, Michelangelo and Gori, Marco}, 157 | booktitle={Joint European Conference on Machine Learning and Knowledge Discovery in Databases}, 158 | pages={283--298}, 159 | year={2019}, 160 | organization={Springer} 161 | } 162 | 163 | Constraints theory in machine learning:: 164 | 165 | @book{gori2017machine, 166 | title={Machine Learning: A constraint-based approach}, 167 | author={Gori, Marco}, 168 | year={2017}, 169 | publisher={Morgan Kaufmann} 170 | } 171 | 172 | 173 | Authors 174 | ------- 175 | 176 | * `Pietro Barbiero `__, University of Cambridge, UK. 177 | * Mateo Espinosa Zarlenga, University of Cambridge, UK. 178 | * Giuseppe Marra, Katholieke Universiteit Leuven, BE. 179 | * Steve Azzolin, University of Trento, IT. 180 | * Francesco Giannini, University of Florence, IT. 181 | * Gabriele Ciravegna, University of Florence, IT. 182 | * Dobrik Georgiev, University of Cambridge, UK. 183 | 184 | 185 | Licence 186 | ------- 187 | 188 | Copyright 2020 Pietro Barbiero, Mateo Espinosa Zarlenga, Giuseppe Marra, 189 | Steve Azzolin, Francesco Giannini, Gabriele Ciravegna, and Dobrik Georgiev. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); you may 192 | not use this file except in compliance with the License. You may obtain 193 | a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0. 194 | 195 | Unless required by applicable law or agreed to in writing, software 196 | distributed under the License is distributed on an "AS IS" BASIS, 197 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 198 | 199 | See the License for the specific language governing permissions and 200 | limitations under the License. 201 | -------------------------------------------------------------------------------- /doc/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | REM Command file for Sphinx documentation 4 | 5 | if "%SPHINXBUILD%" == "" ( 6 | set SPHINXBUILD=sphinx-build 7 | ) 8 | set BUILDDIR=_build 9 | set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . 10 | set I18NSPHINXOPTS=%SPHINXOPTS% . 11 | if NOT "%PAPER%" == "" ( 12 | set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% 13 | set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% 14 | ) 15 | 16 | if "%1" == "" goto help 17 | 18 | if "%1" == "help" ( 19 | :help 20 | echo.Please use `make ^` where ^ is one of 21 | echo. html to make standalone HTML files 22 | echo. dirhtml to make HTML files named index.html in directories 23 | echo. singlehtml to make a single large HTML file 24 | echo. pickle to make pickle files 25 | echo. json to make JSON files 26 | echo. htmlhelp to make HTML files and a HTML help project 27 | echo. qthelp to make HTML files and a qthelp project 28 | echo. devhelp to make HTML files and a Devhelp project 29 | echo. epub to make an epub 30 | echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter 31 | echo. text to make text files 32 | echo. man to make manual pages 33 | echo. texinfo to make Texinfo files 34 | echo. gettext to make PO message catalogs 35 | echo. changes to make an overview over all changed/added/deprecated items 36 | echo. xml to make Docutils-native XML files 37 | echo. pseudoxml to make pseudoxml-XML files for display purposes 38 | echo. linkcheck to check all external links for integrity 39 | echo. doctest to run all doctests embedded in the documentation if enabled 40 | goto end 41 | ) 42 | 43 | if "%1" == "clean" ( 44 | for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i 45 | del /q /s %BUILDDIR%\* 46 | goto end 47 | ) 48 | 49 | 50 | %SPHINXBUILD% 2> nul 51 | if errorlevel 9009 ( 52 | echo. 53 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 54 | echo.installed, then set the SPHINXBUILD environment variable to point 55 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 56 | echo.may add the Sphinx directory to PATH. 57 | echo. 58 | echo.If you don't have Sphinx installed, grab it from 59 | echo.http://sphinx-doc.org/ 60 | exit /b 1 61 | ) 62 | 63 | if "%1" == "html" ( 64 | %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html 65 | if errorlevel 1 exit /b 1 66 | echo. 67 | echo.Build finished. The HTML pages are in %BUILDDIR%/html. 68 | goto end 69 | ) 70 | 71 | if "%1" == "dirhtml" ( 72 | %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml 73 | if errorlevel 1 exit /b 1 74 | echo. 75 | echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. 76 | goto end 77 | ) 78 | 79 | if "%1" == "singlehtml" ( 80 | %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml 81 | if errorlevel 1 exit /b 1 82 | echo. 83 | echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. 84 | goto end 85 | ) 86 | 87 | if "%1" == "pickle" ( 88 | %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle 89 | if errorlevel 1 exit /b 1 90 | echo. 91 | echo.Build finished; now you can process the pickle files. 92 | goto end 93 | ) 94 | 95 | if "%1" == "json" ( 96 | %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json 97 | if errorlevel 1 exit /b 1 98 | echo. 99 | echo.Build finished; now you can process the JSON files. 100 | goto end 101 | ) 102 | 103 | if "%1" == "htmlhelp" ( 104 | %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp 105 | if errorlevel 1 exit /b 1 106 | echo. 107 | echo.Build finished; now you can run HTML Help Workshop with the ^ 108 | .hhp project file in %BUILDDIR%/htmlhelp. 109 | goto end 110 | ) 111 | 112 | if "%1" == "qthelp" ( 113 | %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp 114 | if errorlevel 1 exit /b 1 115 | echo. 116 | echo.Build finished; now you can run "qcollectiongenerator" with the ^ 117 | .qhcp project file in %BUILDDIR%/qthelp, like this: 118 | echo.^> qcollectiongenerator %BUILDDIR%\qthelp\project-template.qhcp 119 | echo.To view the help file: 120 | echo.^> assistant -collectionFile %BUILDDIR%\qthelp\project-template.ghc 121 | goto end 122 | ) 123 | 124 | if "%1" == "devhelp" ( 125 | %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp 126 | if errorlevel 1 exit /b 1 127 | echo. 128 | echo.Build finished. 129 | goto end 130 | ) 131 | 132 | if "%1" == "epub" ( 133 | %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub 134 | if errorlevel 1 exit /b 1 135 | echo. 136 | echo.Build finished. The epub file is in %BUILDDIR%/epub. 137 | goto end 138 | ) 139 | 140 | if "%1" == "latex" ( 141 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 142 | if errorlevel 1 exit /b 1 143 | echo. 144 | echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. 145 | goto end 146 | ) 147 | 148 | if "%1" == "latexpdf" ( 149 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 150 | cd %BUILDDIR%/latex 151 | make all-pdf 152 | cd %BUILDDIR%/.. 153 | echo. 154 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 155 | goto end 156 | ) 157 | 158 | if "%1" == "latexpdfja" ( 159 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 160 | cd %BUILDDIR%/latex 161 | make all-pdf-ja 162 | cd %BUILDDIR%/.. 163 | echo. 164 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 165 | goto end 166 | ) 167 | 168 | if "%1" == "text" ( 169 | %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text 170 | if errorlevel 1 exit /b 1 171 | echo. 172 | echo.Build finished. The text files are in %BUILDDIR%/text. 173 | goto end 174 | ) 175 | 176 | if "%1" == "man" ( 177 | %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man 178 | if errorlevel 1 exit /b 1 179 | echo. 180 | echo.Build finished. The manual pages are in %BUILDDIR%/man. 181 | goto end 182 | ) 183 | 184 | if "%1" == "texinfo" ( 185 | %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo 186 | if errorlevel 1 exit /b 1 187 | echo. 188 | echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. 189 | goto end 190 | ) 191 | 192 | if "%1" == "gettext" ( 193 | %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale 194 | if errorlevel 1 exit /b 1 195 | echo. 196 | echo.Build finished. The message catalogs are in %BUILDDIR%/locale. 197 | goto end 198 | ) 199 | 200 | if "%1" == "changes" ( 201 | %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes 202 | if errorlevel 1 exit /b 1 203 | echo. 204 | echo.The overview file is in %BUILDDIR%/changes. 205 | goto end 206 | ) 207 | 208 | if "%1" == "linkcheck" ( 209 | %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck 210 | if errorlevel 1 exit /b 1 211 | echo. 212 | echo.Link check complete; look for any errors in the above output ^ 213 | or in %BUILDDIR%/linkcheck/output.txt. 214 | goto end 215 | ) 216 | 217 | if "%1" == "doctest" ( 218 | %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest 219 | if errorlevel 1 exit /b 1 220 | echo. 221 | echo.Testing of doctests in the sources finished, look at the ^ 222 | results in %BUILDDIR%/doctest/output.txt. 223 | goto end 224 | ) 225 | 226 | if "%1" == "xml" ( 227 | %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml 228 | if errorlevel 1 exit /b 1 229 | echo. 230 | echo.Build finished. The XML files are in %BUILDDIR%/xml. 231 | goto end 232 | ) 233 | 234 | if "%1" == "pseudoxml" ( 235 | %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml 236 | if errorlevel 1 exit /b 1 237 | echo. 238 | echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. 239 | goto end 240 | ) 241 | 242 | :end 243 | -------------------------------------------------------------------------------- /doc/modules/logic/metrics.rst: -------------------------------------------------------------------------------- 1 | Logic explanation metrics 2 | =========================================== 3 | 4 | :mod:`torch_explain.logic.metrics` 5 | 6 | .. automodule:: torch_explain.logic.metrics 7 | :members: -------------------------------------------------------------------------------- /doc/modules/logic/nn/entropy.rst: -------------------------------------------------------------------------------- 1 | Logic explanations for entropy-based LENs 2 | ============================================== 3 | 4 | :mod:`torch_explain.logic.nn.entropy` 5 | 6 | .. automodule:: torch_explain.logic.nn.entropy 7 | :members: -------------------------------------------------------------------------------- /doc/modules/logic/nn/psi.rst: -------------------------------------------------------------------------------- 1 | Logic explanations for :math:`\psi` LENs 2 | ============================================== 3 | 4 | :mod:`torch_explain.logic.nn.psi` 5 | 6 | .. automodule:: torch_explain.logic.nn.psi 7 | :members: -------------------------------------------------------------------------------- /doc/modules/logic/nn/utils.rst: -------------------------------------------------------------------------------- 1 | Utils for extracting logic explanations 2 | ============================================== 3 | 4 | :mod:`torch_explain.logic.nn.utils` 5 | 6 | .. automodule:: torch_explain.logic.nn.utils 7 | :members: -------------------------------------------------------------------------------- /doc/modules/logic/utils.rst: -------------------------------------------------------------------------------- 1 | Utils 2 | =========================================== 3 | 4 | :mod:`torch_explain.logic.utils` 5 | 6 | .. automodule:: torch_explain.logic.utils 7 | :members: -------------------------------------------------------------------------------- /doc/modules/nn/concepts.rst: -------------------------------------------------------------------------------- 1 | APIs for concepts 2 | ============================================== 3 | 4 | :mod:`torch_explain.nn.concepts` 5 | 6 | .. automodule:: torch_explain.nn.concepts 7 | :members: -------------------------------------------------------------------------------- /doc/modules/nn/functional/loss.rst: -------------------------------------------------------------------------------- 1 | Loss functions to regularize the neural model 2 | ============================================== 3 | 4 | :mod:`torch_explain.nn.functional.loss` 5 | 6 | .. automodule:: torch_explain.nn.functional.loss 7 | :members: -------------------------------------------------------------------------------- /doc/modules/nn/functional/prune.rst: -------------------------------------------------------------------------------- 1 | Prune functions to simplify the neural model 2 | ============================================== 3 | 4 | :mod:`torch_explain.nn.functional.prune` 5 | 6 | .. automodule:: torch_explain.nn.functional.prune 7 | :members: -------------------------------------------------------------------------------- /doc/modules/nn/logic.rst: -------------------------------------------------------------------------------- 1 | Logic layers 2 | ============================================== 3 | 4 | :mod:`torch_explain.nn.logic` 5 | 6 | .. automodule:: torch_explain.nn.logic 7 | :members: -------------------------------------------------------------------------------- /doc/modules/nn/semantics.rst: -------------------------------------------------------------------------------- 1 | Logic T-Norms 2 | ============================================== 3 | 4 | :mod:`torch_explain.nn.semantics` 5 | 6 | .. automodule:: torch_explain.nn.semantics 7 | :members: -------------------------------------------------------------------------------- /doc/user_guide/authors.rst: -------------------------------------------------------------------------------- 1 | Authors 2 | ======= 3 | 4 | * `Pietro Barbiero `__, University of Cambridge, UK. 5 | * Mateo Espinosa Zarlenga, University of Cambridge, UK. 6 | * Giuseppe Marra, Katholieke Universiteit Leuven, BE. 7 | * Steve Azzolin, University of Trento, IT. 8 | * Francesco Giannini, University of Florence, IT. 9 | * Gabriele Ciravegna, University of Florence, IT. 10 | * Dobrik Georgiev, University of Cambridge, UK. 11 | -------------------------------------------------------------------------------- /doc/user_guide/contributing.rst: -------------------------------------------------------------------------------- 1 | Contributing to Pytorch Explain 2 | ================================ 3 | 4 | First off, thanks for taking the time to contribute! :+1: 5 | 6 | How Can I Contribute? 7 | --------------------- 8 | 9 | * Obviously source code: patches, as well as completely new files 10 | * Bug report 11 | * Code review 12 | 13 | Coding Style 14 | ------------ 15 | 16 | **Notez Bien**: All these rules are meant to be broken, **BUT** you need a very good reason **AND** you must explain it in a comment. 17 | 18 | * Names (TL;DR): `module_name`, `package_name`, `ClassName`, `method_name`, `ExceptionName`, `function_name`, `GLOBAL_CONSTANT_NAME`, `global_var_name`, `instance_var_name`, `function_parameter_name`, `local_var_name`. 19 | 20 | * Start names internal to a module or protected or private within a class with a single underscore (`_`); don't dunder (`__`). 21 | 22 | * Use nouns for variables and properties names (`y = foo.baz`). Use full sentences for functions and methods names (`x = foo.calculate_next_bar(previous_bar)`); functions returning a boolean value (a.k.a., predicates) should start with the `is_` prefix (`if is_gargled(quz)`). 23 | 24 | * Do not implement getters and setters, use properties instead. Whether a function does not need parameters consider using a property (`foo.first_bar` instead of `foo.calculate_first_bar()`). However, do not hide complexity: if a task is computationally intensive, use an explicit method (e.g., `big_number.get_prime_factors()`). 25 | 26 | * Do not override `__repr__`. 27 | 28 | * Use `assert` to check the internal consistency and verify the correct usage of methods, not to check for the occurrence of unexpected events. That is: The optimized bytecode should not waste time verifying the correct invocation of methods or running sanity checks. 29 | 30 | * Explain the purpose of all classes and functions in docstrings; be verbose when needed, otherwise use single-line descriptions (note: each verbose description also includes a concise one as its first line). Be terse describing methods, but verbose in the class docstring, possibly including usage examples. Comment public attributes and properties in the `Attributes` section of the class docstring (even though PyCharm is not supporting it, yet); don't explain basic customizations (e.g., `__str__`). Comment `__init__` only when its parameters are not obvious. 31 | Use the formats suggested in the `Google's style guide `__). 32 | 33 | * Annotate all functions (refer to `PEP-483 `__) and `PEP-484 `__) for details). 34 | 35 | * Use English for names, in docstrings and in comments (favor formal language over slang, wit over humor, and American English over British). 36 | 37 | * Format source code using `Yapf `__)'s style `"{based_on_style: google, column_limit=120, blank_line_before_module_docstring=true}"` 38 | 39 | * Follow `PEP-440 `__) for version identification. 40 | 41 | * Follow the `Google's style guide `__) whenever in doubt. 42 | 43 | -------------------------------------------------------------------------------- /doc/user_guide/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | You can install ``torch_explain`` along with all its dependencies from 5 | `PyPI `__: 6 | 7 | .. code:: bash 8 | 9 | pip install torch-explain 10 | 11 | or from source code: 12 | 13 | .. code:: bash 14 | 15 | git clone https://github.com/pietrobarbiero/pytorch_explain.git 16 | cd ./torch_explain 17 | pip install -r requirements.txt . 18 | 19 | PyTorch Explain is compatible with Python 3.7 and above. -------------------------------------------------------------------------------- /doc/user_guide/licence.rst: -------------------------------------------------------------------------------- 1 | ============== 2 | Apache License 3 | ============== 4 | 5 | :Version: 2.0 6 | :Date: January 2004 7 | :URL: http://www.apache.org/licenses/ 8 | 9 | ------------------------------------------------------------ 10 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 11 | ------------------------------------------------------------ 12 | 13 | 1. Definitions. 14 | --------------- 15 | 16 | **"License"** shall mean the terms and conditions for use, reproduction, and 17 | distribution as defined by Sections 1 through 9 of this document. 18 | 19 | **"Licensor"** shall mean the copyright owner or entity authorized by the 20 | copyright owner that is granting the License. 21 | 22 | **"Legal Entity"** shall mean the union of the acting entity and all other 23 | entities that control, are controlled by, or are under common control with that 24 | entity. For the purposes of this definition, "control" means *(i)* the power, 25 | direct or indirect, to cause the direction or management of such entity, 26 | whether by contract or otherwise, or *(ii)* ownership of fifty percent (50%) or 27 | more of the outstanding shares, or *(iii)* beneficial ownership of such entity. 28 | 29 | **"You"** (or **"Your"**) shall mean an individual or Legal Entity exercising 30 | permissions granted by this License. 31 | 32 | **"Source"** form shall mean the preferred form for making modifications, 33 | including but not limited to software source code, documentation source, and 34 | configuration files. 35 | 36 | **"Object"** form shall mean any form resulting from mechanical transformation 37 | or translation of a Source form, including but not limited to compiled object 38 | code, generated documentation, and conversions to other media types. 39 | 40 | **"Work"** shall mean the work of authorship, whether in Source or Object form, 41 | made available under the License, as indicated by a copyright notice that is 42 | included in or attached to the work (an example is provided in the Appendix 43 | below). 44 | 45 | **"Derivative Works"** shall mean any work, whether in Source or Object form, 46 | that is based on (or derived from) the Work and for which the editorial 47 | revisions, annotations, elaborations, or other modifications represent, as a 48 | whole, an original work of authorship. For the purposes of this License, 49 | Derivative Works shall not include works that remain separable from, or merely 50 | link (or bind by name) to the interfaces of, the Work and Derivative Works 51 | thereof. 52 | 53 | **"Contribution"** shall mean any work of authorship, including the original 54 | version of the Work and any modifications or additions to that Work or 55 | Derivative Works thereof, that is intentionally submitted to Licensor for 56 | inclusion in the Work by the copyright owner or by an individual or Legal 57 | Entity authorized to submit on behalf of the copyright owner. For the purposes 58 | of this definition, "submitted" means any form of electronic, verbal, or 59 | written communication sent to the Licensor or its representatives, including 60 | but not limited to communication on electronic mailing lists, source code 61 | control systems, and issue tracking systems that are managed by, or on behalf 62 | of, the Licensor for the purpose of discussing and improving the Work, but 63 | excluding communication that is conspicuously marked or otherwise designated in 64 | writing by the copyright owner as "Not a Contribution." 65 | 66 | **"Contributor"** shall mean Licensor and any individual or Legal Entity on 67 | behalf of whom a Contribution has been received by Licensor and subsequently 68 | incorporated within the Work. 69 | 70 | 2. Grant of Copyright License. 71 | ------------------------------ 72 | 73 | Subject to the terms and conditions of this License, each Contributor hereby 74 | grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, 75 | irrevocable copyright license to reproduce, prepare Derivative Works of, 76 | publicly display, publicly perform, sublicense, and distribute the Work and 77 | such Derivative Works in Source or Object form. 78 | 79 | 3. Grant of Patent License. 80 | --------------------------- 81 | 82 | Subject to the terms and conditions of this License, each Contributor hereby 83 | grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, 84 | irrevocable (except as stated in this section) patent license to make, have 85 | made, use, offer to sell, sell, import, and otherwise transfer the Work, where 86 | such license applies only to those patent claims licensable by such Contributor 87 | that are necessarily infringed by their Contribution(s) alone or by combination 88 | of their Contribution(s) with the Work to which such Contribution(s) was 89 | submitted. If You institute patent litigation against any entity (including a 90 | cross-claim or counterclaim in a lawsuit) alleging that the Work or a 91 | Contribution incorporated within the Work constitutes direct or contributory 92 | patent infringement, then any patent licenses granted to You under this License 93 | for that Work shall terminate as of the date such litigation is filed. 94 | 95 | 4. Redistribution. 96 | ------------------ 97 | 98 | You may reproduce and distribute copies of the Work or Derivative Works thereof 99 | in any medium, with or without modifications, and in Source or Object form, 100 | provided that You meet the following conditions: 101 | 102 | - You must give any other recipients of the Work or Derivative Works a copy of 103 | this License; and 104 | 105 | - You must cause any modified files to carry prominent notices stating that You 106 | changed the files; and 107 | 108 | - You must retain, in the Source form of any Derivative Works that You 109 | distribute, all copyright, patent, trademark, and attribution notices from 110 | the Source form of the Work, excluding those notices that do not pertain to 111 | any part of the Derivative Works; and 112 | 113 | - If the Work includes a ``"NOTICE"`` text file as part of its distribution, 114 | then any Derivative Works that You distribute must include a readable copy of 115 | the attribution notices contained within such ``NOTICE`` file, excluding 116 | those notices that do not pertain to any part of the Derivative Works, in at 117 | least one of the following places: within a ``NOTICE`` text file distributed 118 | as part of the Derivative Works; within the Source form or documentation, if 119 | provided along with the Derivative Works; or, within a display generated by 120 | the Derivative Works, if and wherever such third-party notices normally 121 | appear. The contents of the ``NOTICE`` file are for informational purposes 122 | only and do not modify the License. You may add Your own attribution notices 123 | within Derivative Works that You distribute, alongside or as an addendum to 124 | the ``NOTICE`` text from the Work, provided that such additional attribution 125 | notices cannot be construed as modifying the License. You may add Your own 126 | copyright statement to Your modifications and may provide additional or 127 | different license terms and conditions for use, reproduction, or distribution 128 | of Your modifications, or for any such Derivative Works as a whole, provided 129 | Your use, reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. 133 | ------------------------------- 134 | 135 | Unless You explicitly state otherwise, any Contribution intentionally submitted 136 | for inclusion in the Work by You to the Licensor shall be under the terms and 137 | conditions of this License, without any additional terms or conditions. 138 | Notwithstanding the above, nothing herein shall supersede or modify the terms 139 | of any separate license agreement you may have executed with Licensor regarding 140 | such Contributions. 141 | 142 | 6. Trademarks. 143 | -------------- 144 | 145 | This License does not grant permission to use the trade names, trademarks, 146 | service marks, or product names of the Licensor, except as required for 147 | reasonable and customary use in describing the origin of the Work and 148 | reproducing the content of the ``NOTICE`` file. 149 | 150 | 7. Disclaimer of Warranty. 151 | -------------------------- 152 | 153 | Unless required by applicable law or agreed to in writing, Licensor provides 154 | the Work (and each Contributor provides its Contributions) on an **"AS IS" 155 | BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND**, either express or 156 | implied, including, without limitation, any warranties or conditions of 157 | **TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR 158 | PURPOSE**. You are solely responsible for determining the appropriateness of 159 | using or redistributing the Work and assume any risks associated with Your 160 | exercise of permissions under this License. 161 | 162 | 8. Limitation of Liability. 163 | --------------------------- 164 | 165 | In no event and under no legal theory, whether in tort (including negligence), 166 | contract, or otherwise, unless required by applicable law (such as deliberate 167 | and grossly negligent acts) or agreed to in writing, shall any Contributor be 168 | liable to You for damages, including any direct, indirect, special, incidental, 169 | or consequential damages of any character arising as a result of this License 170 | or out of the use or inability to use the Work (including but not limited to 171 | damages for loss of goodwill, work stoppage, computer failure or malfunction, 172 | or any and all other commercial damages or losses), even if such Contributor 173 | has been advised of the possibility of such damages. 174 | 175 | 9. Accepting Warranty or Additional Liability. 176 | ---------------------------------------------- 177 | 178 | While redistributing the Work or Derivative Works thereof, You may choose to 179 | offer, and charge a fee for, acceptance of support, warranty, indemnity, or 180 | other liability obligations and/or rights consistent with this License. 181 | However, in accepting such obligations, You may act only on Your own behalf and 182 | on Your sole responsibility, not on behalf of any other Contributor, and only 183 | if You agree to indemnify, defend, and hold each Contributor harmless for any 184 | liability incurred by, or claims asserted against, such Contributor by reason 185 | of your accepting any such warranty or additional liability. 186 | 187 | **END OF TERMS AND CONDITIONS** 188 | 189 | APPENDIX: How to apply the Apache License to your work 190 | ------------------------------------------------------ 191 | 192 | To apply the Apache License to your work, attach the following boilerplate 193 | notice, with the fields enclosed by brackets "[]" replaced with your own 194 | identifying information. (Don't include the brackets!) The text should be 195 | enclosed in the appropriate comment syntax for the file format. We also 196 | recommend that a file or class name and description of purpose be included on 197 | the same "printed page" as the copyright notice for easier identification within 198 | third-party archives. :: 199 | 200 | Copyright [yyyy] [name of copyright owner] 201 | 202 | Licensed under the Apache License, Version 2.0 (the "License"); 203 | you may not use this file except in compliance with the License. 204 | You may obtain a copy of the License at 205 | 206 | http://www.apache.org/licenses/LICENSE-2.0 207 | 208 | Unless required by applicable law or agreed to in writing, software 209 | distributed under the License is distributed on an "AS IS" BASIS, 210 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 211 | See the License for the specific language governing permissions and 212 | limitations under the License. -------------------------------------------------------------------------------- /doc/user_guide/running_tests.rst: -------------------------------------------------------------------------------- 1 | Running tests 2 | ============= 3 | 4 | You can run all unittests from command line after having 5 | downloaded the source code from 6 | `GitHub `__: 7 | 8 | .. code:: bash 9 | 10 | $ git clone https://github.com/pietrobarbiero/pytorch_explain.git 11 | $ cd ./torch_explain 12 | 13 | You can use either python: 14 | 15 | .. code:: bash 16 | 17 | $ python -m unittest discover 18 | 19 | or coverage: 20 | 21 | .. code:: bash 22 | 23 | $ coverage run -m unittest discover -------------------------------------------------------------------------------- /doc/user_guide/tutorial_cem.rst: -------------------------------------------------------------------------------- 1 | Concept Embeddings tutorial 2 | ========================================== 3 | 4 | Limits of Concept Bottleneck Models 5 | ------------------------------------------ 6 | 7 | For this simple tutorial, let's approach 8 | the trigonometry benchmark dataset with a concept bottleneck model: 9 | 10 | .. code:: python 11 | 12 | import torch 13 | import torch_explain as te 14 | from torch_explain import datasets 15 | from sklearn.metrics import accuracy_score 16 | from sklearn.model_selection import train_test_split 17 | 18 | x, c, y = datasets.trigonometry(500) 19 | x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(x, c, y, test_size=0.33, random_state=42) 20 | 21 | We can instantiate a simple concept encoder 22 | to map the input features to the concept space and then 23 | a task predictor to map concepts to task predictions: 24 | 25 | .. code:: python 26 | 27 | concept_encoder = torch.nn.Sequential( 28 | torch.nn.Linear(x.shape[1], 10), 29 | torch.nn.LeakyReLU(), 30 | torch.nn.Linear(10, 8), 31 | torch.nn.LeakyReLU(), 32 | torch.nn.Linear(8, c.shape[1]), 33 | torch.nn.Sigmoid(), 34 | ) 35 | task_predictor = torch.nn.Sequential( 36 | torch.nn.Linear(c.shape[1], 1), 37 | ) 38 | model = torch.nn.Sequential(concept_encoder, task_predictor) 39 | 40 | We can now train the network by optimizing the cross entropy loss 41 | on both concepts and tasks: 42 | 43 | .. code:: python 44 | 45 | optimizer = torch.optim.AdamW(model.parameters(), lr=0.01) 46 | loss_form_c = torch.nn.BCELoss() 47 | loss_form_y = torch.nn.BCEWithLogitsLoss() 48 | model.train() 49 | for epoch in range(501): 50 | optimizer.zero_grad() 51 | 52 | # generate concept and task predictions 53 | c_pred = concept_encoder(x_train) 54 | y_pred = task_predictor(c_pred) 55 | 56 | # update loss 57 | concept_loss = loss_form_c(c_pred, c_train) 58 | task_loss = loss_form_y(y_pred, y_train) 59 | loss = concept_loss + 0.5*task_loss 60 | 61 | loss.backward() 62 | optimizer.step() 63 | 64 | Once trained we can check the performance of the model on the test set: 65 | 66 | .. code:: python 67 | 68 | c_pred = concept_encoder(x_test) 69 | y_pred = task_predictor(c_pred) 70 | 71 | concept_accuracy = accuracy_score(c_test, c_pred > 0.5) 72 | task_accuracy = accuracy_score(y_test, y_pred > 0) 73 | 74 | As you can see the performance of the model is not great as the task 75 | task accuracy is around ~80%. Can we do better? 76 | 77 | 78 | Concept Embeddings 79 | ------------------------------ 80 | 81 | Using concept embeddings we can solve our problem much more efficiently. 82 | We just need to define a task predictor and a concept encoder using a 83 | concept embedding layer: 84 | 85 | .. code:: python 86 | 87 | import torch 88 | import torch_explain as te 89 | 90 | embedding_size = 8 91 | concept_encoder = torch.nn.Sequential( 92 | torch.nn.Linear(x.shape[1], 10), 93 | torch.nn.LeakyReLU(), 94 | te.nn.ConceptEmbedding(10, c.shape[1], embedding_size), 95 | ) 96 | task_predictor = torch.nn.Sequential( 97 | torch.nn.Linear(c.shape[1]*embedding_size, 1), 98 | ) 99 | model = torch.nn.Sequential(concept_encoder, task_predictor) 100 | 101 | We can now train the network by optimizing the cross entropy loss 102 | on concepts and tasks: 103 | 104 | .. code:: python 105 | 106 | optimizer = torch.optim.AdamW(model.parameters(), lr=0.01) 107 | loss_form_c = torch.nn.BCELoss() 108 | loss_form_y = torch.nn.BCEWithLogitsLoss() 109 | model.train() 110 | for epoch in range(501): 111 | optimizer.zero_grad() 112 | 113 | # generate concept and task predictions 114 | c_emb, c_pred = concept_encoder(x_train) 115 | y_pred = task_predictor(c_emb.reshape(len(c_emb), -1)) 116 | 117 | # compute loss 118 | concept_loss = loss_form_c(c_pred, c_train) 119 | task_loss = loss_form_y(y_pred, y_train) 120 | loss = concept_loss + 0.5*task_loss 121 | 122 | loss.backward() 123 | optimizer.step() 124 | 125 | Once trained we can check the performance of the model on the test set: 126 | 127 | .. code:: python 128 | 129 | c_emb, c_pred = concept_encoder.forward(x_test) 130 | y_pred = task_predictor(c_emb.reshape(len(c_emb), -1)) 131 | 132 | concept_accuracy = accuracy_score(c_test, c_pred > 0.5) 133 | task_accuracy = accuracy_score(y_test, y_pred > 0) 134 | 135 | As you can see the performance of the model is now great as the task 136 | task accuracy is around ~100%. 137 | -------------------------------------------------------------------------------- /doc/user_guide/tutorial_dcr.rst: -------------------------------------------------------------------------------- 1 | Deep Concept Reasoning tutorial 2 | ========================================== 3 | 4 | Limits of Concept Embeddings 5 | -------------------------------- 6 | 7 | For this simple tutorial, let's use 8 | the XOR benchmark dataset: 9 | 10 | .. code:: python 11 | 12 | import torch 13 | import torch_explain as te 14 | from torch_explain import datasets 15 | from sklearn.metrics import accuracy_score 16 | from sklearn.model_selection import train_test_split 17 | 18 | x, c, y = datasets.xor(500) 19 | x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(x, c, y, test_size=0.33, random_state=42) 20 | 21 | Using concept embeddings we can solve our problem efficiently. 22 | We just need to define a task predictor and a concept encoder using a 23 | concept embedding layer: 24 | 25 | .. code:: python 26 | 27 | import torch 28 | import torch_explain as te 29 | 30 | embedding_size = 8 31 | concept_encoder = torch.nn.Sequential( 32 | torch.nn.Linear(x.shape[1], 10), 33 | torch.nn.LeakyReLU(), 34 | te.nn.ConceptEmbedding(10, c.shape[1], embedding_size), 35 | ) 36 | task_predictor = torch.nn.Sequential( 37 | torch.nn.Linear(c.shape[1]*embedding_size, 1), 38 | ) 39 | model = torch.nn.Sequential(concept_encoder, task_predictor) 40 | 41 | We can now train the network by optimizing the cross entropy loss 42 | on concepts and tasks: 43 | 44 | .. code:: python 45 | 46 | optimizer = torch.optim.AdamW(model.parameters(), lr=0.01) 47 | loss_form_c = torch.nn.BCELoss() 48 | loss_form_y = torch.nn.BCEWithLogitsLoss() 49 | model.train() 50 | for epoch in range(501): 51 | optimizer.zero_grad() 52 | 53 | # generate concept and task predictions 54 | c_emb, c_pred = concept_encoder(x_train) 55 | y_pred = task_predictor(c_emb.reshape(len(c_emb), -1)) 56 | 57 | # compute loss 58 | concept_loss = loss_form_c(c_pred, c_train) 59 | task_loss = loss_form_y(y_pred, y_train) 60 | loss = concept_loss + 0.5*task_loss 61 | 62 | loss.backward() 63 | optimizer.step() 64 | 65 | Once trained we can check the performance of the model on the test set: 66 | 67 | .. code:: python 68 | 69 | c_emb, c_pred = concept_encoder.forward(x_test) 70 | y_pred = task_predictor(c_emb.reshape(len(c_emb), -1)) 71 | 72 | task_accuracy = accuracy_score(y_test, y_pred > 0) 73 | concept_accuracy = accuracy_score(c_test, c_pred > 0.5) 74 | 75 | As you can see the performance of the model is now great as the task 76 | task accuracy is around ~100%. 77 | 78 | However, we cannot explain exactly the reasoning process of the 79 | model! How are concept embeddings used to predict the task? 80 | To answer this question we need to use Deep Concept Reasoning. 81 | 82 | 83 | Deep Concept Reasoning 84 | ---------------------------- 85 | 86 | Using deep concept reasoning we can solve the same problem as above, 87 | but with an intrinsically interpretable model! In fact, Deep Concept Reasoners (Deep CoRes) 88 | make task predictions by means of interpretable logic rules using concept embeddings. 89 | 90 | Using the same example as before, we can just change the task predictor 91 | using a Deep CoRe layer: 92 | 93 | .. code:: python 94 | 95 | from torch_explain.nn.concepts import ConceptReasoningLayer 96 | import torch.nn.functional as F 97 | 98 | y_train = F.one_hot(y_train.long().ravel()).float() 99 | y_test = F.one_hot(y_test.long().ravel()).float() 100 | 101 | task_predictor = ConceptReasoningLayer(embedding_size, y_train.shape[1]) 102 | model = torch.nn.Sequential(concept_encoder, task_predictor) 103 | 104 | 105 | We can now train the network by optimizing the cross entropy loss 106 | on concepts and tasks: 107 | 108 | .. code:: python 109 | 110 | optimizer = torch.optim.AdamW(model.parameters(), lr=0.01) 111 | loss_form = torch.nn.BCELoss() 112 | model.train() 113 | for epoch in range(501): 114 | optimizer.zero_grad() 115 | 116 | # generate concept and task predictions 117 | c_emb, c_pred = concept_encoder(x_train) 118 | y_pred = task_predictor(c_emb, c_pred) 119 | 120 | # compute loss 121 | concept_loss = loss_form(c_pred, c_train) 122 | task_loss = loss_form(y_pred, y_train) 123 | loss = concept_loss + 0.5*task_loss 124 | 125 | loss.backward() 126 | optimizer.step() 127 | 128 | Once trained the Deep CoRe layer can explain its predictions by 129 | providing both local and global logic rules: 130 | 131 | 132 | .. code:: python 133 | 134 | local_explanations = task_predictor.explain(c_emb, c_pred, 'local') 135 | global_explanations = task_predictor.explain(c_emb, c_pred, 'global') 136 | 137 | 138 | For global explanations, the reasoner will return a dictionary with entries such as 139 | ``{'class': 'y_0', 'explanation': '~c_0 & ~c_1', 'count': 94}``, specifying 140 | for each logic rule, the task it is associated with and the number of samples 141 | associated with the explanation. 142 | 143 | 144 | -------------------------------------------------------------------------------- /doc/user_guide/tutorial_lens.rst: -------------------------------------------------------------------------------- 1 | Logic Explained Network (LENs) tutorial 2 | ========================================== 3 | 4 | Entropy-based LENs 5 | ----------------------- 6 | 7 | For this simple tutorial, let's solve the XOR problem 8 | (augmented with 100 dummy features): 9 | 10 | .. code:: python 11 | 12 | import torch 13 | import torch_explain as te 14 | from torch.nn.functional import one_hot 15 | 16 | x0 = torch.zeros((4, 100)) 17 | x_train = torch.tensor([ 18 | [0, 0], 19 | [0, 1], 20 | [1, 0], 21 | [1, 1], 22 | ], dtype=torch.float) 23 | x_train = torch.cat([x_train, x0], dim=1) 24 | y_train = torch.tensor([0, 1, 1, 0], dtype=torch.long) 25 | y_train_1h = one_hot(y_train).to(torch.float) 26 | 27 | We can instantiate a simple feed-forward neural network 28 | with 3 layers using the ``EntropyLayer`` as the first one: 29 | 30 | .. code:: python 31 | 32 | layers = [ 33 | te.nn.EntropyLinear(x_train.shape[1], 10, n_classes=y_train_1h.shape[1]), 34 | torch.nn.LeakyReLU(), 35 | torch.nn.Linear(10, 4), 36 | torch.nn.LeakyReLU(), 37 | torch.nn.Linear(4, 1), 38 | ] 39 | model = torch.nn.Sequential(*layers) 40 | 41 | We can now train the network by optimizing the cross entropy loss and the 42 | ``entropy_logic_loss`` loss function incorporating the human prior towards 43 | simple explanations: 44 | 45 | .. code:: python 46 | 47 | optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) 48 | loss_form = torch.nn.BCEWithLogitsLoss() 49 | model.train() 50 | for epoch in range(2001): 51 | optimizer.zero_grad() 52 | y_pred = model(x_train).squeeze(-1) 53 | loss = loss_form(y_pred, y_train_1h) + 0.0001 * te.nn.functional.entropy_logic_loss(model) 54 | loss.backward() 55 | optimizer.step() 56 | 57 | Once trained we can extract first-order logic formulas describing 58 | how the network composed the input features to obtain the predictions: 59 | 60 | .. code:: python 61 | 62 | from torch_explain.logic.nn import entropy 63 | from torch.nn.functional import one_hot 64 | 65 | y1h = one_hot(y_train) 66 | global_explanations, local_explanations = entropy.explain_classes(model, x_train, y_train, c_threshold=0.5, y_threshold=0.) 67 | 68 | Explanations will be logic formulas in disjunctive normal form. 69 | In this case, the explanation will be ``y=1`` if and only if ``(f1 AND ~f2) OR (f2 AND ~f1)`` 70 | corresponding to ``f1 XOR f2``. 71 | 72 | The function automatically assesses the quality of logic explanations in terms 73 | of classification accuracy and rule complexity. 74 | In this case the accuracy is 100% and the complexity is 4. 75 | 76 | 77 | :math:`\psi` LENs 78 | ----------------------- 79 | 80 | For this simple tutorial, let's solve the XOR problem 81 | using a :math:`\psi` LEN: 82 | 83 | .. code:: python 84 | 85 | import torch 86 | import torch_explain as te 87 | 88 | x_train = torch.tensor([ 89 | [0, 0], 90 | [0, 1], 91 | [1, 0], 92 | [1, 1], 93 | ], dtype=torch.float) 94 | y_train = torch.tensor([0, 1, 1, 0], dtype=torch.float).unsqueeze(1) 95 | 96 | We can instantiate a simple :math:`\psi` network 97 | with 3 layers using **sigmoid activation functions only**: 98 | 99 | .. code:: python 100 | 101 | layers = [ 102 | torch.nn.Linear(x_train.shape[1], 10), 103 | torch.nn.Sigmoid(), 104 | torch.nn.Linear(10, 5), 105 | torch.nn.Sigmoid(), 106 | torch.nn.Linear(5, 1), 107 | torch.nn.Sigmoid(), 108 | ] 109 | model = torch.nn.Sequential(*layers) 110 | 111 | We can now train the network by optimizing the binary cross entropy loss and the 112 | ``l1_loss`` loss function incorporating the human prior towards 113 | simple explanations. The :math:`\psi` networks needs to be pruned during training 114 | to simplify the internal architecture (here pruning happens at epoch 1000): 115 | 116 | .. code:: python 117 | 118 | from torch_explain.nn.functional import prune_equal_fanin 119 | 120 | optimizer = torch.optim.AdamW(model.parameters(), lr=0.01) 121 | loss_form = torch.nn.BCELoss() 122 | model.train() 123 | for epoch in range(6001): 124 | optimizer.zero_grad() 125 | y_pred = model(x_train) 126 | loss = loss_form(y_pred, y_train) + 0.000001 * te.nn.functional.l1_loss(model) 127 | loss.backward() 128 | optimizer.step() 129 | 130 | model = prune_equal_fanin(model, epoch, prune_epoch=1000, k=2) 131 | 132 | Once trained we can extract first-order logic formulas describing 133 | how the network composed the input features to obtain the predictions: 134 | 135 | .. code:: python 136 | 137 | from torch_explain.logic.nn import psi 138 | from torch.nn.functional import one_hot 139 | 140 | y1h = one_hot(y_train.squeeze().long()) 141 | explanation = psi.explain_class(model, x_train) 142 | 143 | Explanations will be logic formulas in disjunctive normal form. 144 | In this case, the explanation will be ``y=1 IFF (f1 AND ~f2) OR (f2 AND ~f1)`` 145 | corresponding to ``y=1 IFF f1 XOR f2``. 146 | 147 | The quality of the logic explanation can **quantitatively** assessed in terms 148 | of classification accuracy and rule complexity as follows: 149 | 150 | .. code:: python 151 | 152 | from torch_explain.logic.metrics import test_explanation, complexity 153 | 154 | accuracy, preds = test_explanation(explanation, x_train, y1h, target_class=1) 155 | explanation_complexity = complexity(explanation) 156 | 157 | In this case the accuracy is 100% and the complexity is 4. -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: project-template 2 | dependencies: 3 | - numpy 4 | - scipy 5 | - scikit-learn 6 | - pandas 7 | -------------------------------------------------------------------------------- /experiments/bio/tabula_muris.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "sys.path.append('../')\n", 11 | "from experiments.data import load_datasets" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 4, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stderr", 21 | "output_type": "stream", 22 | "text": [ 23 | "D:\\research\\coding\\neural_networks\\pytorch_explain\\experiments\\bio\\../..\\experiments\\data\\tabula_muris_comet\\preprocess.py:74: ImplicitModificationWarning: Trying to modify attribute `.obs` of view, initializing view as actual.\n", 24 | " self.adata.obs['label'] = pd.Categorical(values=truth_labels)\n", 25 | "C:\\Users\\pietr\\anaconda3\\envs\\torch110\\lib\\site-packages\\scanpy\\preprocessing\\_simple.py:373: UserWarning: Received a view of an AnnData. Making a copy.\n", 26 | " view_to_actual(adata)\n" 27 | ] 28 | }, 29 | { 30 | "name": "stdout", 31 | "output_type": "stream", 32 | "text": [ 33 | "_________Gene count processed_________\n", 34 | "mgi2go_set 24581\n", 35 | "adata_set 2866\n", 36 | "union 2644\n", 37 | "../data\\tabula_muris_comet\\go-basic.obo: fmt(1.2) rel(2018-10-24) 47,358 Terms\n", 38 | "Num filtered GOs: 153\n" 39 | ] 40 | } 41 | ], 42 | "source": [ 43 | "data_loader = load_datasets.load_tabula_muris(base_dir='../data', batch_size=30, mode='train')" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 22, 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "data": { 53 | "text/plain": [ 54 | "153" 55 | ] 56 | }, 57 | "execution_count": 22, 58 | "metadata": {}, 59 | "output_type": "execute_result" 60 | } 61 | ], 62 | "source": [ 63 | "len(data_loader.dataset.go_mask)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 11, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "alldata = next(iter(data_loader))" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 13, 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "data": { 82 | "text/plain": [ 83 | "torch.Size([30, 2866])" 84 | ] 85 | }, 86 | "execution_count": 13, 87 | "metadata": {}, 88 | "output_type": "execute_result" 89 | } 90 | ], 91 | "source": [ 92 | "alldata[0].shape" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 14, 98 | "metadata": {}, 99 | "outputs": [ 100 | { 101 | "data": { 102 | "text/plain": [ 103 | "torch.Size([30])" 104 | ] 105 | }, 106 | "execution_count": 14, 107 | "metadata": {}, 108 | "output_type": "execute_result" 109 | } 110 | ], 111 | "source": [ 112 | "alldata[1].shape" 113 | ] 114 | } 115 | ], 116 | "metadata": { 117 | "kernelspec": { 118 | "display_name": "Python 3 (ipykernel)", 119 | "language": "python", 120 | "name": "python3" 121 | }, 122 | "language_info": { 123 | "codemirror_mode": { 124 | "name": "ipython", 125 | "version": 3 126 | }, 127 | "file_extension": ".py", 128 | "mimetype": "text/x-python", 129 | "name": "python", 130 | "nbconvert_exporter": "python", 131 | "pygments_lexer": "ipython3", 132 | "version": "3.9.7" 133 | } 134 | }, 135 | "nbformat": 4, 136 | "nbformat_minor": 1 137 | } 138 | -------------------------------------------------------------------------------- /experiments/data/tabula_muris_comet/__init__.py: -------------------------------------------------------------------------------- 1 | from . import datamgr 2 | from . import dataset 3 | # from . import additional_transforms 4 | from . import feature_loader 5 | -------------------------------------------------------------------------------- /experiments/data/tabula_muris_comet/datamgr.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | from PIL import Image 5 | import numpy as np 6 | import torchvision.transforms as transforms 7 | # import data.additional_transforms as add_transforms 8 | from abc import abstractmethod 9 | 10 | from experiments.data.tabula_muris_comet.dataset import SimpleDataset, SetDataset, EpisodicBatchSampler 11 | 12 | 13 | class DataManager: 14 | @abstractmethod 15 | def get_data_loader(self, data_file, aug): 16 | pass 17 | 18 | 19 | class SimpleDataManager(DataManager): 20 | def __init__(self, batch_size): 21 | super(SimpleDataManager, self).__init__() 22 | self.batch_size = batch_size 23 | 24 | def get_data_loader(self, root='./filelists/tabula_muris', mode='train'): #parameters that would change on train/val set 25 | dataset = SimpleDataset(root=root, mode=mode, min_samples=self.batch_size) 26 | data_loader_params = dict(batch_size = self.batch_size, shuffle = True, num_workers = 4, pin_memory = True) 27 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 28 | 29 | return data_loader 30 | 31 | class SetDataManager(DataManager): 32 | def __init__(self, n_way, n_support, n_query, n_eposide =100): 33 | super(SetDataManager, self).__init__() 34 | self.n_way = n_way 35 | self.batch_size = n_support + n_query 36 | self.n_eposide = n_eposide 37 | 38 | def get_data_loader(self, root='./filelists/tabula_muris', mode='train'): #parameters that would change on train/val set 39 | dataset = SetDataset(root=root, mode=mode, min_samples=self.batch_size) 40 | sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide ) 41 | data_loader_params = dict(batch_sampler = sampler, num_workers = 4, pin_memory = True) 42 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 43 | return data_loader 44 | 45 | 46 | -------------------------------------------------------------------------------- /experiments/data/tabula_muris_comet/dataset.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | from PIL import Image 5 | import json 6 | import numpy as np 7 | import os 8 | import torch.utils.data as data 9 | from .preprocess import MacaData 10 | from .map_GO import get_go2gene 11 | identity = lambda x:x 12 | 13 | def create_go_mask(adata, go2gene): 14 | genes = adata.var_names 15 | gene2index = {g: i for i, g in enumerate(genes)} 16 | GO_IDs = sorted(go2gene.keys()) 17 | go_mask = [] 18 | for go in GO_IDs: 19 | go_genes = go2gene[go] 20 | go_mask.append([gene2index[gene] for gene in go_genes]) 21 | return go_mask 22 | 23 | def load_tabular_muris(root='./filelists/tabula_muris', mode='train', min_samples=20): 24 | train_tissues = ['BAT', 'Bladder', 'Brain_Myeloid', 'Brain_Non-Myeloid', 25 | 'Diaphragm', 'GAT', 'Heart', 'Kidney', 'Limb_Muscle', 'Liver', 'MAT', 'Mammary_Gland', 26 | 'SCAT', 'Spleen', 'Trachea'] 27 | val_tissues = ["Skin", "Lung", "Thymus", "Aorta"] 28 | test_tissues = ["Large_Intestine", "Marrow", "Pancreas", "Tongue"] 29 | split = {'train': train_tissues, 30 | 'val': val_tissues, 31 | 'test': test_tissues} 32 | adata = MacaData(src_file=os.path.join(root, "tabula-muris-comet.h5ad")).adata 33 | tissues = split[mode] 34 | # subset data based on target tissues 35 | adata = adata[adata.obs['tissue'].isin(tissues)] 36 | 37 | filtered_index = adata.obs.groupby(["label"]) \ 38 | .filter(lambda group: len(group) >= min_samples) \ 39 | .reset_index()['index'] 40 | adata = adata[filtered_index] 41 | 42 | # convert gene to torch tensor x 43 | samples = adata.to_df().to_numpy(dtype=np.float32) 44 | # convert label to torch tensor y 45 | targets = adata.obs['label'].cat.codes.to_numpy(dtype=np.int32) 46 | go2gene = get_go2gene(adata=adata, GO_min_genes=32, GO_max_genes=None, GO_min_level=6, 47 | GO_max_level=1, data_dir=root) 48 | go_mask = create_go_mask(adata, go2gene) 49 | return samples, targets, go_mask 50 | 51 | 52 | class SimpleDataset: 53 | def __init__(self, root='./filelists/tabula_muris', mode='train', min_samples=20): 54 | samples_all, targets_all, go_masks_all = load_tabular_muris(root=root, mode=mode, min_samples=min_samples) 55 | self.samples = samples_all 56 | self.targets = targets_all 57 | self.go_mask = go_masks_all 58 | 59 | def __getitem__(self,i): 60 | return self.samples[i], self.targets[i] 61 | 62 | def __len__(self): 63 | return self.samples.shape[0] 64 | 65 | def get_dim(self): 66 | return self.samples.shape[1] 67 | 68 | 69 | class SetDataset: 70 | def __init__(self, root='./filelists/tabula_muris', mode='train', min_samples=20): 71 | samples_all, targets_all, go_masks = load_tabular_muris(root=root, mode=mode, min_samples=min_samples) 72 | self.cl_list = np.unique(targets_all) 73 | self.go_mask = go_masks 74 | self.x_dim = samples_all.shape[1] 75 | self.sub_dataloader =[] 76 | sub_data_loader_params = dict(batch_size = min_samples, 77 | shuffle = True, 78 | num_workers = 0, #use main thread only or may receive multiple batches 79 | pin_memory = False) 80 | for cl in self.cl_list: 81 | samples = samples_all[targets_all == cl, ...] 82 | sub_dataset = SubDataset(samples, cl) 83 | self.sub_dataloader.append( torch.utils.data.DataLoader(sub_dataset, **sub_data_loader_params) ) 84 | 85 | def __getitem__(self,i): 86 | return next(iter(self.sub_dataloader[i])) 87 | 88 | def __len__(self): 89 | return len(self.cl_list) 90 | 91 | def get_dim(self): 92 | return self.x_dim 93 | 94 | class SubDataset: 95 | def __init__(self, samples, cl): 96 | self.samples = samples 97 | self.cl = cl 98 | 99 | def __getitem__(self,i): 100 | return self.samples[i], self.cl 101 | 102 | def __len__(self): 103 | return self.samples.shape[0] 104 | 105 | def get_dim(self): 106 | return self.samples.shape[1] 107 | 108 | class EpisodicBatchSampler(object): 109 | def __init__(self, n_classes, n_way, n_episodes): 110 | self.n_classes = n_classes 111 | self.n_way = n_way 112 | self.n_episodes = n_episodes 113 | 114 | def __len__(self): 115 | return self.n_episodes 116 | 117 | def __iter__(self): 118 | for i in range(self.n_episodes): 119 | yield torch.randperm(self.n_classes)[:self.n_way] 120 | -------------------------------------------------------------------------------- /experiments/data/tabula_muris_comet/feature_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import h5py 4 | 5 | class SimpleHDF5Dataset: 6 | def __init__(self, file_handle = None): 7 | if file_handle == None: 8 | self.f = '' 9 | self.all_feats_dset = [] 10 | self.all_labels = [] 11 | self.total = 0 12 | else: 13 | self.f = file_handle 14 | self.all_feats_dset = self.f['all_feats'][...] 15 | self.all_labels = self.f['all_labels'][...] 16 | self.total = self.f['count'][0] 17 | # print('here') 18 | def __getitem__(self, i): 19 | return torch.Tensor(self.all_feats_dset[i,:]), int(self.all_labels[i]) 20 | 21 | def __len__(self): 22 | return self.total 23 | 24 | def init_loader(filename): 25 | with h5py.File(filename, 'r') as f: 26 | fileset = SimpleHDF5Dataset(f) 27 | 28 | #labels = [ l for l in fileset.all_labels if l != 0] 29 | feats = fileset.all_feats_dset 30 | labels = fileset.all_labels 31 | while np.sum(feats[-1]) == 0: 32 | feats = np.delete(feats,-1,axis = 0) 33 | labels = np.delete(labels,-1,axis = 0) 34 | 35 | class_list = np.unique(np.array(labels)).tolist() 36 | inds = range(len(labels)) 37 | 38 | cl_data_file = {} 39 | for cl in class_list: 40 | cl_data_file[cl] = [] 41 | for ind in inds: 42 | cl_data_file[labels[ind]].append( feats[ind]) 43 | 44 | return cl_data_file 45 | -------------------------------------------------------------------------------- /experiments/data/tabula_muris_comet/map_GO.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Oct 26, 2018 3 | 4 | @author: maria 5 | ''' 6 | import os 7 | from sys import warnoptions 8 | from goatools import obo_parser 9 | from collections import defaultdict 10 | import scanpy as sc 11 | import numpy as np 12 | import pandas as pd 13 | import torch 14 | 15 | 16 | def prepare_GO_data(adata, gene2go, GO_file, GO_min_genes=500, GO_max_genes=None, GO_min_level=3, GO_max_level=3): 17 | """ 18 | Preprocesses data . 19 | GO terms are propagated to all parents categories so all GO terms satisfying conditions of 20 | min and max genes are included. 21 | gene2go: mapping of gene IDs to GO terms 22 | count_data: anndata object containing raw count data 23 | GO_file: GO ontology obo file 24 | GO_min_genes: minimum number of genes assigned to GO required to keep GO term (default: 500) 25 | GO_max_genes: maximum number of genes assigned to GO required to keep GO term (default: None) 26 | GO_min_level: minimum level required to keep GO term (default: 3) 27 | npcs: number of principal components 28 | annotations: dictionary containing cell annotations (default: None) 29 | return: dictionary of GO terms with processed anndata object with calculated knn graph 30 | of only genes belonging to that GO term 31 | """ 32 | GOdag = obo_parser.GODag(obo_file=GO_file) 33 | genes = set(adata.var_names) 34 | 35 | gene2go = {g: gene2go[g] for g in gene2go.keys() if g in genes} 36 | GOdag.update_association(gene2go) # propagate through hierarchy 37 | go2gene = reverse_association(gene2go) 38 | # return go2gene 39 | filtered_go2gene = {} 40 | 41 | for GO in go2gene: 42 | ngenes = len(go2gene[GO]) 43 | if check_conditions(GOdag.get(GO), ngenes, GO_min_genes, 44 | GO_max_genes, GO_min_level, GO_max_level): 45 | filtered_go2gene[GO] = go2gene[GO] 46 | print("Num filtered GOs:", len(filtered_go2gene)) 47 | return filtered_go2gene 48 | 49 | 50 | def check_conditions(GOterm, num_genes, min_genes, max_genes, min_level, max_level): 51 | """Check whether GO term satisfies required conditions.""" 52 | 53 | if min_genes != None: 54 | if num_genes < min_genes: 55 | return False 56 | if max_genes != None: 57 | if num_genes > max_genes: 58 | return False 59 | if min_level != None: 60 | if GOterm.level < min_level: 61 | return False 62 | if min_level != None: 63 | if GOterm.level > min_level: 64 | return False 65 | return True 66 | 67 | 68 | def filter_cells(adata, min_genes=501, min_counts=50001): 69 | """Removing cells which do not have min_genes and min_counts as done 70 | in Tabula Muris preprocessing. 71 | min_genes: minimum number of genes required to retain a cell 72 | min_counts: minimum number of counts required to retain a cell 73 | """ 74 | sc.pp.filter_cells(adata, min_genes=min_genes) 75 | sc.pp.filter_cells(adata, min_counts=min_counts) 76 | return adata 77 | 78 | 79 | def remove_ERCC_genes(adata): 80 | """Removing ERCC genes as done in Tabula Muris preprocessing.""" 81 | genes = adata.var_names 82 | # remove genes starting with ERCC 83 | idx = [i for i, g in enumerate(genes) if not g.startswith("ERCC")] 84 | 85 | genes = [g for i, g in enumerate(genes) if i in idx] 86 | adata = adata[:, idx] 87 | return adata 88 | 89 | 90 | def anndata_to_df(adata): 91 | return pd.DataFrame(adata.X.toarray(), dtype=np.float32, index=adata.obs_names, 92 | columns=adata.var_names).transpose() 93 | 94 | 95 | def reverse_association(gene2go): 96 | """ 97 | For given dictionary of genes mapped to set of GO 98 | terms, creates mapping of GO terms to gene IDs. 99 | gene2go: mapping of gene IDs to GO terms 100 | return: mapping of GO terms to gene IDs 101 | """ 102 | go2gene = defaultdict(set) 103 | for gene, go_set in gene2go.items(): 104 | for go in go_set: 105 | go2gene[go].add(gene) 106 | return go2gene 107 | 108 | 109 | def map_mgi2go(filepath): 110 | """ 111 | Reads from file mapping of MGI mouse gene ID to GO. Takes only genes with 112 | experimental and high throughput evidence codes. 113 | filepath: file containing mapping 114 | return: mapping of MGI to GO 115 | """ 116 | supported_codes = {"EXP", "IDA", "IPI", "IMP", "IGI", "IEP", 117 | "HTP", "HDA", "HMP", "HGI", "HEP"} 118 | mgi2go = defaultdict(set) 119 | with open(filepath) as f: 120 | for line in f.readlines(): 121 | if line[0] != '!': 122 | line = line.split("\t") 123 | mgi = line[2] 124 | go = line[4] 125 | evidence_code = line[6] 126 | # FIXME experimental, comment out 127 | mgi2go[mgi].add(go) 128 | # if evidence_code in supported_codes: 129 | # mgi2go[mgi].add(go) 130 | # print(len(mgi2go)) 131 | return mgi2go 132 | 133 | 134 | def get_go2gene(adata, GO_min_genes=500, GO_max_genes=None, GO_min_level=3, GO_max_level=3, 135 | data_dir='./filelists/tabula_muris/'): 136 | """ 137 | Returns processed tabula muris data in AnnData format. 138 | GO_min_genes: minimum number of genes assigned to GO required to keep GO term, used only if data is separated by GO 139 | categories (default: 500) 140 | GO_max_genes: maximum number of genes assigned to GO required to keep GO term, used only if data is separated by GO 141 | categories (default: None) 142 | GO_min_level: minimum level required to keep GO term, used only if data is separated by GO categories (default: 3) 143 | raw_data_dir: directory contaning raw data 144 | """ 145 | 146 | mgi2go = map_mgi2go(os.path.join(data_dir, "gene_association.mgi")) 147 | mgi2go_set = set(mgi2go.keys()) 148 | adata_set = set(adata.var_names) 149 | print("_________Gene count processed_________") 150 | print("mgi2go_set", len(mgi2go_set)) 151 | print("adata_set", len(adata_set)) 152 | print("union", len(adata_set & mgi2go_set)) 153 | # print("Not found", adata_set - mgi2go_set) 154 | GOobo_file = os.path.join(data_dir, "go-basic.obo") 155 | 156 | go2gene = prepare_GO_data(adata, mgi2go, GO_file=GOobo_file, 157 | GO_min_genes=GO_min_genes, GO_max_genes=GO_max_genes, GO_min_level=GO_min_level, GO_max_level=GO_max_level) 158 | 159 | return go2gene 160 | 161 | -------------------------------------------------------------------------------- /experiments/data/tabula_muris_comet/preprocess.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Jul 26, 2019 3 | 4 | @author: maria 5 | ''' 6 | 7 | from anndata import read_h5ad 8 | import scanpy as sc 9 | import pandas as pd 10 | from collections import Counter 11 | import numpy as np 12 | 13 | class MacaData(): 14 | 15 | def __init__(self, annotation_type='cell_ontology_class_reannotated', src_file = 'dataset/cell_data/tabula-muris-senis-facs-official-annotations.h5ad', filter_genes=True): 16 | 17 | """ 18 | annotation type: cell_ontology_class, cell_ontology id or free_annotation 19 | """ 20 | self.adata = read_h5ad(src_file) 21 | self.adata.obs[annotation_type] = self.adata.obs[annotation_type].astype(str) 22 | self.adata = self.adata[self.adata.obs[annotation_type]!='nan',:] 23 | self.adata = self.adata[self.adata.obs[annotation_type]!='NA',:] 24 | 25 | #print(Counter(self.adata.obs.loc[self.adata.obs['age']=='18m', 'free_annotation'])) 26 | 27 | self.cells2names = self.cellannotation2ID(annotation_type) 28 | 29 | if filter_genes: 30 | sc.pp.filter_genes(self.adata, min_cells=5) 31 | 32 | self.adata = self.preprocess_data(self.adata) 33 | 34 | 35 | def preprocess_data(self, adata): 36 | sc.pp.filter_cells(adata, min_counts=5000) 37 | sc.pp.filter_cells(adata, min_genes=500) 38 | 39 | sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4) #simple lib size normalization? 40 | adata.raw = adata 41 | adata = sc.pp.filter_genes_dispersion(adata, subset = False, min_disp=.5, max_disp=None, 42 | min_mean=.0125, max_mean=10, n_bins=20, n_top_genes=None, 43 | log=True, copy=True) 44 | adata = adata[:,adata.var.highly_variable] 45 | sc.pp.log1p(adata) 46 | sc.pp.scale(adata, max_value=10, zero_center=True) 47 | adata.X[np.isnan(adata.X)] = 0 48 | #sc.tl.pca(self.adata) 49 | 50 | return adata 51 | 52 | def get_tissue_data(self, tissue, age=None): 53 | """Select data for given tissue. 54 | filtered: if annotated return only cells with annotations, if unannotated return only cells without labels, else all 55 | age: '3m','18m', '24m', if None all ages are included 56 | """ 57 | 58 | tiss = self.adata[self.adata.obs['tissue'] == tissue,:] 59 | 60 | if age: 61 | return tiss[tiss.obs['age']==age] 62 | 63 | return tiss 64 | 65 | 66 | def cellannotation2ID(self, annotation_type): 67 | """Adds ground truth clusters data.""" 68 | annotations = list(self.adata.obs[annotation_type]) 69 | annotations_set = sorted(set(annotations)) 70 | 71 | mapping = {a:idx for idx,a in enumerate(annotations_set)} 72 | 73 | truth_labels = [mapping[a] for a in annotations] 74 | self.adata.obs['label'] = pd.Categorical(values=truth_labels) 75 | #18m-unannotated 76 | # 77 | return mapping 78 | 79 | if __name__ == '__main__': 80 | md = MacaData(src_file='../data/tabula_muris/tabula-muris-senis-facs-official-annotations.h5ad') 81 | tiss = md.get_tissue_data('Kidney') 82 | import pdb; pdb.set_trace() 83 | 84 | -------------------------------------------------------------------------------- /experiments/elens/L1_vs_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_lightning import seed_everything 3 | from torch.nn.functional import one_hot 4 | import matplotlib.pyplot as plt 5 | import seaborn as sns 6 | import os 7 | import pandas as pd 8 | 9 | from matplotlib import rc 10 | rc('text', usetex=True) 11 | 12 | import torch_explain as te 13 | from torch_explain.logic.metrics import test_explanation, complexity 14 | from torch_explain.logic.nn import entropy, psi 15 | from torch_explain.nn.functional import prune_equal_fanin 16 | 17 | RESULTS_DIR = './results/L1_vs_entropy' 18 | # RESULTS_DIR = './experiments/results/L1_vs_entropy' 19 | seed_everything(51) 20 | 21 | def main(): 22 | os.makedirs(RESULTS_DIR, exist_ok=True) 23 | 24 | # eye, nose, window, wheel, hand, radio 25 | x = torch.tensor([ 26 | [0, 0, 0, 0], 27 | [0, 1, 0, 0], 28 | [1, 0, 0, 0], 29 | [1, 1, 0, 0], 30 | [0, 0, 0, 0], 31 | [0, 0, 0, 1], 32 | [0, 0, 1, 0], 33 | [0, 0, 1, 1], 34 | ], dtype=torch.float) 35 | # human, car 36 | y = torch.tensor([ # 1, 0, 0, 1], dtype=torch.long) 37 | [0, 1, 0, 1], 38 | [1, 0, 0, 1], 39 | [1, 0, 0, 1], 40 | [0, 1, 0, 1], 41 | [0, 1, 0, 1], 42 | [0, 1, 1, 0], 43 | [0, 1, 1, 0], 44 | [0, 1, 1, 0], 45 | ], dtype=torch.float) 46 | y1h = y # one_hot(y) 47 | 48 | layers = [ 49 | te.nn.EntropyLinear(x.shape[1], 20, n_classes=y1h.shape[1], temperature=0.3), 50 | torch.nn.LeakyReLU(), 51 | torch.nn.Linear(20, 10), 52 | torch.nn.LeakyReLU(), 53 | torch.nn.Linear(10, 1), 54 | ] 55 | model = torch.nn.Sequential(*layers) 56 | concept_names = ['x1', 'x2', 'x3', 'x4'] # , 'hand', 'radio'] 57 | target_class_names = ['y', '¬y', 'z', '¬z'] 58 | 59 | optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001) 60 | loss_form = torch.nn.BCEWithLogitsLoss() 61 | model.train() 62 | 63 | max_epochs = 18001 64 | 65 | for epoch in range(max_epochs): 66 | # train step 67 | optimizer.zero_grad() 68 | y_pred = model(x).squeeze(-1) 69 | loss = loss_form(y_pred, y) + 0.0001 * te.nn.functional.entropy_logic_loss(model) 70 | # loss = loss_form(y_pred, y.argmax(dim=1)) + 0.00001 * te.nn.functional.entropy_logic_loss(model) 71 | # loss = loss_form(y_pred, y) + 0.001 * te.nn.functional.entropy_logic_loss(model) 72 | loss.backward() 73 | optimizer.step() 74 | 75 | # print() 76 | # print(layers[0].weight.grad[0].norm(dim=1)) 77 | # print(layers[0].weight.grad[1].norm(dim=1)) 78 | # print() 79 | 80 | # compute accuracy 81 | if epoch % 100 == 0: 82 | accuracy = (y_pred > 0.5).eq(y).sum().item() / (y.size(0) * y.size(1)) 83 | 84 | target_class = 0 85 | explanation_class_1, _ = entropy.explain_class(model, x, y1h, x, y1h, target_class, concept_names=concept_names) 86 | if explanation_class_1: explanation_class_1 = explanation_class_1.replace('&', '∧').replace('|', '∨').replace('~', '¬') 87 | explanation_class_1 = f'∀x: {explanation_class_1} ↔ {target_class_names[target_class]}' 88 | target_class = 1 89 | explanation_class_2, _ = entropy.explain_class(model, x, y1h, x, y1h, target_class, concept_names=concept_names) 90 | if explanation_class_2: explanation_class_2 = explanation_class_2.replace('&', '∧').replace('|', '∨').replace('~', '¬') 91 | explanation_class_2 = f'∀x: {explanation_class_2} ↔ {target_class_names[target_class]}' 92 | target_class = 2 93 | explanation_class_3, _ = entropy.explain_class(model, x, y1h, x, y1h, target_class, concept_names=concept_names) 94 | if explanation_class_3: explanation_class_3 = explanation_class_3.replace('&', '∧').replace('|', '∨').replace('~', '¬') 95 | explanation_class_3 = f'∀x: {explanation_class_3} ↔ {target_class_names[target_class]}' 96 | target_class = 3 97 | explanation_class_4, _ = entropy.explain_class(model, x, y1h, x, y1h, target_class, concept_names=concept_names) 98 | if explanation_class_4: explanation_class_4 = explanation_class_4.replace('&', '∧').replace('|', '∨').replace('~', '¬') 99 | explanation_class_4 = f'∀x: {explanation_class_4} ↔ {target_class_names[target_class]}' 100 | 101 | # update loss and accuracy 102 | print(f'Epoch {epoch}: loss {loss.item():.4f} train accuracy: {accuracy * 100:.2f}') 103 | print(f'\tAlphas class 1: {layers[0].alpha_norm}') 104 | print() 105 | 106 | df = pd.DataFrame(layers[0].alpha_norm.detach().numpy(), 107 | index=['y', '¬y', 'z', '¬z'], 108 | columns=['x1', 'x2', 'x3', 'x4']) 109 | 110 | plt.figure(figsize=[4, 3]) 111 | plt.title(r"Entropy concept scores $\tilde{\alpha}$") 112 | sns.heatmap(df, annot=True, fmt=".4f", vmin=0, vmax=1) 113 | plt.tight_layout() 114 | plt.savefig(os.path.join(RESULTS_DIR, 'entropy_heatmap.png')) 115 | plt.savefig(os.path.join(RESULTS_DIR, 'entropy_heatmap.pdf')) 116 | plt.show() 117 | 118 | 119 | layers = [ 120 | te.nn.EntropyLinear(x.shape[1], 20, n_classes=y1h.shape[1], temperature=0.3), 121 | torch.nn.LeakyReLU(), 122 | torch.nn.Linear(20, 10), 123 | torch.nn.LeakyReLU(), 124 | torch.nn.Linear(10, 1), 125 | ] 126 | model = torch.nn.Sequential(*layers) 127 | 128 | optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001) 129 | loss_form = torch.nn.BCEWithLogitsLoss() 130 | model.train() 131 | 132 | for epoch in range(max_epochs): 133 | # train step 134 | optimizer.zero_grad() 135 | y_pred = model(x).squeeze(-1) 136 | loss = loss_form(y_pred, y) + 0.0001 * te.nn.functional.l1_loss(model) 137 | loss.backward() 138 | optimizer.step() 139 | 140 | # compute accuracy 141 | if epoch % 100 == 0: 142 | accuracy = (y_pred > 0.5).eq(y).sum().item() / (y.size(0) * y.size(1)) 143 | 144 | # update loss and accuracy 145 | print(f'Epoch {epoch}: loss {loss.item():.4f} train accuracy: {accuracy * 100:.2f}') 146 | print(f'\tAlphas class 1: {layers[0].alpha_norm}') 147 | print() 148 | 149 | df = pd.DataFrame(layers[0].alpha_norm.detach().numpy(), 150 | index=['y', '¬y', 'z', '¬z'], 151 | columns=['x1', 'x2', 'x3', 'x4']) 152 | 153 | plt.figure(figsize=[4, 3]) 154 | plt.title(r"L1$^*$ concept scores $\tilde{\alpha}$") 155 | sns.heatmap(df, annot=True, fmt=".4f", vmin=0, vmax=1) 156 | plt.tight_layout() 157 | plt.savefig(os.path.join(RESULTS_DIR, 'L1_heatmap.png')) 158 | plt.savefig(os.path.join(RESULTS_DIR, 'L1_heatmap.pdf')) 159 | plt.show() 160 | 161 | 162 | 163 | layers = [ 164 | torch.nn.Linear(x.shape[1], 20), 165 | torch.nn.LeakyReLU(), 166 | torch.nn.Linear(20, 10), 167 | torch.nn.LeakyReLU(), 168 | torch.nn.Linear(10, y.shape[1]), 169 | ] 170 | model = torch.nn.Sequential(*layers) 171 | 172 | optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001) 173 | loss_form = torch.nn.BCEWithLogitsLoss() 174 | model.train() 175 | 176 | for epoch in range(max_epochs): 177 | # train step 178 | optimizer.zero_grad() 179 | y_pred = model(x).squeeze(-1) 180 | loss = loss_form(y_pred, y) + 0.0001 * te.nn.functional.l1_loss(model) 181 | loss.backward() 182 | optimizer.step() 183 | 184 | # compute accuracy 185 | if epoch % 100 == 0: 186 | accuracy = (y_pred > 0.5).eq(y).sum().item() / (y.size(0) * y.size(1)) 187 | 188 | # update loss and accuracy 189 | print(f'Epoch {epoch}: loss {loss.item():.4f} train accuracy: {accuracy * 100:.2f}') 190 | print(f'\tAlphas class 1: {layers[0].weight}') 191 | print() 192 | 193 | df = pd.DataFrame(layers[0].weight.detach().numpy(), 194 | index=[f'h{i}' for i in range(layers[0].weight.shape[0])], 195 | columns=['x1', 'x2', 'x3', 'x4']) 196 | 197 | 198 | plt.figure(figsize=[4, 6]) 199 | plt.title(r"L1 weights $W$") 200 | sns.heatmap(df, annot=True, fmt=".4f") 201 | plt.tight_layout() 202 | plt.savefig(os.path.join(RESULTS_DIR, 'L1_linear_heatmap.png')) 203 | plt.savefig(os.path.join(RESULTS_DIR, 'L1_linear_heatmap.pdf')) 204 | plt.show() 205 | 206 | # print(layers[0].alpha_norm) 207 | # print(layers[0].alpha_norm[0]) 208 | # print(layers[0].alpha_norm[1]) 209 | # print(layers[0].alpha_norm[2]) 210 | # print(layers[0].alpha_norm[3]) 211 | 212 | 213 | if __name__ == '__main__': 214 | main() 215 | -------------------------------------------------------------------------------- /experiments/elens/blackbox/cub.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import time 5 | from torch.utils.data import DataLoader 6 | from pytorch_lightning.callbacks import ModelCheckpoint 7 | from pytorch_lightning import Trainer, seed_everything 8 | 9 | from torch_explain.models.explainer import Explainer 10 | from torch_explain.logic.metrics import formula_consistency 11 | from experiments.data.load_datasets import load_cub 12 | 13 | train_data, val_data, test_data, concept_names = load_cub('../data') 14 | 15 | train_loader = DataLoader(train_data, batch_size=len(train_data)) 16 | val_loader = DataLoader(val_data, batch_size=len(val_data)) 17 | test_loader = DataLoader(test_data, batch_size=len(test_data)) 18 | n_concepts = next(iter(train_loader))[0].shape[1] 19 | n_classes = next(iter(train_loader))[1].shape[1] 20 | print(concept_names) 21 | print(n_concepts) 22 | print(n_classes) 23 | 24 | # %% md 25 | 26 | ## 5-fold cross-validation with explainer network 27 | 28 | base_dir = f'./results/CUB/blackbox' 29 | os.makedirs(base_dir, exist_ok=True) 30 | 31 | n_seeds = 5 32 | results_list = [] 33 | explanations = {i: [] for i in range(n_classes)} 34 | for seed in range(n_seeds): 35 | seed_everything(seed) 36 | print(f'Seed [{seed + 1}/{n_seeds}]') 37 | train_loader = DataLoader(train_data, batch_size=len(train_data)) 38 | val_loader = DataLoader(val_data, batch_size=len(val_data)) 39 | test_loader = DataLoader(test_data, batch_size=len(test_data)) 40 | 41 | checkpoint_callback = ModelCheckpoint(dirpath=base_dir, monitor='val_loss', save_top_k=1) 42 | trainer = Trainer(max_epochs=500, gpus=1, auto_lr_find=True, deterministic=True, 43 | check_val_every_n_epoch=1, default_root_dir=base_dir, 44 | weights_save_path=base_dir, callbacks=[checkpoint_callback]) 45 | model = Explainer(n_concepts=n_concepts, n_classes=n_classes, l1=0, lr=0.01, explainer_hidden=[10]) 46 | 47 | trainer.fit(model, train_loader, val_loader) 48 | print(f"Concept mask: {model.model[0].concept_mask}") 49 | model.freeze() 50 | model_results = trainer.test(model, test_dataloaders=test_loader) 51 | for j in range(n_classes): 52 | n_used_concepts = sum(model.model[0].concept_mask[j] > 0.5) 53 | print(f"Extracted concepts: {n_used_concepts}") 54 | results = {} 55 | results['model_accuracy'] = model_results[0]['test_acc'] 56 | 57 | results_list.append(results) 58 | 59 | results_df = pd.DataFrame(results_list) 60 | results_df.to_csv(os.path.join(base_dir, 'results_aware_cub.csv')) 61 | 62 | results_df = pd.DataFrame(results_list) 63 | results_df.to_csv(os.path.join(base_dir, 'results_aware_cub.csv')) 64 | results_df 65 | 66 | # %% 67 | 68 | results_df.mean() 69 | 70 | # %% 71 | 72 | results_df.sem() 73 | 74 | 75 | print(f'Mu net scores (model): {results_df["model_accuracy"].mean()} (+/- {results_df["model_accuracy"].std()})') -------------------------------------------------------------------------------- /experiments/elens/blackbox/mimic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import time 5 | import torch 6 | from torch.utils.data import DataLoader, TensorDataset, random_split 7 | from pytorch_lightning.callbacks import ModelCheckpoint 8 | from pytorch_lightning import Trainer, seed_everything 9 | from sklearn.ensemble import RandomForestClassifier 10 | from sklearn.tree import DecisionTreeClassifier 11 | from sklearn.model_selection import StratifiedKFold, train_test_split 12 | from sklearn.feature_selection import mutual_info_classif, chi2 13 | from sklearn.linear_model import LassoCV 14 | import matplotlib.pyplot as plt 15 | import seaborn as sns 16 | 17 | from torch_explain.models.explainer import Explainer 18 | from torch_explain.logic.metrics import formula_consistency 19 | from experiments.data.load_datasets import load_mimic 20 | 21 | x, y, concept_names = load_mimic(base_dir='../data') 22 | 23 | dataset = TensorDataset(x, y) 24 | train_size = int(len(dataset) * 0.5) 25 | val_size = (len(dataset) - train_size) // 2 26 | test_size = len(dataset) - train_size - val_size 27 | train_data, val_data, test_data = random_split(dataset, [train_size, val_size, test_size]) 28 | train_loader = DataLoader(train_data, batch_size=train_size) 29 | val_loader = DataLoader(val_data, batch_size=val_size) 30 | test_loader = DataLoader(test_data, batch_size=test_size) 31 | n_concepts = next(iter(train_loader))[0].shape[1] 32 | n_classes = 2 33 | print(concept_names) 34 | print(n_concepts) 35 | print(n_classes) 36 | 37 | # %% md 38 | 39 | ## 5-fold cross-validation with explainer network 40 | 41 | # %% 42 | 43 | seed_everything(42) 44 | 45 | base_dir = f'./results/mimic-ii/blackbox' 46 | os.makedirs(base_dir, exist_ok=True) 47 | 48 | n_splits = 5 49 | skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42) 50 | results_list = [] 51 | feature_selection = [] 52 | explanations = {i: [] for i in range(n_classes)} 53 | for split, (trainval_index, test_index) in enumerate(skf.split(x.cpu().detach().numpy(), 54 | y.argmax(dim=1).cpu().detach().numpy())): 55 | print(f'Split [{split + 1}/{n_splits}]') 56 | x_trainval, x_test = torch.FloatTensor(x[trainval_index]), torch.FloatTensor(x[test_index]) 57 | y_trainval, y_test = torch.FloatTensor(y[trainval_index]), torch.FloatTensor(y[test_index]) 58 | x_train, x_val, y_train, y_val = train_test_split(x_trainval, y_trainval, test_size=0.2, random_state=42) 59 | print(f'{len(y_train)}/{len(y_val)}/{len(y_test)}') 60 | 61 | train_data = TensorDataset(x_train, y_train) 62 | val_data = TensorDataset(x_val, y_val) 63 | test_data = TensorDataset(x_test, y_test) 64 | train_loader = DataLoader(train_data, batch_size=train_size) 65 | val_loader = DataLoader(val_data, batch_size=val_size) 66 | test_loader = DataLoader(test_data, batch_size=test_size) 67 | 68 | checkpoint_callback = ModelCheckpoint(dirpath=base_dir, monitor='val_loss', save_top_k=1) 69 | trainer = Trainer(max_epochs=200, gpus=1, auto_lr_find=True, deterministic=True, 70 | check_val_every_n_epoch=1, default_root_dir=base_dir, 71 | weights_save_path=base_dir, callbacks=[checkpoint_callback]) 72 | model = Explainer(n_concepts=n_concepts, n_classes=n_classes, l1=0, lr=0.01, explainer_hidden=[20]) 73 | 74 | trainer.fit(model, train_loader, val_loader) 75 | print(f"Gamma: {model.model[0].concept_mask}") 76 | model.freeze() 77 | model_results = trainer.test(model, test_dataloaders=test_loader) 78 | for j in range(n_classes): 79 | n_used_concepts = sum(model.model[0].concept_mask[j] > 0.5) 80 | print(f"Extracted concepts: {n_used_concepts}") 81 | results = {} 82 | results['model_accuracy'] = model_results[0]['test_acc'] 83 | 84 | results_list.append(results) 85 | 86 | results_df = pd.DataFrame(results_list) 87 | results_df.to_csv(os.path.join(base_dir, 'results_aware_mimic.csv')) 88 | results_df 89 | 90 | # %% 91 | 92 | results_df.mean() 93 | 94 | # %% 95 | 96 | results_df.sem() 97 | 98 | # %% md 99 | 100 | ## Compare with out-of-the-box models 101 | 102 | # %% 103 | 104 | dt_scores, rf_scores = [], [] 105 | for split, (trainval_index, test_index) in enumerate( 106 | skf.split(x.cpu().detach().numpy(), y.argmax(dim=1).cpu().detach().numpy())): 107 | print(f'Split [{split + 1}/{n_splits}]') 108 | x_trainval, x_test = x[trainval_index], x[test_index] 109 | y_trainval, y_test = y[trainval_index].argmax(dim=1), y[test_index].argmax(dim=1) 110 | 111 | dt_model = DecisionTreeClassifier(max_depth=5, random_state=split) 112 | dt_model.fit(x_trainval, y_trainval) 113 | dt_scores.append(dt_model.score(x_test, y_test)) 114 | 115 | rf_model = RandomForestClassifier(random_state=split) 116 | rf_model.fit(x_trainval, y_trainval) 117 | rf_scores.append(rf_model.score(x_test, y_test)) 118 | 119 | print(f'Random forest scores: {np.mean(rf_scores)} (+/- {np.std(rf_scores)})') 120 | print(f'Decision tree scores: {np.mean(dt_scores)} (+/- {np.std(dt_scores)})') 121 | print(f'Mu net scores (model): {results_df["model_accuracy"].mean()} (+/- {results_df["model_accuracy"].std()})') 122 | -------------------------------------------------------------------------------- /experiments/elens/blackbox/mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import time 5 | from torch.utils.data import DataLoader, TensorDataset, random_split 6 | from pytorch_lightning.callbacks import ModelCheckpoint 7 | from pytorch_lightning import Trainer, seed_everything 8 | 9 | from torch_explain.models.explainer import Explainer 10 | from torch_explain.logic.metrics import formula_consistency 11 | from experiments.data.load_datasets import load_mnist 12 | 13 | # %% md 14 | 15 | ## Import MIMIC-II dataset 16 | 17 | # %% 18 | x, y, concept_names = load_mnist('../data') 19 | 20 | 21 | dataset = TensorDataset(x, y) 22 | train_size = int(len(dataset) * 0.9) 23 | val_size = (len(dataset) - train_size) // 2 24 | test_size = len(dataset) - train_size - val_size 25 | train_data, val_data, test_data = random_split(dataset, [train_size, val_size, test_size]) 26 | train_loader = DataLoader(train_data, batch_size=len(train_data)) 27 | val_loader = DataLoader(val_data, batch_size=len(val_data)) 28 | test_loader = DataLoader(test_data, batch_size=len(test_data)) 29 | n_concepts = next(iter(train_loader))[0].shape[1] 30 | n_classes = 2 31 | print(concept_names) 32 | print(n_concepts) 33 | print(n_classes) 34 | 35 | # %% md 36 | 37 | ## 5-fold cross-validation with explainer network 38 | 39 | base_dir = f'./results/MNIST/blackbox' 40 | os.makedirs(base_dir, exist_ok=True) 41 | 42 | n_seeds = 5 43 | results_list = [] 44 | explanations = {i: [] for i in range(n_classes)} 45 | for seed in range(n_seeds): 46 | seed_everything(seed) 47 | print(f'Seed [{seed + 1}/{n_seeds}]') 48 | train_loader = DataLoader(train_data, batch_size=len(train_data)) 49 | val_loader = DataLoader(val_data, batch_size=len(val_data)) 50 | test_loader = DataLoader(test_data, batch_size=len(test_data)) 51 | 52 | checkpoint_callback = ModelCheckpoint(dirpath=base_dir, monitor='val_loss', save_top_k=1) 53 | trainer = Trainer(max_epochs=10, gpus=1, auto_lr_find=True, deterministic=True, 54 | check_val_every_n_epoch=1, default_root_dir=base_dir, 55 | weights_save_path=base_dir, callbacks=[checkpoint_callback]) 56 | model = Explainer(n_concepts=n_concepts, n_classes=n_classes, l1=0, lr=0.01, 57 | explainer_hidden=[10], conceptizator='identity_bool') 58 | 59 | trainer.fit(model, train_loader, val_loader) 60 | print(f"Concept mask: {model.model[0].concept_mask}") 61 | model.freeze() 62 | model_results = trainer.test(model, test_dataloaders=test_loader) 63 | for j in range(n_classes): 64 | n_used_concepts = sum(model.model[0].concept_mask[j] > 0.5) 65 | print(f"Extracted concepts: {n_used_concepts}") 66 | results = {} 67 | results['model_accuracy'] = model_results[0]['test_acc'] 68 | 69 | results_list.append(results) 70 | 71 | results_df = pd.DataFrame(results_list) 72 | results_df.to_csv(os.path.join(base_dir, 'results_aware_mnist.csv')) 73 | results_df 74 | 75 | # %% 76 | 77 | results_df.mean() 78 | 79 | # %% 80 | 81 | results_df.sem() 82 | 83 | 84 | print(f'Mu net scores (model): {results_df["model_accuracy"].mean()} (+/- {results_df["model_accuracy"].std()})') -------------------------------------------------------------------------------- /experiments/elens/blackbox/vdem.py: -------------------------------------------------------------------------------- 1 | # %% md 2 | 3 | # Varieties of Democracy (vDem) 4 | import os 5 | import pandas as pd 6 | import numpy as np 7 | import time 8 | import torch 9 | from torch.utils.data import DataLoader, TensorDataset, random_split 10 | from pytorch_lightning.callbacks import ModelCheckpoint 11 | from pytorch_lightning import Trainer, seed_everything 12 | from sklearn.ensemble import RandomForestClassifier 13 | from sklearn.tree import DecisionTreeClassifier 14 | from sklearn.model_selection import StratifiedKFold, train_test_split 15 | from sklearn.feature_selection import mutual_info_classif, chi2 16 | from sklearn.linear_model import LassoCV 17 | import matplotlib.pyplot as plt 18 | import seaborn as sns 19 | 20 | from torch_explain.models.explainer import Explainer 21 | from torch_explain.logic.metrics import formula_consistency 22 | from experiments.data.load_datasets import load_vDem 23 | 24 | # %% md 25 | 26 | ## Import v-Dem dataset 27 | 28 | # %% 29 | 30 | x, c, y, concept_names = load_vDem('../data') 31 | 32 | dataset_xc = TensorDataset(x, c) 33 | dataset_cy = TensorDataset(c, y) 34 | 35 | train_size = int(len(dataset_cy) * 0.5) 36 | val_size = (len(dataset_cy) - train_size) // 2 37 | test_size = len(dataset_cy) - train_size - val_size 38 | train_data, val_data, test_data = random_split(dataset_cy, [train_size, val_size, test_size]) 39 | train_loader = DataLoader(train_data, batch_size=train_size) 40 | val_loader = DataLoader(val_data, batch_size=val_size) 41 | test_loader = DataLoader(test_data, batch_size=test_size) 42 | 43 | n_concepts = next(iter(train_loader))[0].shape[1] 44 | n_classes = 2 45 | 46 | print(concept_names) 47 | print(n_concepts) 48 | print(n_classes) 49 | 50 | # %% md 51 | 52 | ## 10-fold cross-validation with explainer network 53 | 54 | # %% 55 | 56 | seed_everything(42) 57 | 58 | base_dir = f'./results/vdem/blackbox' 59 | os.makedirs(base_dir, exist_ok=True) 60 | 61 | n_splits = 5 62 | skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42) 63 | results_list = [] 64 | feature_selection = [] 65 | explanations = {i: [] for i in range(n_classes)} 66 | for split, (trainval_index, test_index) in enumerate(skf.split(x.cpu().detach().numpy(), 67 | y.argmax(dim=1).cpu().detach().numpy())): 68 | print(f'Split [{split + 1}/{n_splits}]') 69 | x_trainval, x_test = torch.FloatTensor(x[trainval_index]), torch.FloatTensor(x[test_index]) 70 | c_trainval, c_test = torch.FloatTensor(c[trainval_index]), torch.FloatTensor(c[test_index]) 71 | y_trainval, y_test = torch.FloatTensor(y[trainval_index]), torch.FloatTensor(y[test_index]) 72 | x_train, x_val, c_train, c_val, y_train, y_val = train_test_split(x_trainval, c_trainval, y_trainval, 73 | test_size=0.2, random_state=42) 74 | print(f'{len(y_train)}/{len(y_val)}/{len(y_test)}') 75 | 76 | # train X->C 77 | train_data_xc = TensorDataset(x_train, c_train) 78 | val_data_xc = TensorDataset(x_val, c_val) 79 | test_data_xc = TensorDataset(x_test, c_test) 80 | train_loader_xc = DataLoader(train_data_xc, batch_size=train_size) 81 | val_loader_xc = DataLoader(val_data_xc, batch_size=val_size) 82 | test_loader_xc = DataLoader(test_data_xc, batch_size=test_size) 83 | 84 | checkpoint_callback_xc = ModelCheckpoint(dirpath=base_dir, monitor='val_loss', save_top_k=1) 85 | trainer_xc = Trainer(max_epochs=200, gpus=1, auto_lr_find=True, deterministic=True, 86 | check_val_every_n_epoch=1, default_root_dir=base_dir + '_xc', 87 | weights_save_path=base_dir, callbacks=[checkpoint_callback_xc]) 88 | model_xc = Explainer(n_concepts=x.shape[1], n_classes=c.shape[1], l1=0, lr=0.01, 89 | explainer_hidden=[100, 50], temperature=5000, loss=torch.nn.BCEWithLogitsLoss()) 90 | trainer_xc.fit(model_xc, train_loader_xc, val_loader_xc) 91 | model_xc.freeze() 92 | c_train_pred = model_xc.model(x_train) 93 | c_val_pred = model_xc.model(x_val) 94 | c_test_pred = model_xc.model(x_test) 95 | 96 | # train C->Y 97 | train_data = TensorDataset(c_train_pred.squeeze(), y_train) 98 | val_data = TensorDataset(c_val_pred.squeeze(), y_val) 99 | test_data = TensorDataset(c_test_pred.squeeze(), y_test) 100 | train_loader = DataLoader(train_data, batch_size=train_size) 101 | val_loader = DataLoader(val_data, batch_size=val_size) 102 | test_loader = DataLoader(test_data, batch_size=test_size) 103 | 104 | checkpoint_callback = ModelCheckpoint(dirpath=base_dir, monitor='val_loss', save_top_k=1) 105 | trainer = Trainer(max_epochs=200, gpus=1, auto_lr_find=True, deterministic=True, 106 | check_val_every_n_epoch=1, default_root_dir=base_dir, 107 | weights_save_path=base_dir, callbacks=[checkpoint_callback]) 108 | model = Explainer(n_concepts=n_concepts, n_classes=n_classes, l1=0, lr=0.01, explainer_hidden=[20, 20]) 109 | 110 | trainer.fit(model, train_loader, val_loader) 111 | model.freeze() 112 | model_results = trainer.test(model, test_dataloaders=test_loader) 113 | for j in range(n_classes): 114 | n_used_concepts = sum(model.model[0].concept_mask[j] > 0.5) 115 | print(f"Extracted concepts: {n_used_concepts}") 116 | results = {} 117 | results['model_accuracy'] = model_results[0]['test_acc'] 118 | 119 | results_list.append(results) 120 | 121 | results_df = pd.DataFrame(results_list) 122 | results_df.to_csv(os.path.join(base_dir, 'results_aware_vdem.csv')) 123 | results_df 124 | 125 | # %% 126 | 127 | results_df.mean() 128 | 129 | # %% 130 | 131 | results_df.sem() 132 | 133 | # %% md 134 | 135 | ## Compare with out-of-the-box models 136 | 137 | # %% 138 | 139 | dt_scores, rf_scores = [], [] 140 | for split, (trainval_index, test_index) in enumerate( 141 | skf.split(x.cpu().detach().numpy(), y.argmax(dim=1).cpu().detach().numpy())): 142 | print(f'Split [{split + 1}/{n_splits}]') 143 | x_trainval, x_test = x[trainval_index], x[test_index] 144 | y_trainval, y_test = y[trainval_index].argmax(dim=1), y[test_index].argmax(dim=1) 145 | 146 | dt_model = DecisionTreeClassifier(max_depth=5, random_state=split) 147 | dt_model.fit(x_trainval, y_trainval) 148 | dt_scores.append(dt_model.score(x_test, y_test)) 149 | 150 | rf_model = RandomForestClassifier(random_state=split) 151 | rf_model.fit(x_trainval, y_trainval) 152 | rf_scores.append(rf_model.score(x_test, y_test)) 153 | 154 | print(f'Random forest scores: {np.mean(rf_scores)} (+/- {np.std(rf_scores)})') 155 | print(f'Decision tree scores: {np.mean(dt_scores)} (+/- {np.std(dt_scores)})') 156 | print(f'Mu net scores (model): {results_df["model_accuracy"].mean()} (+/- {results_df["model_accuracy"].std()})') 157 | -------------------------------------------------------------------------------- /experiments/elens/cub.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import time 5 | from torch.utils.data import DataLoader 6 | from pytorch_lightning.callbacks import ModelCheckpoint 7 | from pytorch_lightning import Trainer, seed_everything 8 | 9 | from torch_explain.models.explainer import Explainer 10 | from torch_explain.logic.metrics import formula_consistency 11 | from experiments.data.load_datasets import load_cub 12 | 13 | train_data, val_data, test_data, concept_names = load_cub() 14 | 15 | train_loader = DataLoader(train_data, batch_size=len(train_data)) 16 | val_loader = DataLoader(val_data, batch_size=len(val_data)) 17 | test_loader = DataLoader(test_data, batch_size=len(test_data)) 18 | n_concepts = next(iter(train_loader))[0].shape[1] 19 | n_classes = next(iter(train_loader))[1].shape[1] 20 | print(concept_names) 21 | print(n_concepts) 22 | print(n_classes) 23 | 24 | # %% md 25 | 26 | ## 5-fold cross-validation with explainer network 27 | 28 | base_dir = f'./results/CUB/explainer' 29 | os.makedirs(base_dir, exist_ok=True) 30 | 31 | n_seeds = 5 32 | results_list = [] 33 | explanations = {i: [] for i in range(n_classes)} 34 | for seed in range(n_seeds): 35 | seed_everything(seed) 36 | print(f'Seed [{seed + 1}/{n_seeds}]') 37 | train_loader = DataLoader(train_data, batch_size=len(train_data)) 38 | val_loader = DataLoader(val_data, batch_size=len(val_data)) 39 | test_loader = DataLoader(test_data, batch_size=len(test_data)) 40 | 41 | checkpoint_callback = ModelCheckpoint(dirpath=base_dir, monitor='val_loss', save_top_k=1) 42 | trainer = Trainer(max_epochs=500, gpus=1, auto_lr_find=True, deterministic=True, 43 | check_val_every_n_epoch=1, default_root_dir=base_dir, 44 | weights_save_path=base_dir, callbacks=[checkpoint_callback]) 45 | model = Explainer(n_concepts=n_concepts, n_classes=n_classes, l1=0.0001, 46 | temperature=0.7, lr=0.01, explainer_hidden=[10]) 47 | 48 | start = time.time() 49 | trainer.fit(model, train_loader, val_loader) 50 | print(f"Concept mask: {model.model[0].concept_mask}") 51 | model.freeze() 52 | model_results = trainer.test(model, test_dataloaders=test_loader) 53 | for j in range(n_classes): 54 | n_used_concepts = sum(model.model[0].concept_mask[j] > 0.5) 55 | print(f"Extracted concepts: {n_used_concepts}") 56 | results, f = model.explain_class(val_loader, train_loader, test_loader, topk_explanations=50, 57 | concept_names=concept_names, verbose=True) 58 | end = time.time() - start 59 | results['model_accuracy'] = model_results[0]['test_acc'] 60 | results['extraction_time'] = end 61 | 62 | results_list.append(results) 63 | extracted_concepts = [] 64 | all_concepts = model.model[0].concept_mask[0] > 0.5 65 | common_concepts = model.model[0].concept_mask[0] > 0.5 66 | for j in range(n_classes): 67 | n_used_concepts = sum(model.model[0].concept_mask[j] > 0.5) 68 | print(f"Extracted concepts: {n_used_concepts}") 69 | print(f"Explanation: {f[j]['explanation']}") 70 | print(f"Explanation accuracy: {f[j]['explanation_accuracy']}") 71 | if f[j]['explanation'] is not None: 72 | explanations[j].append(f[j]['explanation']) 73 | extracted_concepts.append(n_used_concepts) 74 | all_concepts += model.model[0].concept_mask[j] > 0.5 75 | common_concepts *= model.model[0].concept_mask[j] > 0.5 76 | 77 | results['extracted_concepts'] = np.mean(extracted_concepts) 78 | results['common_concepts_ratio'] = sum(common_concepts) / sum(all_concepts) 79 | # break 80 | 81 | results_df = pd.DataFrame(results_list) 82 | results_df.to_csv(os.path.join(base_dir, 'results_aware_cub.csv')) 83 | 84 | consistencies = [] 85 | for j in range(n_classes): 86 | consistencies.append(formula_consistency(explanations[j])) 87 | explanation_consistency = np.mean(consistencies) 88 | 89 | results_df = pd.DataFrame(results_list) 90 | results_df['explanation_consistency'] = explanation_consistency 91 | results_df.to_csv(os.path.join(base_dir, 'results_aware_cub.csv')) 92 | results_df 93 | 94 | # %% 95 | 96 | results_df.mean() 97 | 98 | # %% 99 | 100 | results_df.sem() 101 | -------------------------------------------------------------------------------- /experiments/elens/hyperparams/cub.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import time 5 | from torch.utils.data import DataLoader 6 | from pytorch_lightning.callbacks import ModelCheckpoint 7 | from pytorch_lightning import Trainer, seed_everything 8 | 9 | from torch_explain.models.explainer import Explainer 10 | from torch_explain.logic.metrics import formula_consistency 11 | from experiments.data.load_datasets import load_cub 12 | 13 | train_data, val_data, test_data, concept_names = load_cub('../data') 14 | 15 | train_loader = DataLoader(train_data, batch_size=len(train_data)) 16 | val_loader = DataLoader(val_data, batch_size=len(val_data)) 17 | test_loader = DataLoader(test_data, batch_size=len(test_data)) 18 | n_concepts = next(iter(train_loader))[0].shape[1] 19 | n_classes = next(iter(train_loader))[1].shape[1] 20 | print(concept_names) 21 | print(n_concepts) 22 | print(n_classes) 23 | 24 | # %% md 25 | 26 | ## 5-fold cross-validation with explainer network 27 | 28 | base_dir = f'./results/CUB/blackbox' 29 | os.makedirs(base_dir, exist_ok=True) 30 | 31 | n_seeds = 5 32 | results_list = [] 33 | explanations = {i: [] for i in range(n_classes)} 34 | for seed in range(n_seeds): 35 | seed_everything(seed) 36 | print(f'Seed [{seed + 1}/{n_seeds}]') 37 | train_loader = DataLoader(train_data, batch_size=len(train_data)) 38 | val_loader = DataLoader(val_data, batch_size=len(val_data)) 39 | test_loader = DataLoader(test_data, batch_size=len(test_data)) 40 | 41 | checkpoint_callback = ModelCheckpoint(dirpath=base_dir, monitor='val_loss', save_top_k=1) 42 | trainer = Trainer(max_epochs=500, gpus=1, auto_lr_find=True, deterministic=True, 43 | check_val_every_n_epoch=1, default_root_dir=base_dir, 44 | weights_save_path=base_dir, callbacks=[checkpoint_callback]) 45 | model = Explainer(n_concepts=n_concepts, n_classes=n_classes, l1=0, lr=0.01, explainer_hidden=[10]) 46 | 47 | trainer.fit(model, train_loader, val_loader) 48 | print(f"Concept mask: {model.model[0].concept_mask}") 49 | model.freeze() 50 | model_results = trainer.test(model, test_dataloaders=test_loader) 51 | for j in range(n_classes): 52 | n_used_concepts = sum(model.model[0].concept_mask[j] > 0.5) 53 | print(f"Extracted concepts: {n_used_concepts}") 54 | results = {} 55 | results['model_accuracy'] = model_results[0]['test_acc'] 56 | 57 | results_list.append(results) 58 | 59 | results_df = pd.DataFrame(results_list) 60 | results_df.to_csv(os.path.join(base_dir, 'results_aware_cub.csv')) 61 | 62 | results_df = pd.DataFrame(results_list) 63 | results_df.to_csv(os.path.join(base_dir, 'results_aware_cub.csv')) 64 | results_df 65 | 66 | # %% 67 | 68 | results_df.mean() 69 | 70 | # %% 71 | 72 | results_df.sem() 73 | 74 | 75 | print(f'Mu net scores (model): {results_df["model_accuracy"].mean()} (+/- {results_df["model_accuracy"].std()})') -------------------------------------------------------------------------------- /experiments/elens/hyperparams/mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import time 5 | from torch.utils.data import DataLoader, TensorDataset, random_split 6 | from pytorch_lightning.callbacks import ModelCheckpoint 7 | from pytorch_lightning import Trainer, seed_everything 8 | 9 | from torch_explain.models.explainer import Explainer 10 | from torch_explain.logic.metrics import formula_consistency 11 | from experiments.data.load_datasets import load_mnist 12 | 13 | # %% md 14 | 15 | ## Import MIMIC-II dataset 16 | 17 | # %% 18 | x, y, concept_names = load_mnist('../data') 19 | 20 | 21 | dataset = TensorDataset(x, y) 22 | train_size = int(len(dataset) * 0.9) 23 | val_size = (len(dataset) - train_size) // 2 24 | test_size = len(dataset) - train_size - val_size 25 | train_data, val_data, test_data = random_split(dataset, [train_size, val_size, test_size]) 26 | train_loader = DataLoader(train_data, batch_size=len(train_data)) 27 | val_loader = DataLoader(val_data, batch_size=len(val_data)) 28 | test_loader = DataLoader(test_data, batch_size=len(test_data)) 29 | n_concepts = next(iter(train_loader))[0].shape[1] 30 | n_classes = 2 31 | print(concept_names) 32 | print(n_concepts) 33 | print(n_classes) 34 | 35 | # %% md 36 | 37 | ## 5-fold cross-validation with explainer network 38 | 39 | base_dir = f'./results/MNIST/blackbox' 40 | os.makedirs(base_dir, exist_ok=True) 41 | 42 | n_seeds = 5 43 | results_list = [] 44 | explanations = {i: [] for i in range(n_classes)} 45 | for seed in range(n_seeds): 46 | seed_everything(seed) 47 | print(f'Seed [{seed + 1}/{n_seeds}]') 48 | train_loader = DataLoader(train_data, batch_size=len(train_data)) 49 | val_loader = DataLoader(val_data, batch_size=len(val_data)) 50 | test_loader = DataLoader(test_data, batch_size=len(test_data)) 51 | 52 | checkpoint_callback = ModelCheckpoint(dirpath=base_dir, monitor='val_loss', save_top_k=1) 53 | trainer = Trainer(max_epochs=10, gpus=1, auto_lr_find=True, deterministic=True, 54 | check_val_every_n_epoch=1, default_root_dir=base_dir, 55 | weights_save_path=base_dir, callbacks=[checkpoint_callback]) 56 | model = Explainer(n_concepts=n_concepts, n_classes=n_classes, l1=0, lr=0.01, 57 | explainer_hidden=[10], conceptizator='identity_bool') 58 | 59 | trainer.fit(model, train_loader, val_loader) 60 | print(f"Concept mask: {model.model[0].concept_mask}") 61 | model.freeze() 62 | model_results = trainer.test(model, test_dataloaders=test_loader) 63 | for j in range(n_classes): 64 | n_used_concepts = sum(model.model[0].concept_mask[j] > 0.5) 65 | print(f"Extracted concepts: {n_used_concepts}") 66 | results = {} 67 | results['model_accuracy'] = model_results[0]['test_acc'] 68 | 69 | results_list.append(results) 70 | 71 | results_df = pd.DataFrame(results_list) 72 | results_df.to_csv(os.path.join(base_dir, 'results_aware_mnist.csv')) 73 | results_df 74 | 75 | # %% 76 | 77 | results_df.mean() 78 | 79 | # %% 80 | 81 | results_df.sem() 82 | 83 | 84 | print(f'Mu net scores (model): {results_df["model_accuracy"].mean()} (+/- {results_df["model_accuracy"].std()})') -------------------------------------------------------------------------------- /experiments/elens/hyperparams/vdem.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import time 5 | import torch 6 | from torch.utils.data import DataLoader, TensorDataset, random_split 7 | from pytorch_lightning.callbacks import ModelCheckpoint 8 | from pytorch_lightning import Trainer, seed_everything 9 | from sklearn.ensemble import RandomForestClassifier 10 | from sklearn.tree import DecisionTreeClassifier 11 | from sklearn.model_selection import StratifiedKFold, train_test_split 12 | from sklearn.feature_selection import mutual_info_classif, chi2 13 | from sklearn.linear_model import LassoCV 14 | import matplotlib.pyplot as plt 15 | import seaborn as sns 16 | 17 | from torch_explain.models.explainer import Explainer 18 | from torch_explain.logic.metrics import formula_consistency 19 | from experiments.data.load_datasets import load_vDem 20 | 21 | 22 | x, c, y, concept_names = load_vDem('../data') 23 | 24 | dataset_xc = TensorDataset(x, c) 25 | dataset_cy = TensorDataset(c, y) 26 | 27 | train_size = int(len(dataset_cy) * 0.5) 28 | val_size = (len(dataset_cy) - train_size) // 2 29 | test_size = len(dataset_cy) - train_size - val_size 30 | train_data, val_data, test_data = random_split(dataset_cy, [train_size, val_size, test_size]) 31 | train_loader = DataLoader(train_data, batch_size=train_size) 32 | val_loader = DataLoader(val_data, batch_size=val_size) 33 | test_loader = DataLoader(test_data, batch_size=test_size) 34 | 35 | n_concepts = next(iter(train_loader))[0].shape[1] 36 | n_classes = 2 37 | 38 | print(concept_names) 39 | print(n_concepts) 40 | print(n_classes) 41 | 42 | # %% md 43 | 44 | ## 5-fold cross-validation with explainer network 45 | 46 | # %% 47 | 48 | seed_everything(42) 49 | 50 | base_dir = f'./results/vdem/blackbox' 51 | os.makedirs(base_dir, exist_ok=True) 52 | 53 | n_splits = 5 54 | results_list = [] 55 | 56 | for l1 in [1e-3, 1e-4, 1e-5, 1e-6]: 57 | for tau in [4, 5, 6]: 58 | skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42) 59 | for split, (trainval_index, test_index) in enumerate(skf.split(x.cpu().detach().numpy(), 60 | y.argmax(dim=1).cpu().detach().numpy())): 61 | print(f'Split [{split + 1}/{n_splits}]') 62 | x_trainval, x_test = torch.FloatTensor(x[trainval_index]), torch.FloatTensor(x[test_index]) 63 | c_trainval, c_test = torch.FloatTensor(c[trainval_index]), torch.FloatTensor(c[test_index]) 64 | y_trainval, y_test = torch.FloatTensor(y[trainval_index]), torch.FloatTensor(y[test_index]) 65 | x_train, x_val, c_train, c_val, y_train, y_val = train_test_split(x_trainval, c_trainval, y_trainval, 66 | test_size=0.2, random_state=42) 67 | print(f'{len(y_train)}/{len(y_val)}/{len(y_test)}') 68 | 69 | # train X->C 70 | train_data_xc = TensorDataset(x_train, c_train) 71 | val_data_xc = TensorDataset(x_val, c_val) 72 | test_data_xc = TensorDataset(x_test, c_test) 73 | train_loader_xc = DataLoader(train_data_xc, batch_size=train_size) 74 | val_loader_xc = DataLoader(val_data_xc, batch_size=val_size) 75 | test_loader_xc = DataLoader(test_data_xc, batch_size=test_size) 76 | 77 | checkpoint_callback_xc = ModelCheckpoint(dirpath=base_dir, monitor='val_loss', save_top_k=1) 78 | trainer_xc = Trainer(max_epochs=200, gpus=1, auto_lr_find=True, deterministic=True, 79 | check_val_every_n_epoch=1, default_root_dir=base_dir + '_xc', 80 | weights_save_path=base_dir, callbacks=[checkpoint_callback_xc]) 81 | model_xc = Explainer(n_concepts=x.shape[1], n_classes=c.shape[1], l1=0, lr=0.01, 82 | explainer_hidden=[100, 50], temperature=5000, loss=torch.nn.BCEWithLogitsLoss()) 83 | trainer_xc.fit(model_xc, train_loader_xc, val_loader_xc) 84 | model_xc.freeze() 85 | c_train_pred = model_xc.model(x_train) 86 | c_val_pred = model_xc.model(x_val) 87 | c_test_pred = model_xc.model(x_test) 88 | 89 | # train C->Y 90 | train_data = TensorDataset(c_train_pred.squeeze(), y_train) 91 | val_data = TensorDataset(c_val_pred.squeeze(), y_val) 92 | test_data = TensorDataset(c_test_pred.squeeze(), y_test) 93 | train_loader = DataLoader(train_data, batch_size=train_size) 94 | val_loader = DataLoader(val_data, batch_size=val_size) 95 | test_loader = DataLoader(test_data, batch_size=test_size) 96 | 97 | checkpoint_callback = ModelCheckpoint(dirpath=base_dir, monitor='val_loss', save_top_k=1) 98 | trainer = Trainer(max_epochs=200, gpus=1, auto_lr_find=True, deterministic=True, 99 | check_val_every_n_epoch=1, default_root_dir=base_dir, 100 | weights_save_path=base_dir, callbacks=[checkpoint_callback]) 101 | model = Explainer(n_concepts=n_concepts, n_classes=n_classes, lr=0.01, explainer_hidden=[20, 20], 102 | temperature=tau, l1=l1) 103 | 104 | start = time.time() 105 | trainer.fit(model, train_loader, val_loader) 106 | print(f"Gamma: {model.model[0].concept_mask}") 107 | model.freeze() 108 | model_results = trainer.test(model, test_dataloaders=test_loader) 109 | for j in range(n_classes): 110 | n_used_concepts = sum(model.model[0].concept_mask[j] > 0.5) 111 | print(f"Extracted concepts: {n_used_concepts}") 112 | results, f = model.explain_class(val_loader, train_loader, test_loader, 113 | topk_explanations=10, 114 | concept_names=concept_names) 115 | end = time.time() - start 116 | results['model_accuracy'] = model_results[0]['test_acc'] 117 | results['extraction_time'] = end 118 | results['tau'] = tau 119 | results['lambda'] = l1 120 | 121 | results_list.append(results) 122 | 123 | results_df = pd.DataFrame(results_list) 124 | results_df.to_csv(os.path.join(base_dir, f'results_aware_vdem_l_{l1}_tau_{tau}.csv')) 125 | 126 | 127 | print(results_list) 128 | -------------------------------------------------------------------------------- /experiments/elens/mimic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import time 5 | import torch 6 | from torch.utils.data import DataLoader, TensorDataset, random_split 7 | from pytorch_lightning.callbacks import ModelCheckpoint 8 | from pytorch_lightning import Trainer, seed_everything 9 | from sklearn.ensemble import RandomForestClassifier 10 | from sklearn.tree import DecisionTreeClassifier 11 | from sklearn.model_selection import StratifiedKFold, train_test_split 12 | from sklearn.feature_selection import mutual_info_classif, chi2 13 | from sklearn.linear_model import LassoCV 14 | import matplotlib.pyplot as plt 15 | import seaborn as sns 16 | 17 | from torch_explain.models.explainer import Explainer 18 | from torch_explain.logic.metrics import formula_consistency 19 | from experiments.data.load_datasets import load_mimic 20 | 21 | x, y, concept_names = load_mimic() 22 | 23 | dataset = TensorDataset(x, y) 24 | train_size = int(len(dataset) * 0.5) 25 | val_size = (len(dataset) - train_size) // 2 26 | test_size = len(dataset) - train_size - val_size 27 | train_data, val_data, test_data = random_split(dataset, [train_size, val_size, test_size]) 28 | train_loader = DataLoader(train_data, batch_size=train_size) 29 | val_loader = DataLoader(val_data, batch_size=val_size) 30 | test_loader = DataLoader(test_data, batch_size=test_size) 31 | n_concepts = next(iter(train_loader))[0].shape[1] 32 | n_classes = 2 33 | print(concept_names) 34 | print(n_concepts) 35 | print(n_classes) 36 | 37 | # %% md 38 | 39 | ## 5-fold cross-validation with explainer network 40 | 41 | # %% 42 | 43 | seed_everything(42) 44 | 45 | base_dir = f'./results/mimic-ii/explainer' 46 | os.makedirs(base_dir, exist_ok=True) 47 | 48 | n_splits = 5 49 | skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42) 50 | results_list = [] 51 | feature_selection = [] 52 | explanations = {i: [] for i in range(n_classes)} 53 | for split, (trainval_index, test_index) in enumerate(skf.split(x.cpu().detach().numpy(), 54 | y.argmax(dim=1).cpu().detach().numpy())): 55 | print(f'Split [{split + 1}/{n_splits}]') 56 | x_trainval, x_test = torch.FloatTensor(x[trainval_index]), torch.FloatTensor(x[test_index]) 57 | y_trainval, y_test = torch.FloatTensor(y[trainval_index]), torch.FloatTensor(y[test_index]) 58 | x_train, x_val, y_train, y_val = train_test_split(x_trainval, y_trainval, test_size=0.2, random_state=42) 59 | print(f'{len(y_train)}/{len(y_val)}/{len(y_test)}') 60 | 61 | train_data = TensorDataset(x_train, y_train) 62 | val_data = TensorDataset(x_val, y_val) 63 | test_data = TensorDataset(x_test, y_test) 64 | train_loader = DataLoader(train_data, batch_size=train_size) 65 | val_loader = DataLoader(val_data, batch_size=val_size) 66 | test_loader = DataLoader(test_data, batch_size=test_size) 67 | 68 | checkpoint_callback = ModelCheckpoint(dirpath=base_dir, monitor='val_loss', save_top_k=1) 69 | trainer = Trainer(max_epochs=200, gpus=1, auto_lr_find=True, deterministic=True, 70 | check_val_every_n_epoch=1, default_root_dir=base_dir, 71 | weights_save_path=base_dir, callbacks=[checkpoint_callback]) 72 | model = Explainer(n_concepts=n_concepts, n_classes=n_classes, l1=1e-3, lr=0.01, 73 | explainer_hidden=[20], temperature=0.7) 74 | 75 | start = time.time() 76 | trainer.fit(model, train_loader, val_loader) 77 | print(f"Gamma: {model.model[0].concept_mask}") 78 | model.freeze() 79 | model_results = trainer.test(model, test_dataloaders=test_loader) 80 | for j in range(n_classes): 81 | n_used_concepts = sum(model.model[0].concept_mask[j] > 0.5) 82 | print(f"Extracted concepts: {n_used_concepts}") 83 | results, f = model.explain_class(val_loader, train_loader, test_loader, 84 | topk_explanations=10, 85 | concept_names=concept_names) 86 | end = time.time() - start 87 | results['model_accuracy'] = model_results[0]['test_acc'] 88 | results['extraction_time'] = end 89 | 90 | results_list.append(results) 91 | extracted_concepts = [] 92 | all_concepts = model.model[0].concept_mask[0] > 0.5 93 | common_concepts = model.model[0].concept_mask[0] > 0.5 94 | for j in range(n_classes): 95 | n_used_concepts = sum(model.model[0].concept_mask[j] > 0.5) 96 | print(f"Extracted concepts: {n_used_concepts}") 97 | print(f"Explanation: {f[j]['explanation']}") 98 | print(f"Explanation accuracy: {f[j]['explanation_accuracy']}") 99 | explanations[j].append(f[j]['explanation']) 100 | extracted_concepts.append(n_used_concepts) 101 | all_concepts += model.model[0].concept_mask[j] > 0.5 102 | common_concepts *= model.model[0].concept_mask[j] > 0.5 103 | 104 | results['extracted_concepts'] = np.mean(extracted_concepts) 105 | results['common_concepts_ratio'] = sum(common_concepts) / sum(all_concepts) 106 | 107 | # compare against standard feature selection 108 | i_mutual_info = mutual_info_classif(x_trainval, y_trainval[:, 1]) 109 | i_chi2 = chi2(x_trainval, y_trainval[:, 1])[0] 110 | i_chi2[np.isnan(i_chi2)] = 0 111 | lasso = LassoCV(cv=5, random_state=0).fit(x_trainval, y_trainval[:, 1]) 112 | i_lasso = np.abs(lasso.coef_) 113 | i_mu = model.model[0].concept_mask[1] 114 | df = pd.DataFrame(np.hstack([ 115 | i_mu.numpy(), 116 | i_mutual_info / np.max(i_mutual_info), 117 | i_chi2 / np.max(i_chi2), 118 | i_lasso / np.max(i_lasso), 119 | ]).T, columns=['feature importance']) 120 | df['method'] = 'explainer' 121 | df.iloc[90:, 1] = 'MI' 122 | df.iloc[180:, 1] = 'CHI2' 123 | df.iloc[270:, 1] = 'Lasso' 124 | df['feature'] = np.hstack([np.arange(0, 90)] * 4) 125 | feature_selection.append(df) 126 | 127 | consistencies = [] 128 | for j in range(n_classes): 129 | consistencies.append(formula_consistency(explanations[j])) 130 | explanation_consistency = np.mean(consistencies) 131 | 132 | feature_selection = pd.concat(feature_selection, axis=0) 133 | 134 | # %% md 135 | 136 | ## Print results 137 | 138 | # %% 139 | 140 | f1 = feature_selection[feature_selection['feature'] <= 30] 141 | f2 = feature_selection[(feature_selection['feature'] > 30) & (feature_selection['feature'] <= 60)] 142 | f3 = feature_selection[feature_selection['feature'] > 60] 143 | 144 | # %% 145 | 146 | plt.figure(figsize=[10, 10]) 147 | plt.subplot(1, 3, 1) 148 | ax = sns.barplot(y=f1['feature'], x=f1.iloc[:, 0], 149 | hue=f1['method'], orient='h', errwidth=0.5, errcolor='k') 150 | ax.get_legend().remove() 151 | plt.subplot(1, 3, 2) 152 | ax = sns.barplot(y=f2['feature'], x=f2.iloc[:, 0], 153 | hue=f2['method'], orient='h', errwidth=0.5, errcolor='k') 154 | plt.xlabel('') 155 | ax.get_legend().remove() 156 | plt.subplot(1, 3, 3) 157 | sns.barplot(y=f3['feature'], x=f3.iloc[:, 0], 158 | hue=f3['method'], orient='h', errwidth=0.5, errcolor='k') 159 | plt.xlabel('') 160 | plt.tight_layout() 161 | plt.savefig(os.path.join(base_dir, 'barplot_mimic.png')) 162 | plt.savefig(os.path.join(base_dir, 'barplot_mimic.pdf')) 163 | plt.show() 164 | 165 | # %% 166 | 167 | plt.figure(figsize=[6, 4]) 168 | sns.boxplot(x=feature_selection.iloc[:, 1], y=feature_selection.iloc[:, 0]) 169 | plt.tight_layout() 170 | plt.savefig(os.path.join(base_dir, 'boxplot_mimic.png')) 171 | plt.savefig(os.path.join(base_dir, 'boxplot_mimic.pdf')) 172 | plt.show() 173 | 174 | # %% 175 | 176 | results_df = pd.DataFrame(results_list) 177 | results_df['explanation_consistency'] = explanation_consistency 178 | results_df.to_csv(os.path.join(base_dir, 'results_aware_mimic.csv')) 179 | results_df 180 | 181 | # %% 182 | 183 | results_df.mean() 184 | 185 | # %% 186 | 187 | results_df.sem() 188 | 189 | # %% md 190 | 191 | ## Compare with out-of-the-box models 192 | 193 | # %% 194 | 195 | dt_scores, rf_scores = [], [] 196 | for split, (trainval_index, test_index) in enumerate( 197 | skf.split(x.cpu().detach().numpy(), y.argmax(dim=1).cpu().detach().numpy())): 198 | print(f'Split [{split + 1}/{n_splits}]') 199 | x_trainval, x_test = x[trainval_index], x[test_index] 200 | y_trainval, y_test = y[trainval_index].argmax(dim=1), y[test_index].argmax(dim=1) 201 | 202 | dt_model = DecisionTreeClassifier(max_depth=5, random_state=split) 203 | dt_model.fit(x_trainval, y_trainval) 204 | dt_scores.append(dt_model.score(x_test, y_test)) 205 | 206 | rf_model = RandomForestClassifier(random_state=split) 207 | rf_model.fit(x_trainval, y_trainval) 208 | rf_scores.append(rf_model.score(x_test, y_test)) 209 | 210 | print(f'Random forest scores: {np.mean(rf_scores)} (+/- {np.std(rf_scores)})') 211 | print(f'Decision tree scores: {np.mean(dt_scores)} (+/- {np.std(dt_scores)})') 212 | print(f'Mu net scores (model): {results_df["model_accuracy"].mean()} (+/- {results_df["model_accuracy"].std()})') 213 | print( 214 | f'Mu net scores (exp): {results_df["explanation_accuracy"].mean()} (+/- {results_df["explanation_accuracy"].std()})') 215 | -------------------------------------------------------------------------------- /experiments/elens/mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import time 5 | from torch.utils.data import DataLoader, TensorDataset, random_split 6 | from pytorch_lightning.callbacks import ModelCheckpoint 7 | from pytorch_lightning import Trainer, seed_everything 8 | 9 | from torch_explain.models.explainer import Explainer 10 | from torch_explain.logic.metrics import formula_consistency 11 | from experiments.data.load_datasets import load_mnist 12 | 13 | # %% md 14 | 15 | ## Import MIMIC-II dataset 16 | 17 | # %% 18 | x, y, concept_names = load_mnist() 19 | 20 | 21 | dataset = TensorDataset(x, y) 22 | train_size = int(len(dataset) * 0.9) 23 | val_size = (len(dataset) - train_size) // 2 24 | test_size = len(dataset) - train_size - val_size 25 | train_data, val_data, test_data = random_split(dataset, [train_size, val_size, test_size]) 26 | train_loader = DataLoader(train_data, batch_size=len(train_data)) 27 | val_loader = DataLoader(val_data, batch_size=len(val_data)) 28 | test_loader = DataLoader(test_data, batch_size=len(test_data)) 29 | n_concepts = next(iter(train_loader))[0].shape[1] 30 | n_classes = 2 31 | print(concept_names) 32 | print(n_concepts) 33 | print(n_classes) 34 | 35 | # %% md 36 | 37 | ## 5-fold cross-validation with explainer network 38 | 39 | base_dir = f'./results/MNIST/explainer' 40 | os.makedirs(base_dir, exist_ok=True) 41 | 42 | n_seeds = 5 43 | results_list = [] 44 | explanations = {i: [] for i in range(n_classes)} 45 | for seed in range(n_seeds): 46 | seed_everything(seed) 47 | print(f'Seed [{seed + 1}/{n_seeds}]') 48 | train_loader = DataLoader(train_data, batch_size=len(train_data)) 49 | val_loader = DataLoader(val_data, batch_size=len(val_data)) 50 | test_loader = DataLoader(test_data, batch_size=len(test_data)) 51 | 52 | checkpoint_callback = ModelCheckpoint(dirpath=base_dir, monitor='val_loss', save_top_k=1) 53 | trainer = Trainer(max_epochs=10, gpus=1, auto_lr_find=True, deterministic=True, 54 | check_val_every_n_epoch=1, default_root_dir=base_dir, 55 | weights_save_path=base_dir, callbacks=[checkpoint_callback]) 56 | model = Explainer(n_concepts=n_concepts, n_classes=n_classes, l1=0.0000001, temperature=5, lr=0.01, 57 | explainer_hidden=[10], conceptizator='identity_bool') 58 | 59 | start = time.time() 60 | trainer.fit(model, train_loader, val_loader) 61 | print(f"Concept mask: {model.model[0].concept_mask}") 62 | model.freeze() 63 | model_results = trainer.test(model, test_dataloaders=test_loader) 64 | for j in range(n_classes): 65 | n_used_concepts = sum(model.model[0].concept_mask[j] > 0.5) 66 | print(f"Extracted concepts: {n_used_concepts}") 67 | results, f = model.explain_class(val_loader, val_loader, test_loader, topk_explanations=5, 68 | x_to_bool=None, max_accuracy=True, concept_names=concept_names) 69 | end = time.time() - start 70 | results['model_accuracy'] = model_results[0]['test_acc'] 71 | results['extraction_time'] = end 72 | 73 | results_list.append(results) 74 | extracted_concepts = [] 75 | all_concepts = model.model[0].concept_mask[0] > 0.5 76 | common_concepts = model.model[0].concept_mask[0] > 0.5 77 | for j in range(n_classes): 78 | n_used_concepts = sum(model.model[0].concept_mask[j] > 0.5) 79 | print(f"Extracted concepts: {n_used_concepts}") 80 | print(f"Explanation: {f[j]['explanation']}") 81 | print(f"Explanation accuracy: {f[j]['explanation_accuracy']}") 82 | explanations[j].append(f[j]['explanation']) 83 | extracted_concepts.append(n_used_concepts) 84 | all_concepts += model.model[0].concept_mask[j] > 0.5 85 | common_concepts *= model.model[0].concept_mask[j] > 0.5 86 | 87 | results['extracted_concepts'] = np.mean(extracted_concepts) 88 | results['common_concepts_ratio'] = sum(common_concepts) / sum(all_concepts) 89 | 90 | consistencies = [] 91 | for j in range(n_classes): 92 | consistencies.append(formula_consistency(explanations[j])) 93 | explanation_consistency = np.mean(consistencies) 94 | 95 | results_df = pd.DataFrame(results_list) 96 | results_df['explanation_consistency'] = explanation_consistency 97 | results_df.to_csv(os.path.join(base_dir, 'results_aware_mnist.csv')) 98 | results_df 99 | 100 | # %% 101 | 102 | results_df.mean() 103 | 104 | # %% 105 | 106 | results_df.sem() 107 | -------------------------------------------------------------------------------- /experiments/elens/results_old/CUB/explainer/results_aware_cub.csv: -------------------------------------------------------------------------------- 1 | ,explanation_accuracy,explanation_fidelity,explanation_complexity,model_accuracy,extraction_time,extracted_concepts,common_concepts_ratio,explanation_consistency 2 | 0,0.9533668557045459,0.9985423728813558,3.725,0.9271186590194702,174.71739959716797,15.31,tensor(0.),0.3551646090321788 3 | 1,0.9533697419817596,0.9987288135593221,3.71,0.9322033524513245,169.01225447654724,15.06,tensor(0.),0.3551646090321788 4 | 2,0.9523896598887611,0.9984999999999999,3.825,0.9355931878089905,166.97806310653687,14.55,tensor(0.),0.3551646090321788 5 | 3,0.9506608260858387,0.9984237288135592,3.665,0.9237288236618042,170.8761112689972,14.635,tensor(0.),0.3551646090321788 6 | 4,0.9524511272288261,0.9986186440677965,3.755,0.9288135766983032,177.751935005188,14.625,tensor(0.),0.3551646090321788 7 | -------------------------------------------------------------------------------- /experiments/elens/results_old/MNIST/explainer/results_aware_mnist.csv: -------------------------------------------------------------------------------- 1 | ,explanation_accuracy,explanation_fidelity,explanation_complexity,model_accuracy,extraction_time,extracted_concepts,common_concepts_ratio,explanation_consistency 2 | 0,0.9962499594787191,0.9962500000000001,50.0,0.9976666569709778,140.34969568252563,10.0,tensor(1.),1.0 3 | 1,0.9962499594787191,0.9964166666666667,50.0,0.9984999895095825,136.56005907058716,10.0,tensor(1.),1.0 4 | 2,0.9962499594787191,0.9962500000000001,50.0,0.9986666440963745,138.76298713684082,10.0,tensor(1.),1.0 5 | 3,0.9962499594787191,0.99625,50.0,0.9975000023841858,138.25102996826172,10.0,tensor(1.),1.0 6 | 4,0.9962499594787191,0.9964166666666667,50.0,0.9983333349227905,137.67692637443542,10.0,tensor(1.),1.0 7 | -------------------------------------------------------------------------------- /experiments/elens/results_old/vdem/explainer/results_aware_vdem.csv: -------------------------------------------------------------------------------- 1 | ,explanation_accuracy,explanation_fidelity,explanation_complexity,model_accuracy,extraction_time,extracted_concepts,common_concepts_ratio,explanation_consistency 2 | 0,0.885761036532425,0.861890694239291,4.0,0.9615952372550964,24.49960970878601,12.5,tensor(0.7857),0.4625 3 | 1,0.9084284469271398,0.9217134416543574,2.0,0.942392885684967,26.258350372314453,11.0,tensor(0.5714),0.4625 4 | 2,0.9069456622598302,0.929837518463811,3.0,0.9394387006759644,20.23900842666626,7.5,tensor(0.5000),0.4625 5 | 3,0.8873662457750302,0.9084194977843427,2.0,0.9335302710533142,183.4691743850708,13.5,tensor(0.9286),0.4625 6 | 4,0.9056895193065406,0.9231905465288035,4.5,0.9483013153076172,45.040387868881226,13.0,tensor(0.8571),0.4625 7 | -------------------------------------------------------------------------------- /experiments/elens/summary_to_latex.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "indonesian-daniel", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os\n", 11 | "import glob\n", 12 | "import numpy as np\n", 13 | "import pandas as pd" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "id": "centered-president", 20 | "metadata": {}, 21 | "outputs": [ 22 | { 23 | "data": { 24 | "text/plain": [ 25 | "['./results\\\\CUB\\\\explainer\\\\results_aware_cub.csv',\n", 26 | " './results\\\\mimic-ii\\\\explainer\\\\results_aware_mimic.csv',\n", 27 | " './results\\\\MNIST\\\\explainer\\\\results_aware_mnist.csv',\n", 28 | " './results\\\\vdem\\\\explainer\\\\results_aware_vdem.csv']" 29 | ] 30 | }, 31 | "execution_count": 2, 32 | "metadata": {}, 33 | "output_type": "execute_result" 34 | } 35 | ], 36 | "source": [ 37 | "files = glob.glob('./results/**/results**.csv', recursive = True)\n", 38 | "files" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 3, 44 | "id": "derived-mining", 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "file = files[0]" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 4, 54 | "id": "needed-colleague", 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "data": { 59 | "text/plain": [ 60 | "explanation_accuracy 0.952448\n", 61 | "explanation_fidelity 0.998563\n", 62 | "explanation_complexity 3.736000\n", 63 | "model_accuracy 0.929492\n", 64 | "extraction_time 171.867153\n", 65 | "extracted_concepts 14.836000\n", 66 | "explanation_consistency 0.355165\n", 67 | "dtype: float64" 68 | ] 69 | }, 70 | "execution_count": 4, 71 | "metadata": {}, 72 | "output_type": "execute_result" 73 | } 74 | ], 75 | "source": [ 76 | "df = pd.read_csv(file, index_col=0)\n", 77 | "dfm = df.mean()\n", 78 | "dfs = df.sem()\n", 79 | "dfm" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 5, 85 | "id": "imposed-cornwall", 86 | "metadata": {}, 87 | "outputs": [ 88 | { 89 | "name": "stdout", 90 | "output_type": "stream", 91 | "text": [ 92 | "Dataset: ./results\\CUB\\explainer\\results_aware_cub.csv\n", 93 | "Model accuracy\n", 94 | "$92.95 \\pm 0.20$\n", 95 | "Explanation accuracy\n", 96 | "$95.24 \\pm 0.05$\n", 97 | "Complexity\n", 98 | "$3.74 \\pm 0.03$\n", 99 | "Fidelity\n", 100 | "$99.86 \\pm 0.01$\n", 101 | "Consistency\n", 102 | "$35.52$\n", 103 | "Time\n", 104 | "$171.87 \\pm 1.95$\n", 105 | "\n", 106 | "Dataset: ./results\\mimic-ii\\explainer\\results_aware_mimic.csv\n", 107 | "Model accuracy\n", 108 | "$79.05 \\pm 1.35$\n", 109 | "Explanation accuracy\n", 110 | "$66.93 \\pm 2.14$\n", 111 | "Complexity\n", 112 | "$3.50 \\pm 0.88$\n", 113 | "Fidelity\n", 114 | "$79.11 \\pm 2.02$\n", 115 | "Consistency\n", 116 | "$28.75$\n", 117 | "Time\n", 118 | "$23.08 \\pm 3.53$\n", 119 | "\n", 120 | "Dataset: ./results\\MNIST\\explainer\\results_aware_mnist.csv\n", 121 | "Model accuracy\n", 122 | "$99.81 \\pm 0.02$\n", 123 | "Explanation accuracy\n", 124 | "$99.62 \\pm 0.00$\n", 125 | "Complexity\n", 126 | "$50.00 \\pm 0.00$\n", 127 | "Fidelity\n", 128 | "$99.63 \\pm 0.00$\n", 129 | "Consistency\n", 130 | "$100.00$\n", 131 | "Time\n", 132 | "$138.32 \\pm 0.63$\n", 133 | "\n", 134 | "Dataset: ./results\\vdem\\explainer\\results_aware_vdem.csv\n", 135 | "Model accuracy\n", 136 | "$94.51 \\pm 0.48$\n", 137 | "Explanation accuracy\n", 138 | "$89.88 \\pm 0.50$\n", 139 | "Complexity\n", 140 | "$3.10 \\pm 0.51$\n", 141 | "Fidelity\n", 142 | "$90.90 \\pm 1.23$\n", 143 | "Consistency\n", 144 | "$46.25$\n", 145 | "Time\n", 146 | "$59.90 \\pm 31.18$\n", 147 | "\n" 148 | ] 149 | } 150 | ], 151 | "source": [ 152 | "for file in files:\n", 153 | " df = pd.read_csv(file, index_col=0)\n", 154 | " dfm = df.mean()\n", 155 | " dfs = df.sem()\n", 156 | " print(f'Dataset: {file}')\n", 157 | " print(f'Model accuracy')\n", 158 | " print(f\"${100*dfm['model_accuracy']:.2f} \\pm {100*dfs['model_accuracy']:.2f}$\")\n", 159 | " print(f'Explanation accuracy')\n", 160 | " print(f\"${100*dfm['explanation_accuracy']:.2f} \\pm {100*dfs['explanation_accuracy']:.2f}$\")\n", 161 | " print(f'Complexity')\n", 162 | " print(f\"${dfm['explanation_complexity']:.2f} \\pm {dfs['explanation_complexity']:.2f}$\")\n", 163 | " print(f'Fidelity')\n", 164 | " print(f\"${100*dfm['explanation_fidelity']:.2f} \\pm {100*dfs['explanation_fidelity']:.2f}$\")\n", 165 | " print(f'Consistency')\n", 166 | " print(f\"${100*dfm['explanation_consistency']:.2f}$\")\n", 167 | " print(f'Time')\n", 168 | " print(f\"${dfm['extraction_time']:.2f} \\pm {dfs['extraction_time']:.2f}$\")\n", 169 | " print()" 170 | ] 171 | } 172 | ], 173 | "metadata": { 174 | "kernelspec": { 175 | "display_name": "Python 3", 176 | "language": "python", 177 | "name": "python3" 178 | }, 179 | "language_info": { 180 | "codemirror_mode": { 181 | "name": "ipython", 182 | "version": 3 183 | }, 184 | "file_extension": ".py", 185 | "mimetype": "text/x-python", 186 | "name": "python", 187 | "nbconvert_exporter": "python", 188 | "pygments_lexer": "ipython3", 189 | "version": "3.8.5" 190 | } 191 | }, 192 | "nbformat": 4, 193 | "nbformat_minor": 5 194 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | scikit-learn 4 | pandas 5 | torch 6 | sympy -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.rst 3 | 4 | [aliases] 5 | test = pytest 6 | 7 | [tool:pytest] 8 | addopts = --doctest-modules 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | """A template.""" 3 | 4 | import codecs 5 | import os 6 | 7 | from setuptools import find_packages, setup 8 | 9 | # get __version__ from _version.py 10 | ver_file = os.path.join('torch_explain', '_version.py') 11 | with open(ver_file) as f: 12 | exec(f.read()) 13 | 14 | DISTNAME = 'torch_explain' 15 | DESCRIPTION = 'PyTorch Explain: Explainable Deep Learning in Python.' 16 | with codecs.open('README.rst') as f: 17 | LONG_DESCRIPTION = f.read() 18 | MAINTAINER = 'P. Barbiero' 19 | MAINTAINER_EMAIL = 'barbiero@tutanota.com' 20 | URL = 'https://github.com/pietrobarbiero/pytorch_explain' 21 | LICENSE = 'Apache 2.0' 22 | DOWNLOAD_URL = 'https://github.com/pietrobarbiero/pytorch_explain' 23 | VERSION = __version__ 24 | INSTALL_REQUIRES = ['numpy', 'scipy', 'scikit-learn', 'pandas', 'torch', 'sympy'] 25 | CLASSIFIERS = ['Intended Audience :: Science/Research', 26 | 'Intended Audience :: Developers', 27 | 'License :: OSI Approved', 28 | 'Programming Language :: Python', 29 | 'Topic :: Software Development', 30 | 'Topic :: Scientific/Engineering', 31 | 'Operating System :: Microsoft :: Windows', 32 | 'Operating System :: POSIX', 33 | 'Operating System :: Unix', 34 | 'Operating System :: MacOS', 35 | 'Programming Language :: Python :: 2.7', 36 | 'Programming Language :: Python :: 3.5', 37 | 'Programming Language :: Python :: 3.6', 38 | 'Programming Language :: Python :: 3.7'] 39 | EXTRAS_REQUIRE = { 40 | 'tests': [ 41 | 'pytest', 42 | 'pytest-cov'], 43 | 'docs': [ 44 | 'sphinx', 45 | 'sphinx-gallery', 46 | 'sphinx_rtd_theme', 47 | 'numpydoc', 48 | 'matplotlib' 49 | ] 50 | } 51 | 52 | setup(name=DISTNAME, 53 | maintainer=MAINTAINER, 54 | maintainer_email=MAINTAINER_EMAIL, 55 | description=DESCRIPTION, 56 | license=LICENSE, 57 | url=URL, 58 | version=VERSION, 59 | download_url=DOWNLOAD_URL, 60 | long_description=LONG_DESCRIPTION, 61 | long_description_content_type='text/x-rst', 62 | zip_safe=False, # the package can run out of an .egg file 63 | classifiers=CLASSIFIERS, 64 | packages=find_packages(), 65 | install_requires=INSTALL_REQUIRES, 66 | extras_require=EXTRAS_REQUIRE) 67 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pietrobarbiero/pytorch_explain/4986e0b75e38967bcdb6923c53872ceea5191760/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_cem.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | from sklearn.metrics import accuracy_score 5 | from sklearn.model_selection import train_test_split 6 | 7 | import torch_explain as te 8 | from torch_explain import datasets 9 | 10 | 11 | def train_concept_bottleneck_model(x, c, y, embedding_size=1): 12 | x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(x, c, y, test_size=0.33, random_state=42) 13 | 14 | if embedding_size > 1: 15 | # concept embedding model 16 | encoder = torch.nn.Sequential( 17 | torch.nn.Linear(x.shape[1], 10), 18 | torch.nn.LeakyReLU(), 19 | ) 20 | concept_embedder = te.nn.ConceptEmbedding(10, c.shape[1], embedding_size) 21 | task_predictor = torch.nn.Sequential( 22 | torch.nn.Linear(c.shape[1]*embedding_size, 1), 23 | ) 24 | model = torch.nn.Sequential(encoder, concept_embedder, task_predictor) 25 | else: 26 | # standard concept bottleneck model 27 | concept_embedder = torch.nn.Sequential( 28 | torch.nn.Linear(x.shape[1], 10), 29 | torch.nn.LeakyReLU(), 30 | torch.nn.Linear(10, 8), 31 | torch.nn.LeakyReLU(), 32 | torch.nn.Linear(8, c.shape[1]), 33 | torch.nn.Sigmoid(), 34 | ) 35 | task_predictor = torch.nn.Sequential( 36 | torch.nn.Linear(c.shape[1], 1), 37 | ) 38 | model = torch.nn.Sequential(concept_embedder, task_predictor) 39 | 40 | optimizer = torch.optim.AdamW(model.parameters(), lr=0.01) 41 | loss_form_c = torch.nn.BCELoss() 42 | loss_form_y = torch.nn.BCEWithLogitsLoss() 43 | model.train() 44 | for epoch in range(501): 45 | optimizer.zero_grad() 46 | 47 | if embedding_size > 1: 48 | h = encoder(x_train) 49 | c_emb, c_pred = concept_embedder.forward(h, [0, 1], c_train, train=True) 50 | y_pred = task_predictor(c_emb.reshape(len(c_emb), -1)) 51 | else: 52 | c_pred = concept_embedder(x_train) 53 | y_pred = task_predictor(c_pred) 54 | 55 | concept_loss = loss_form_c(c_pred, c_train) 56 | task_loss = loss_form_y(y_pred, y_train) 57 | loss = concept_loss + 0.5*task_loss 58 | loss.backward() 59 | optimizer.step() 60 | 61 | # compute accuracy 62 | if epoch % 100 == 0: 63 | if embedding_size > 1: 64 | h = encoder(x_test) 65 | c_emb, c_pred = concept_embedder.forward(h, [0, 1], c_test, train=False) 66 | y_pred = task_predictor(c_emb.reshape(len(c_emb), -1)) 67 | else: 68 | c_pred = concept_embedder(x_test) 69 | y_pred = task_predictor(c_pred) 70 | 71 | task_accuracy = accuracy_score(y_test, y_pred > 0) 72 | concept_accuracy = accuracy_score(c_test, c_pred > 0.5) 73 | print(f'Epoch {epoch}: loss {loss:.4f} task accuracy: {task_accuracy:.4f} concept accuracy: {concept_accuracy:.4f}') 74 | 75 | return model 76 | 77 | 78 | class TestTemplateObject(unittest.TestCase): 79 | def test_concept_embedding(self): 80 | x, c, y = datasets.xor(500) 81 | train_concept_bottleneck_model(x, c, y, embedding_size=1) 82 | train_concept_bottleneck_model(x, c, y, embedding_size=8) 83 | 84 | x, c, y = datasets.trigonometry(500) 85 | train_concept_bottleneck_model(x, c, y, embedding_size=1) 86 | train_concept_bottleneck_model(x, c, y, embedding_size=8) 87 | 88 | x, c, y = datasets.dot(500) 89 | train_concept_bottleneck_model(x, c, y, embedding_size=1) 90 | train_concept_bottleneck_model(x, c, y, embedding_size=8) 91 | 92 | return 93 | 94 | def test_concept_interventions(self): 95 | x, c, y = datasets.dot(500) 96 | 97 | # concept embedding model 98 | encoder = torch.nn.Sequential( 99 | torch.nn.Linear(x.shape[1], 10), 100 | torch.nn.LeakyReLU(), 101 | ) 102 | h = encoder(x) 103 | 104 | concept_embedder = te.nn.ConceptEmbedding(10, c.shape[1], 8) 105 | c_emb, c_pred = concept_embedder.forward(h, [0, 1], c, train=True) 106 | 107 | concept_embedder = te.nn.ConceptEmbedding(10, c.shape[1], 8, 108 | active_intervention_values=1, inactive_intervention_values=0, 109 | intervention_idxs=[0, 1]) 110 | c_emb, c_pred = concept_embedder.forward(h, train=True) 111 | 112 | concept_embedder = te.nn.ConceptEmbedding(10, c.shape[1], 8, training_intervention_prob=0) 113 | c_emb, c_pred = concept_embedder.forward(h) 114 | 115 | concept_embedder = te.nn.ConceptEmbedding(10, c.shape[1], 8) 116 | c_emb, c_pred = concept_embedder.forward(h, train=True) 117 | 118 | concept_embedder = te.nn.ConceptEmbedding(10, c.shape[1], 8) 119 | c_emb, c_pred = concept_embedder.forward(h, intervention_idxs=[10]) 120 | 121 | return 122 | 123 | 124 | if __name__ == '__main__': 125 | unittest.main() 126 | -------------------------------------------------------------------------------- /tests/test_dcr.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | from sklearn.metrics import accuracy_score 5 | from sklearn.model_selection import train_test_split 6 | import torch.nn.functional as F 7 | 8 | import torch_explain as te 9 | from torch_explain import datasets 10 | from torch_explain.nn.concepts import ConceptReasoningLayer 11 | from torch_explain.nn.semantics import GodelTNorm, ProductTNorm 12 | 13 | 14 | def train_concept_bottleneck_model(x, c, y, embedding_size=1, logic=GodelTNorm(), temperature=100): 15 | x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(x, c, y, test_size=0.33, random_state=42) 16 | 17 | y_train = F.one_hot(y_train.long().ravel()).float() 18 | y_test = F.one_hot(y_test.long().ravel()).float() 19 | 20 | # concept embedding model 21 | encoder = torch.nn.Sequential( 22 | torch.nn.Linear(x.shape[1], 10), 23 | torch.nn.LeakyReLU(), 24 | ) 25 | concept_embedder = te.nn.ConceptEmbedding(10, c.shape[1], embedding_size) 26 | task_predictor = ConceptReasoningLayer(embedding_size, y_train.shape[1], logic, temperature) 27 | model = torch.nn.Sequential(encoder, concept_embedder, task_predictor) 28 | 29 | optimizer = torch.optim.AdamW(model.parameters(), lr=0.01) 30 | loss_form_c = torch.nn.BCELoss() 31 | loss_form_y = torch.nn.BCELoss() 32 | model.train() 33 | for epoch in range(501): 34 | optimizer.zero_grad() 35 | 36 | h = encoder(x_train) 37 | c_emb, c_pred = concept_embedder.forward(h, [0, 1], c_train, train=True) 38 | y_pred = task_predictor(c_emb, c_pred) 39 | 40 | concept_loss = loss_form_c(c_pred, c_train) 41 | task_loss = loss_form_y(y_pred, y_train) 42 | loss = concept_loss + 0.5*task_loss 43 | loss.backward() 44 | optimizer.step() 45 | 46 | # compute accuracy 47 | if epoch % 100 == 0: 48 | h = encoder(x_test) 49 | c_emb, c_pred = concept_embedder.forward(h, [0, 1], c_test, train=False) 50 | y_pred = task_predictor(c_emb, c_pred) 51 | 52 | task_accuracy = accuracy_score(y_test, y_pred > 0.5) 53 | concept_accuracy = accuracy_score(c_test, c_pred > 0.5) 54 | print(f'Epoch {epoch}: loss {loss:.4f} task accuracy: {task_accuracy:.4f} concept accuracy: {concept_accuracy:.4f}') 55 | 56 | local_explanations = task_predictor.explain(c_emb, c_pred, 'local') 57 | global_explanations = task_predictor.explain(c_emb, c_pred, 'global') 58 | print(global_explanations) 59 | 60 | return model 61 | 62 | 63 | class LukasiewiczTNorm: 64 | pass 65 | 66 | 67 | class TestTemplateObject(unittest.TestCase): 68 | def test_deep_core(self): 69 | x, c, y = datasets.xor(1000) 70 | train_concept_bottleneck_model(x, c, y, embedding_size=16) 71 | 72 | x, c, y = datasets.trigonometry(1000) 73 | train_concept_bottleneck_model(x, c, y, embedding_size=16) 74 | # 75 | x, c, y = datasets.dot(1000) 76 | train_concept_bottleneck_model(x, c, y, embedding_size=16) 77 | 78 | return 79 | 80 | def test_semantics(self): 81 | x, c, y = datasets.xor(200) 82 | for logic in [GodelTNorm(), ProductTNorm()]: 83 | train_concept_bottleneck_model(x, c, y, embedding_size=16, logic=logic) 84 | 85 | return 86 | 87 | 88 | if __name__ == '__main__': 89 | unittest.main() 90 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | # from sklearn.datasets import make_classification 5 | # from sklearn.model_selection import StratifiedShuffleSplit 6 | # from torch.nn.functional import one_hot 7 | # from torch import nn 8 | from torch.nn.functional import one_hot, leaky_relu 9 | # from torch_geometric.nn import Sequential, GCNConv 10 | # from torch_geometric.utils import from_networkx 11 | # import networkx as nx 12 | # import numpy as np 13 | 14 | import torch_explain as te 15 | from torch_explain.logic.metrics import test_explanation, complexity, concept_consistency, formula_consistency 16 | from torch_explain.logic.nn import entropy, psi 17 | from torch_explain.nn.functional import prune_equal_fanin 18 | 19 | 20 | class TestTemplateObject(unittest.TestCase): 21 | def test_get_predictions(self): 22 | x, c, y = te.datasets.xor(500) 23 | formula = '(x1 & x2)' 24 | te.logic.utils.get_predictions(formula, c, 0.5) 25 | formula = '' 26 | te.logic.utils.get_predictions(formula, c, 0.5) 27 | 28 | 29 | if __name__ == '__main__': 30 | unittest.main() 31 | -------------------------------------------------------------------------------- /torch_explain/__init__.py: -------------------------------------------------------------------------------- 1 | from . import logic 2 | from . import nn 3 | from . import datasets 4 | 5 | from ._version import __version__ 6 | 7 | __all__ = [ 8 | 'nn', 9 | 'logic', 10 | 'datasets', 11 | '__version__' 12 | ] 13 | -------------------------------------------------------------------------------- /torch_explain/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.5.1" 2 | -------------------------------------------------------------------------------- /torch_explain/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .benchmarks import xor, trigonometry, dot 2 | 3 | __all__ = [ 4 | 'xor', 5 | 'trigonometry', 6 | 'dot', 7 | ] 8 | -------------------------------------------------------------------------------- /torch_explain/datasets/benchmarks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def xor(size, random_state=42): 6 | # sample from normal distribution 7 | np.random.seed(random_state) 8 | x = np.random.uniform(0, 1, (size, 2)) 9 | c = np.stack([ 10 | x[:, 0] > 0.5, 11 | x[:, 1] > 0.5, 12 | ]).T 13 | y = np.logical_xor(c[:, 0], c[:, 1]) 14 | 15 | x = torch.FloatTensor(x) 16 | c = torch.FloatTensor(c) 17 | y = torch.FloatTensor(y) 18 | return x, c, y.unsqueeze(-1) 19 | 20 | 21 | def trigonometry(size, random_state=42): 22 | np.random.seed(random_state) 23 | h = np.random.normal(0, 2, (size, 3)) 24 | x, y, z = h[:, 0], h[:, 1], h[:, 2] 25 | 26 | # raw features 27 | input_features = np.stack([ 28 | np.sin(x) + x, 29 | np.cos(x) + x, 30 | np.sin(y) + y, 31 | np.cos(y) + y, 32 | np.sin(z) + z, 33 | np.cos(z) + z, 34 | x ** 2 + y ** 2 + z ** 2, 35 | ]).T 36 | 37 | # concetps 38 | concetps = np.stack([ 39 | x > 0, 40 | y > 0, 41 | z > 0, 42 | ]).T 43 | 44 | # task 45 | downstream_task = (x + y + z) > 1 46 | 47 | input_features = torch.FloatTensor(input_features) 48 | concetps = torch.FloatTensor(concetps) 49 | downstream_task = torch.FloatTensor(downstream_task) 50 | return input_features, concetps, downstream_task.unsqueeze(-1) 51 | 52 | 53 | def dot(size, random_state=42): 54 | # sample from normal distribution 55 | emb_size = 2 56 | v1 = np.random.randn(size, emb_size) * 2 57 | v2 = np.ones(emb_size) 58 | v3 = np.random.randn(size, emb_size) * 2 59 | v4 = -np.ones(emb_size) 60 | x = np.hstack([v1+v3, v1-v3]) 61 | c = np.stack([ 62 | np.dot(v1, v2).ravel() > 0, 63 | np.dot(v3, v4).ravel() > 0, 64 | ]).T 65 | y = ((v1*v3).sum(axis=-1) > 0).astype(np.int64) 66 | 67 | x = torch.FloatTensor(x) 68 | c = torch.FloatTensor(c) 69 | y = torch.Tensor(y) 70 | return x, c, y.unsqueeze(-1) 71 | -------------------------------------------------------------------------------- /torch_explain/logic/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import replace_names 2 | from .nn import entropy, psi 3 | from .metrics import test_explanation, concept_consistency, formula_consistency, complexity, test_explanations 4 | 5 | __all__ = [ 6 | 'entropy', 7 | 'psi', 8 | 'test_explanation', 9 | 'test_explanations', 10 | 'replace_names', 11 | 'concept_consistency', 12 | 'formula_consistency', 13 | 'complexity', 14 | ] 15 | -------------------------------------------------------------------------------- /torch_explain/logic/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import sympy 4 | 5 | import torch 6 | import numpy as np 7 | from sklearn.metrics import f1_score, accuracy_score 8 | from sympy import to_dnf, lambdify 9 | 10 | 11 | def test_explanation(formula: str, x: torch.Tensor, y: torch.Tensor, target_class: int, 12 | mask: torch.Tensor = None, threshold: float = 0.5, 13 | material: bool = False) -> Tuple[float, torch.Tensor]: 14 | """ 15 | Tests a logic formula. 16 | 17 | :param formula: logic formula 18 | :param x: input data 19 | :param y: input labels (MUST be one-hot encoded) 20 | :param target_class: target class 21 | :param mask: sample mask 22 | :param threshold: threshold to get concept truth values 23 | :return: Accuracy of the explanation and predictions 24 | """ 25 | if formula in ['True', 'False', ''] or formula is None: 26 | return 0.0, None 27 | 28 | else: 29 | assert len(y.shape) == 2 30 | y2 = y[:, target_class] 31 | concept_list = [f"feature{i:010}" for i in range(x.shape[1])] 32 | # get predictions using sympy 33 | explanation = to_dnf(formula) 34 | fun = lambdify(concept_list, explanation, 'numpy') 35 | x = x.cpu().detach().numpy() 36 | predictions = fun(*[x[:, i] > threshold for i in range(x.shape[1])]) 37 | predictions = torch.LongTensor(predictions) 38 | if material: 39 | # material implication: (p=>q) <=> (not p or q) 40 | accuracy = torch.sum(torch.logical_or(torch.logical_not(predictions[mask]), y2[mask])) / len(y2[mask]) 41 | accuracy = accuracy.item() 42 | else: 43 | # material biconditional: (p<=>q) <=> (p and q) or (not p and not q) 44 | accuracy = accuracy_score(predictions[mask], y2[mask]) 45 | return accuracy, predictions 46 | 47 | 48 | def test_explanations(formulas: List[str], x: torch.Tensor, y: torch.Tensor, mask: torch.Tensor = None, 49 | threshold: float = 0.5, material: bool = False) -> Tuple[float, torch.Tensor]: 50 | """ 51 | Tests all together the logic formulas of different classes. 52 | When a sample fires more than one formula, consider the sample wrongly predicted. 53 | :param formulas: list of logic formula, one for each class 54 | :param x: input data 55 | :param y: input labels (MUST be one-hot encoded) 56 | :param mask: sample mask 57 | :param threshold: threshold to get concept truth values 58 | :return: Accuracy of the explanation and predictions 59 | """ 60 | if formulas is None or formulas == []: 61 | return 0.0, None 62 | for formula in formulas: 63 | if formula in ['True', 'False', '']: 64 | return 0.0, None 65 | assert len(y.shape) == 2 66 | 67 | y2 = y.argmax(-1) 68 | x = x.cpu().detach().numpy() 69 | concept_list = [f"feature{i:010}" for i in range(x.shape[1])] 70 | 71 | # get predictions using sympy 72 | class_predictions = torch.zeros(len(formulas), x.shape[0]) 73 | for i, formula in enumerate(formulas): 74 | explanation = to_dnf(formula) 75 | fun = lambdify(concept_list, explanation, 'numpy') 76 | 77 | predictions = fun(*[x[:, i] > threshold for i in range(x.shape[1])]) 78 | predictions = torch.LongTensor(predictions) 79 | class_predictions[i] = predictions 80 | 81 | class_predictions_filtered_by_pred = torch.zeros(class_predictions.shape[1]) 82 | for i in range(class_predictions.shape[1]): 83 | if sum(class_predictions[:, i]) in [0,2]: #todo: vectorize 84 | class_predictions_filtered_by_pred[i] = -1 #consider as an error 85 | else: 86 | class_predictions_filtered_by_pred[i] = class_predictions[:, i].argmax(-1) 87 | 88 | if material: 89 | # material implication: (p=>q) <=> (not p or q) 90 | accuracy = torch.sum(torch.logical_or(torch.logical_not(predictions[mask]), y2[mask])) / len(y2[mask]) 91 | accuracy = accuracy.item() 92 | else: 93 | # material biconditional: (p<=>q) <=> (p and q) or (not p and not q) 94 | accuracy = accuracy_score(class_predictions_filtered_by_pred[mask], y2[mask]) 95 | return accuracy, class_predictions_filtered_by_pred 96 | 97 | 98 | def complexity(formula: str, to_dnf: bool = False) -> float: 99 | """ 100 | Estimates the complexity of the formula. 101 | 102 | :param formula: logic formula. 103 | :param to_dnf: whether to convert the formula in disjunctive normal form. 104 | :return: The complexity of the formula. 105 | """ 106 | if formula != "" and formula is not None: 107 | if to_dnf: 108 | formula = str(sympy.to_dnf(formula)) 109 | return np.array([len(f.split(' & ')) for f in formula.split(' | ')]).sum() 110 | return 0 111 | 112 | 113 | def concept_consistency(formula_list: List[str]) -> dict: 114 | """ 115 | Computes the frequency of concepts in a list of logic formulas. 116 | 117 | :param formula_list: list of logic formulas. 118 | :return: Frequency of concepts. 119 | """ 120 | concept_dict = _generate_consistency_dict(formula_list) 121 | return {k: v / len(formula_list) for k, v in concept_dict.items()} 122 | 123 | 124 | def formula_consistency(formula_list: List[str]) -> float: 125 | """ 126 | Computes the average frequency of concepts in a list of logic formulas. 127 | 128 | :param formula_list: list of logic formulas. 129 | :return: Average frequency of concepts. 130 | """ 131 | concept_dict = _generate_consistency_dict(formula_list) 132 | concept_consistency = np.array([c for c in concept_dict.values()]) / len(formula_list) 133 | return concept_consistency.mean() 134 | 135 | 136 | def _generate_consistency_dict(formula_list: List[str]) -> dict: 137 | concept_dict = {} 138 | for i, formula in enumerate(formula_list): 139 | concept_dict_i = {} 140 | for minterm_list in formula.split(' | '): 141 | for term in minterm_list.split(' & '): 142 | concept = term.replace('(', '').replace(')', '').replace('~', '') 143 | if concept in concept_dict_i: 144 | continue 145 | elif concept in concept_dict: 146 | concept_dict_i[concept] = 1 147 | concept_dict[concept] += 1 148 | else: 149 | concept_dict_i[concept] = 1 150 | concept_dict[concept] = 1 151 | return concept_dict 152 | -------------------------------------------------------------------------------- /torch_explain/logic/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pietrobarbiero/pytorch_explain/4986e0b75e38967bcdb6923c53872ceea5191760/torch_explain/logic/nn/__init__.py -------------------------------------------------------------------------------- /torch_explain/logic/nn/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def _collect_parameters(model: torch.nn.Module, 8 | device: torch.device = torch.device('cpu')) -> Tuple[List[np.ndarray], List[np.ndarray]]: 9 | """ 10 | Collect network parameters in two lists of numpy arrays. 11 | 12 | :param model: pytorch model 13 | :param device: cpu or cuda device 14 | :return: list of weights and list of biases 15 | """ 16 | weights, bias = [], [] 17 | for module in model.children(): 18 | if isinstance(module, torch.nn.Linear): 19 | if device.type == 'cpu': 20 | weights.append(module.weight.detach().numpy()) 21 | try: 22 | bias.append(module.bias.detach().numpy()) 23 | except: 24 | pass 25 | 26 | else: 27 | weights.append(module.weight.cpu().detach().numpy()) 28 | try: 29 | bias.append(module.bias.cpu().detach().numpy()) 30 | except: 31 | pass 32 | 33 | return weights, bias 34 | -------------------------------------------------------------------------------- /torch_explain/logic/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch 3 | from sympy import lambdify, sympify 4 | import copy 5 | 6 | 7 | 8 | def replace_names(explanation: str, concept_names: List[str]) -> str: 9 | """ 10 | Replace names of concepts in a formula. 11 | :param explanation: formula 12 | :param concept_names: new concept names 13 | :return: Formula with renamed concepts 14 | """ 15 | feature_abbreviations = [f'feature{i:010}' for i in range(len(concept_names))] 16 | mapping = [] 17 | for f_abbr, f_name in zip(feature_abbreviations, concept_names): 18 | mapping.append((f_abbr, f_name)) 19 | 20 | for k, v in mapping: 21 | explanation = explanation.replace(k, v) 22 | 23 | return explanation 24 | 25 | 26 | def get_predictions(formula: str, x: torch.Tensor, threshold: float = 0.5): 27 | """ 28 | Tests a logic formula. 29 | :param formula: logic formula 30 | :param x: input data 31 | :param target_class: target class 32 | :return: Accuracy of the explanation and predictions 33 | """ 34 | 35 | if formula in ['True', 'False', ''] or formula is None: 36 | return None 37 | 38 | else: 39 | concept_list = [f"feature{i:010}" for i in range(x.shape[1])] 40 | # get predictions using sympy 41 | # explanation = to_dnf(formula) 42 | explanation = sympify(formula) 43 | fun = lambdify(concept_list, explanation, 'numpy') 44 | x = x.cpu().detach().numpy() 45 | predictions = fun(*[x[:, i] > threshold for i in range(x.shape[1])]) 46 | return predictions 47 | 48 | def get_the_good_and_bad_terms( 49 | model, c, edge_index, sample_pos, explanation, target_class, concept_names=None, threshold=0.5 50 | ): 51 | def perturb_inputs_rem(inputs, target): 52 | if threshold == 0.5: 53 | inputs[:, target] = 0.0 54 | elif threshold == 0.: 55 | inputs[:, target] = -1.0 56 | return inputs 57 | 58 | def perturb_inputs_add(inputs, target): 59 | # inputs[:, target] += inputs.sum(axis=1) / (inputs != 0).sum(axis=1) 60 | # inputs[:, target] += inputs.max(axis=1)[0] 61 | inputs[:, target] = 1 62 | # inputs[:, target] += 1 63 | return inputs 64 | 65 | explanation = explanation.split(" & ") 66 | 67 | good, bad = [], [] 68 | 69 | if edge_index is None: 70 | base = model(c)[sample_pos].view(1, -1) 71 | else: 72 | base = model(c, edge_index)[sample_pos].view(1, -1) 73 | 74 | for term in explanation: 75 | atom = term 76 | remove = True 77 | if atom[0] == "~": 78 | remove = False 79 | atom = atom[1:] 80 | 81 | if concept_names is not None: 82 | idx = concept_names.index(atom) 83 | else: 84 | idx = int(atom[len("feature") :]) 85 | temp_tensor = c[sample_pos].clone().detach().view(1, -1) 86 | temp_tensor = ( 87 | perturb_inputs_rem(temp_tensor, idx) 88 | if remove 89 | else perturb_inputs_add(temp_tensor, idx) 90 | ) 91 | c2 = copy.deepcopy(c) 92 | c2[sample_pos] = temp_tensor 93 | if edge_index is None: 94 | new_pred = model(c2)[sample_pos].view(1, -1) 95 | else: 96 | new_pred = model(c2, edge_index)[sample_pos].view(1, -1) 97 | 98 | if new_pred[:, target_class] >= base[:, target_class]: 99 | bad.append(term) 100 | else: 101 | good.append(term) 102 | del temp_tensor 103 | return good, bad 104 | -------------------------------------------------------------------------------- /torch_explain/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .logic import EntropyLinear 2 | from .concepts import ConceptEmbedding 3 | from . import functional 4 | 5 | __all__ = [ 6 | 'functional', 7 | 'EntropyLinear', 8 | 'ConceptEmbedding', 9 | ] 10 | -------------------------------------------------------------------------------- /torch_explain/nn/concepts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import Counter 3 | 4 | from .semantics import Logic, GodelTNorm 5 | 6 | 7 | def softselect(values, temperature): 8 | softmax_scores = torch.log_softmax(values, dim=1) 9 | softscores = torch.sigmoid(softmax_scores - temperature * softmax_scores.mean(dim=1, keepdim=True)) 10 | return softscores 11 | 12 | 13 | class ConceptReasoningLayer(torch.nn.Module): 14 | def __init__(self, emb_size, n_classes, logic: Logic = GodelTNorm(), temperature: float = 100.): 15 | super().__init__() 16 | self.emb_size = emb_size 17 | self.n_classes = n_classes 18 | self.logic = logic 19 | self.filter_nn = torch.nn.Sequential( 20 | torch.nn.Linear(emb_size, emb_size), 21 | torch.nn.LeakyReLU(), 22 | torch.nn.Linear(emb_size, n_classes), 23 | ) 24 | self.sign_nn = torch.nn.Sequential( 25 | torch.nn.Linear(emb_size, emb_size), 26 | torch.nn.LeakyReLU(), 27 | torch.nn.Linear(emb_size, n_classes), 28 | ) 29 | self.temperature = temperature 30 | 31 | def forward(self, x, c, return_attn=False, sign_attn=None, filter_attn=None): 32 | values = c.unsqueeze(-1).repeat(1, 1, self.n_classes) 33 | 34 | if sign_attn is None: 35 | # compute attention scores to build logic sentence 36 | # each attention score will represent whether the concept should be active or not in the logic sentence 37 | sign_attn = torch.sigmoid(self.sign_nn(x)) 38 | 39 | # attention scores need to be aligned with predicted concept truth values (attn <-> values) 40 | # (not A or V) and (A or not V) <-> (A <-> V) 41 | sign_terms = self.logic.iff_pair(sign_attn, values) 42 | 43 | if filter_attn is None: 44 | # compute attention scores to identify only relevant concepts for each class 45 | filter_attn = softselect(self.filter_nn(x), self.temperature) 46 | 47 | # filter value 48 | # filtered implemented as "or(a, not b)", corresponding to "b -> a" 49 | filtered_values = self.logic.disj_pair(sign_terms, self.logic.neg(filter_attn)) 50 | 51 | # generate minterm 52 | preds = self.logic.conj(filtered_values, dim=1).squeeze(1).float() 53 | 54 | if return_attn: 55 | return preds, sign_attn, filter_attn 56 | else: 57 | return preds 58 | 59 | def explain(self, x, c, mode, concept_names=None, class_names=None, filter_attn=None): 60 | assert mode in ['local', 'global', 'exact'] 61 | 62 | if concept_names is None: 63 | concept_names = [f'c_{i}' for i in range(c.shape[1])] 64 | if class_names is None: 65 | class_names = [f'y_{i}' for i in range(self.n_classes)] 66 | 67 | # make a forward pass to get predictions and attention weights 68 | y_preds, sign_attn_mask, filter_attn_mask = self.forward(x, c, return_attn=True, filter_attn=filter_attn) 69 | 70 | explanations = [] 71 | all_class_explanations = {cn: [] for cn in class_names} 72 | for sample_idx in range(len(x)): 73 | prediction = y_preds[sample_idx] > 0.5 74 | active_classes = torch.argwhere(prediction).ravel() 75 | 76 | if len(active_classes) == 0: 77 | # if no class is active for this sample, then we cannot extract any explanation 78 | explanations.append({ 79 | 'class': -1, 80 | 'explanation': '', 81 | 'attention': [], 82 | }) 83 | else: 84 | # else we can extract an explanation for each active class! 85 | for target_class in active_classes: 86 | attentions = [] 87 | minterm = [] 88 | for concept_idx in range(len(concept_names)): 89 | c_pred = c[sample_idx, concept_idx] 90 | sign_attn = sign_attn_mask[sample_idx, concept_idx, target_class] 91 | filter_attn = filter_attn_mask[sample_idx, concept_idx, target_class] 92 | 93 | # we first check if the concept was relevant 94 | # a concept is relevant <-> the filter attention score is lower than the concept probability 95 | at_score = 0 96 | sign_terms = self.logic.iff_pair(sign_attn, c_pred).item() 97 | if self.logic.neg(filter_attn) < sign_terms: 98 | if sign_attn >= 0.5: 99 | # if the concept is relevant and the sign is positive we just take its attention score 100 | at_score = filter_attn.item() 101 | if mode == 'exact': 102 | minterm.append(f'{sign_terms:.3f} ({concept_names[concept_idx]})') 103 | else: 104 | minterm.append(f'{concept_names[concept_idx]}') 105 | else: 106 | # if the concept is relevant and the sign is positive we take (-1) * its attention score 107 | at_score = -filter_attn.item() 108 | if mode == 'exact': 109 | minterm.append(f'{sign_terms:.3f} (~{concept_names[concept_idx]})') 110 | else: 111 | minterm.append(f'~{concept_names[concept_idx]}') 112 | attentions.append(at_score) 113 | 114 | # add explanation to list 115 | target_class_name = class_names[target_class] 116 | minterm = ' & '.join(minterm) 117 | all_class_explanations[target_class_name].append(minterm) 118 | explanations.append({ 119 | 'sample-id': sample_idx, 120 | 'class': target_class_name, 121 | 'explanation': minterm, 122 | 'attention': attentions, 123 | }) 124 | 125 | if mode == 'global': 126 | # count most frequent explanations for each class 127 | explanations = [] 128 | for class_id, class_explanations in all_class_explanations.items(): 129 | explanation_count = Counter(class_explanations) 130 | for explanation, count in explanation_count.items(): 131 | explanations.append({ 132 | 'class': class_id, 133 | 'explanation': explanation, 134 | 'count': count, 135 | }) 136 | 137 | return explanations 138 | 139 | 140 | class ConceptEmbedding(torch.nn.Module): 141 | def __init__( 142 | self, 143 | in_features, 144 | n_concepts, 145 | emb_size, 146 | active_intervention_values=None, 147 | inactive_intervention_values=None, 148 | intervention_idxs=None, 149 | training_intervention_prob=0.25, 150 | ): 151 | super().__init__() 152 | self.emb_size = emb_size 153 | self.intervention_idxs = intervention_idxs 154 | self.training_intervention_prob = training_intervention_prob 155 | if self.training_intervention_prob != 0: 156 | self.ones = torch.ones(n_concepts) 157 | 158 | self.concept_context_generators = torch.nn.ModuleList() 159 | for i in range(n_concepts): 160 | self.concept_context_generators.append(torch.nn.Sequential( 161 | torch.nn.Linear(in_features, 2 * emb_size), 162 | torch.nn.LeakyReLU(), 163 | )) 164 | self.concept_prob_predictor = torch.nn.Sequential( 165 | torch.nn.Linear(2 * emb_size, 1), 166 | torch.nn.Sigmoid(), 167 | ) 168 | 169 | # And default values for interventions here 170 | if active_intervention_values is not None: 171 | self.active_intervention_values = torch.tensor( 172 | active_intervention_values 173 | ) 174 | else: 175 | self.active_intervention_values = torch.ones(n_concepts) 176 | if inactive_intervention_values is not None: 177 | self.inactive_intervention_values = torch.tensor( 178 | inactive_intervention_values 179 | ) 180 | else: 181 | self.inactive_intervention_values = torch.zeros(n_concepts) 182 | 183 | def _after_interventions( 184 | self, 185 | prob, 186 | concept_idx, 187 | intervention_idxs=None, 188 | c_true=None, 189 | train=False, 190 | ): 191 | if train and (self.training_intervention_prob != 0) and (intervention_idxs is None): 192 | # Then we will probabilistically intervene in some concepts 193 | mask = torch.bernoulli(self.ones * self.training_intervention_prob) 194 | intervention_idxs = torch.nonzero(mask).reshape(-1) 195 | if (c_true is None) or (intervention_idxs is None): 196 | return prob 197 | if concept_idx not in intervention_idxs: 198 | return prob 199 | return (c_true[:, concept_idx:concept_idx + 1] * self.active_intervention_values[concept_idx]) + \ 200 | ((c_true[:, concept_idx:concept_idx + 1] - 1) * -self.inactive_intervention_values[concept_idx]) 201 | 202 | def forward(self, x, intervention_idxs=None, c=None, train=False): 203 | c_emb_list, c_pred_list = [], [] 204 | # We give precendence to inference time interventions arguments 205 | used_int_idxs = intervention_idxs 206 | if used_int_idxs is None: 207 | used_int_idxs = self.intervention_idxs 208 | for i, context_gen in enumerate(self.concept_context_generators): 209 | context = context_gen(x) 210 | c_pred = self.concept_prob_predictor(context) 211 | c_pred_list.append(c_pred) 212 | # Time to check for interventions 213 | c_pred = self._after_interventions( 214 | prob=c_pred, 215 | concept_idx=i, 216 | intervention_idxs=used_int_idxs, 217 | c_true=c, 218 | train=train, 219 | ) 220 | 221 | context_pos = context[:, :self.emb_size] 222 | context_neg = context[:, self.emb_size:] 223 | c_emb = context_pos * c_pred + context_neg * (1 - c_pred) 224 | c_emb_list.append(c_emb.unsqueeze(1)) 225 | 226 | return torch.cat(c_emb_list, axis=1), torch.cat(c_pred_list, axis=1) 227 | -------------------------------------------------------------------------------- /torch_explain/nn/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss import l1_loss, entropy_logic_loss 2 | from .prune import prune_equal_fanin 3 | 4 | __all__ = [ 5 | 'entropy_logic_loss', 6 | 'l1_loss', 7 | 'prune_equal_fanin', 8 | ] 9 | -------------------------------------------------------------------------------- /torch_explain/nn/functional/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Linear 3 | 4 | from torch_explain.nn.logic import EntropyLinear 5 | 6 | 7 | def entropy_logic_loss(model: torch.nn.Module): 8 | """ 9 | Entropy loss function to get simple logic explanations. 10 | 11 | :param model: pytorch model. 12 | :return: entropy loss. 13 | """ 14 | loss = 0 15 | for module in model.children(): 16 | if isinstance(module, EntropyLinear): 17 | loss -= torch.sum(module.alpha * torch.log(module.alpha)) 18 | break 19 | return loss 20 | 21 | 22 | def l1_loss(model: torch.nn.Module): 23 | """ 24 | L1 loss function to get simple logic explanations. 25 | 26 | :param model: pytorch model. 27 | :return: L1 loss. 28 | """ 29 | loss = 0 30 | for module in model.children(): 31 | if isinstance(module, Linear): 32 | loss += torch.norm(module.weight, 1) 33 | break 34 | return loss 35 | -------------------------------------------------------------------------------- /torch_explain/nn/functional/prune.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.utils import prune 3 | 4 | 5 | def prune_equal_fanin(model: torch.nn.Module, epoch: int, prune_epoch: int, k: int = 2, 6 | device: torch.device = torch.device('cpu')) -> torch.nn.Module: 7 | """ 8 | Prune the linear layers of the network such that each neuron has the same fan-in. 9 | 10 | :param model: pytorch model. 11 | :param epoch: current training epoch. 12 | :param prune_epoch: training epoch when pruning needs to be applied. 13 | :param k: fan-in. 14 | :param device: cpu or cuda device. 15 | :return: Pruned model 16 | """ 17 | if epoch != prune_epoch: 18 | return model 19 | 20 | model.eval() 21 | for i, module in enumerate(model.children()): 22 | # prune only Linear layers 23 | if isinstance(module, torch.nn.Linear): 24 | # create mask 25 | mask = torch.ones(module.weight.shape) 26 | # identify weights with the lowest absolute values 27 | param_absneg = -torch.abs(module.weight) 28 | idx = torch.topk(param_absneg, k=param_absneg.shape[1] - k, dim=1)[1] 29 | for j in range(len(idx)): 30 | mask[j, idx[j]] = 0 31 | # prune 32 | mask = mask.to(device) 33 | prune.custom_from_mask(module, name="weight", mask=mask) 34 | # print(f"Pruned {k}/{module.weight.shape[1]} weights") 35 | 36 | return model 37 | -------------------------------------------------------------------------------- /torch_explain/nn/logic.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch import nn 6 | 7 | 8 | class EntropyLinear(nn.Module): 9 | """Applies a linear transformation to the incoming data: :math:`y = xA^T + b` 10 | """ 11 | 12 | def __init__(self, in_features: int, out_features: int, n_classes: int, temperature: float = 0.6, 13 | bias: bool = True, remove_attention: bool = False) -> None: 14 | super(EntropyLinear, self).__init__() 15 | self.in_features = in_features 16 | self.out_features = out_features 17 | self.n_classes = n_classes 18 | self.temperature = temperature 19 | self.alpha = None 20 | self.remove_attention = remove_attention 21 | self.weight = nn.Parameter(torch.Tensor(n_classes, out_features, in_features)) 22 | self.has_bias = bias 23 | if bias: 24 | self.bias = nn.Parameter(torch.Tensor(n_classes, 1, out_features)) 25 | else: 26 | self.register_parameter('bias', None) 27 | self.reset_parameters() 28 | 29 | def reset_parameters(self) -> None: 30 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 31 | if self.bias is not None: 32 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 33 | bound = 1 / math.sqrt(fan_in) 34 | nn.init.uniform_(self.bias, -bound, bound) 35 | 36 | def forward(self, input: Tensor) -> Tensor: 37 | if len(input.shape) == 2: 38 | input = input.unsqueeze(0) 39 | # compute concept-awareness scores 40 | gamma = self.weight.norm(dim=1, p=1) 41 | self.alpha = torch.exp(gamma/self.temperature) / torch.sum(torch.exp(gamma/self.temperature), dim=1, keepdim=True) 42 | 43 | # weight the input concepts by awareness scores 44 | self.alpha_norm = self.alpha / self.alpha.max(dim=1)[0].unsqueeze(1) 45 | if self.remove_attention: 46 | self.concept_mask = torch.ones_like(self.alpha_norm, dtype=torch.bool) 47 | x = input 48 | else: 49 | self.concept_mask = self.alpha_norm > 0.5 50 | x = input.multiply(self.alpha_norm.unsqueeze(1)) 51 | 52 | # compute linear map 53 | x = x.matmul(self.weight.permute(0, 2, 1)) 54 | if self.has_bias: 55 | x += self.bias 56 | return x.permute(1, 0, 2) 57 | 58 | def extra_repr(self) -> str: 59 | return 'in_features={}, out_features={}, n_classes={}'.format( 60 | self.in_features, self.out_features, self.n_classes 61 | ) 62 | -------------------------------------------------------------------------------- /torch_explain/nn/semantics.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | class Logic: 7 | @abc.abstractmethod 8 | def update(self): 9 | raise NotImplementedError 10 | 11 | @abc.abstractmethod 12 | def conj(self, a, dim=1): 13 | raise NotImplementedError 14 | 15 | @abc.abstractmethod 16 | def disj(self, a, dim=1): 17 | raise NotImplementedError 18 | 19 | def conj_pair(self, a, b): 20 | raise NotImplementedError 21 | 22 | def disj_pair(self, a, b): 23 | raise NotImplementedError 24 | 25 | def iff_pair(self, a, b): 26 | raise NotImplementedError 27 | 28 | @abc.abstractmethod 29 | def neg(self, a): 30 | raise NotImplementedError 31 | 32 | 33 | class ProductTNorm(Logic): 34 | def __init__(self): 35 | super(ProductTNorm, self).__init__() 36 | self.current_truth = torch.tensor(1) 37 | self.current_false = torch.tensor(0) 38 | 39 | def update(self): 40 | pass 41 | 42 | def conj(self, a, dim=1): 43 | return torch.prod(a, dim=dim, keepdim=True) 44 | 45 | def conj_pair(self, a, b): 46 | return a * b 47 | 48 | def disj(self, a, dim=1): 49 | return 1 - torch.prod(1 - a, dim=dim, keepdim=True) 50 | 51 | def disj_pair(self, a, b): 52 | return a + b - a * b 53 | 54 | def iff_pair(self, a, b): 55 | return self.conj_pair(self.disj_pair(self.neg(a), b), self.disj_pair(a, self.neg(b))) 56 | 57 | def neg(self, a): 58 | return 1 - a 59 | 60 | def predict_proba(self, a): 61 | return a.squeeze(-1) 62 | 63 | 64 | class GodelTNorm(Logic): 65 | def __init__(self): 66 | super(GodelTNorm, self).__init__() 67 | self.current_truth = 1 68 | self.current_false = 0 69 | 70 | def update(self): 71 | pass 72 | 73 | def conj(self, a,dim=1): 74 | return torch.min(a, dim=dim, keepdim=True)[0] 75 | 76 | def disj(self, a, dim=1): 77 | return torch.max(a, dim=dim, keepdim=True)[0] 78 | 79 | def conj_pair(self, a, b): 80 | return torch.minimum(a, b) 81 | 82 | def disj_pair(self, a, b): 83 | return torch.maximum(a, b) 84 | 85 | def iff_pair(self, a, b): 86 | return self.conj_pair(self.disj_pair(self.neg(a), b), self.disj_pair(a, self.neg(b))) 87 | 88 | def neg(self, a): 89 | return 1 - a 90 | 91 | def predict_proba(self, a): 92 | return a.squeeze(-1) 93 | --------------------------------------------------------------------------------