├── tests ├── __init__.py └── test_template.py ├── MANIFEST.in ├── reasoning ├── _version.py ├── __init__.py ├── _utils.py └── _fol_extractor.py ├── template ├── _version.py ├── _template.py └── __init__.py ├── requirements.txt ├── environment.yml ├── img ├── learning_of_constraints.png └── learning_with_constraints.png ├── setup.cfg ├── .readthedocs.yml ├── doc ├── _templates │ ├── function.rst │ ├── numpydoc_docstring.py │ └── class.rst ├── _static │ ├── css │ │ └── project-template.css │ └── js │ │ └── copybutton.js ├── api.rst ├── index.rst ├── conf.py ├── quick_start.rst ├── user_guide.rst ├── Makefile └── make.bat ├── .coveragerc ├── .gitignore ├── .travis.yml ├── appveyor.yml ├── .circleci └── config.yml ├── setup.py ├── README.rst ├── notebooks ├── learning_of_constraints_digits.py ├── intoCNF.py └── intoCNF_with_prints.py └── LICENSE /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | -------------------------------------------------------------------------------- /reasoning/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.0' -------------------------------------------------------------------------------- /template/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.3" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | scikit-learn 4 | pandas 5 | -------------------------------------------------------------------------------- /template/_template.py: -------------------------------------------------------------------------------- 1 | 2 | class TemplateObject: 3 | 4 | def __init__(self): 5 | return 6 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: project-template 2 | dependencies: 3 | - numpy 4 | - scipy 5 | - scikit-learn 6 | - pandas 7 | -------------------------------------------------------------------------------- /img/learning_of_constraints.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pietrobarbiero/constraint-learning/HEAD/img/learning_of_constraints.png -------------------------------------------------------------------------------- /img/learning_with_constraints.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pietrobarbiero/constraint-learning/HEAD/img/learning_with_constraints.png -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.rst 3 | 4 | [aliases] 5 | test = pytest 6 | 7 | [tool:pytest] 8 | addopts = --doctest-modules 9 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | formats: 2 | - none 3 | requirements_file: requirements.txt 4 | python: 5 | pip_install: true 6 | extra_requirements: 7 | - tests 8 | - docs 9 | -------------------------------------------------------------------------------- /template/__init__.py: -------------------------------------------------------------------------------- 1 | from ._template import TemplateObject 2 | 3 | from ._version import __version__ 4 | 5 | __all__ = [ 6 | 'TemplateObject', 7 | '__version__' 8 | ] 9 | -------------------------------------------------------------------------------- /reasoning/__init__.py: -------------------------------------------------------------------------------- 1 | from ._fol_extractor import generate_fol_explanations 2 | from ._version import __version__ 3 | 4 | __all__ = [ 5 | 'generate_fol_explanations', 6 | '__version__', 7 | ] -------------------------------------------------------------------------------- /doc/_templates/function.rst: -------------------------------------------------------------------------------- 1 | :mod:`{{module}}`.{{objname}} 2 | {{ underline }}==================== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autofunction:: {{ objname }} 7 | 8 | .. include:: {{module}}.{{objname}}.examples 9 | 10 | .. raw:: html 11 | 12 |
13 | -------------------------------------------------------------------------------- /doc/_templates/numpydoc_docstring.py: -------------------------------------------------------------------------------- 1 | {{index}} 2 | {{summary}} 3 | {{extended_summary}} 4 | {{parameters}} 5 | {{returns}} 6 | {{yields}} 7 | {{other_parameters}} 8 | {{attributes}} 9 | {{raises}} 10 | {{warns}} 11 | {{warnings}} 12 | {{see_also}} 13 | {{notes}} 14 | {{references}} 15 | {{examples}} 16 | {{methods}} 17 | -------------------------------------------------------------------------------- /tests/test_template.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestTemplateObject(unittest.TestCase): 5 | def test_object(self): 6 | import template 7 | 8 | t = template.TemplateObject() 9 | self.assertTrue(isinstance(t, template.TemplateObject)) 10 | return 11 | 12 | 13 | if __name__ == '__main__': 14 | unittest.main() 15 | -------------------------------------------------------------------------------- /doc/_static/css/project-template.css: -------------------------------------------------------------------------------- 1 | @import url("theme.css"); 2 | 3 | .highlight a { 4 | text-decoration: underline; 5 | } 6 | 7 | .deprecated p { 8 | padding: 10px 7px 10px 10px; 9 | color: #b94a48; 10 | background-color: #F3E5E5; 11 | border: 1px solid #eed3d7; 12 | } 13 | 14 | .deprecated p span.versionmodified { 15 | font-weight: bold; 16 | } 17 | -------------------------------------------------------------------------------- /doc/_templates/class.rst: -------------------------------------------------------------------------------- 1 | :mod:`{{module}}`.{{objname}} 2 | {{ underline }}============== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autoclass:: {{ objname }} 7 | 8 | {% block methods %} 9 | .. automethod:: __init__ 10 | {% endblock %} 11 | 12 | .. include:: {{module}}.{{objname}}.examples 13 | 14 | .. raw:: html 15 | 16 |
17 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | # Configuration for coverage.py 2 | 3 | [run] 4 | branch = True 5 | source = template 6 | include = */template/* 7 | omit = 8 | */setup.py 9 | 10 | [report] 11 | exclude_lines = 12 | pragma: no cover 13 | def __repr__ 14 | if self.debug: 15 | if settings.DEBUG 16 | raise AssertionError 17 | raise NotImplementedError 18 | if 0: 19 | if __name__ == .__main__.: 20 | if self.verbose: 21 | show_missing = True -------------------------------------------------------------------------------- /doc/api.rst: -------------------------------------------------------------------------------- 1 | #################### 2 | project-template API 3 | #################### 4 | 5 | This is an example on how to document the API of your own project. 6 | 7 | .. currentmodule:: skltemplate 8 | 9 | Estimator 10 | ========= 11 | 12 | .. autosummary:: 13 | :toctree: generated/ 14 | :template: class.rst 15 | 16 | TemplateEstimator 17 | 18 | Transformer 19 | =========== 20 | 21 | .. autosummary:: 22 | :toctree: generated/ 23 | :template: class.rst 24 | 25 | TemplateTransformer 26 | 27 | Predictor 28 | ========= 29 | 30 | .. autosummary:: 31 | :toctree: generated/ 32 | :template: class.rst 33 | 34 | TemplateClassifier 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # scikit-learn specific 10 | doc/_build/ 11 | doc/auto_examples/ 12 | doc/modules/generated/ 13 | doc/datasets/generated/ 14 | 15 | # Distribution / packaging 16 | 17 | .Python 18 | env/ 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *,cover 53 | .hypothesis/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | 62 | # Sphinx documentation 63 | doc/_build/ 64 | doc/generated/ 65 | 66 | # PyBuilder 67 | target/ 68 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: trusty 2 | sudo: false 3 | 4 | language: python 5 | 6 | cache: 7 | directories: 8 | - $HOME/.cache/pip 9 | 10 | matrix: 11 | include: 12 | - env: PYTHON_VERSION="3.7" NUMPY_VERSION="*" SCIPY_VERSION="*" 13 | PANDAS_VERSION="*" 14 | 15 | install: 16 | # install miniconda 17 | - deactivate 18 | - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh 19 | - MINICONDA_PATH=/home/travis/miniconda 20 | - chmod +x miniconda.sh && ./miniconda.sh -b -p $MINICONDA_PATH 21 | - export PATH=$MINICONDA_PATH/bin:$PATH 22 | - conda update --yes conda 23 | # create the testing environment 24 | - conda create -n testenv --yes python=$PYTHON_VERSION pip 25 | - source activate testenv 26 | - conda install --yes numpy==$NUMPY_VERSION scipy==$SCIPY_VERSION pandas=$PANDAS_VERSION cython nose pytest pytest-cov 27 | - pip install codecov 28 | - pip install . 29 | 30 | #script: 31 | # - mkdir for_test 32 | # - cd for_test 33 | # - pytest -v --cov=dbgen --pyargs dbgen 34 | script: 35 | - coverage run -m unittest || python3 -m unittest || python -m unittest 36 | 37 | after_success: 38 | - codecov 39 | -------------------------------------------------------------------------------- /appveyor.yml: -------------------------------------------------------------------------------- 1 | build: false 2 | 3 | environment: 4 | matrix: 5 | - PYTHON: "C:\\Miniconda3-x64" 6 | PYTHON_VERSION: "3.5.x" 7 | PYTHON_ARCH: "32" 8 | NUMPY_VERSION: "1.13.1" 9 | SCIPY_VERSION: "0.19.1" 10 | SKLEARN_VERSION: "0.19.1" 11 | 12 | - PYTHON: "C:\\Miniconda3-x64" 13 | PYTHON_VERSION: "3.6.x" 14 | PYTHON_ARCH: "64" 15 | NUMPY_VERSION: "*" 16 | SCIPY_VERSION: "*" 17 | SKLEARN_VERSION: "*" 18 | 19 | install: 20 | # Prepend miniconda installed Python to the PATH of this build 21 | # Add Library/bin directory to fix issue 22 | # https://github.com/conda/conda/issues/1753 23 | - "SET PATH=%PYTHON%;%PYTHON%\\Scripts;%PYTHON%\\Library\\bin;%PATH%" 24 | # install the dependencies 25 | - "conda install --yes -c conda-forge pip numpy==%NUMPY_VERSION% scipy==%SCIPY_VERSION% scikit-learn==%SKLEARN_VERSION%" 26 | - pip install codecov nose pytest pytest-cov 27 | - pip install . 28 | 29 | test_script: 30 | - mkdir for_test 31 | - cd for_test 32 | - pytest -v --cov=template --pyargs template 33 | 34 | after_test: 35 | - cp .coverage %APPVEYOR_BUILD_FOLDER% 36 | - cd %APPVEYOR_BUILD_FOLDER% 37 | - codecov 38 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | jobs: 4 | build: 5 | docker: 6 | - image: circleci/python:3.6.1 7 | working_directory: ~/repo 8 | steps: 9 | - checkout 10 | - run: 11 | name: install dependencies 12 | command: | 13 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh 14 | chmod +x miniconda.sh && ./miniconda.sh -b -p ~/miniconda 15 | export PATH="~/miniconda/bin:$PATH" 16 | conda update --yes --quiet conda 17 | conda create -n testenv --yes --quiet python=3 18 | source activate testenv 19 | conda install --yes pip numpy scipy scikit-learn matplotlib sphinx sphinx_rtd_theme numpydoc pillow 20 | pip install sphinx-gallery 21 | pip install . 22 | cd doc 23 | make html 24 | - store_artifacts: 25 | path: doc/_build/html/ 26 | destination: doc 27 | - store_artifacts: 28 | path: ~/log.txt 29 | - run: ls -ltrh doc/_build/html 30 | filters: 31 | branches: 32 | ignore: gh-pages 33 | 34 | workflows: 35 | version: 2 36 | workflow: 37 | jobs: 38 | - build 39 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | .. project-template documentation master file, created by 2 | sphinx-quickstart on Mon Jan 18 14:44:12 2016. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to sklearn-template's documentation! 7 | ============================================ 8 | 9 | This project is a reference implementation to anyone who wishes to develop 10 | scikit-learn compatible classes. 11 | 12 | .. toctree:: 13 | :maxdepth: 2 14 | :hidden: 15 | :caption: Getting Started 16 | 17 | quick_start 18 | 19 | .. toctree:: 20 | :maxdepth: 2 21 | :hidden: 22 | :caption: Documentation 23 | 24 | user_guide 25 | api 26 | 27 | .. toctree:: 28 | :maxdepth: 2 29 | :hidden: 30 | :caption: Tutorial - Examples 31 | 32 | auto_examples/index 33 | 34 | `Getting started `_ 35 | ------------------------------------- 36 | 37 | Information regarding this template and how to modify it for your own project. 38 | 39 | `User Guide `_ 40 | ------------------------------- 41 | 42 | An example of narrative documentation. 43 | 44 | `API Documentation `_ 45 | ------------------------------- 46 | 47 | An example of API documentation. 48 | 49 | `Examples `_ 50 | -------------------------------------- 51 | 52 | A set of examples. It complements the `User Guide `_. -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | """A template.""" 3 | 4 | import codecs 5 | import os 6 | 7 | from setuptools import find_packages, setup 8 | 9 | # get __version__ from _version.py 10 | ver_file = os.path.join('template', '_version.py') 11 | with open(ver_file) as f: 12 | exec(f.read()) 13 | 14 | DISTNAME = 'project-template' 15 | DESCRIPTION = 'A template.' 16 | with codecs.open('README.rst', encoding='utf-8-sig') as f: 17 | LONG_DESCRIPTION = f.read() 18 | MAINTAINER = 'P. Barbiero' 19 | MAINTAINER_EMAIL = 'barbiero@tutanota.com' 20 | URL = 'https://github.com/pietrobarbiero/project-template' 21 | LICENSE = 'Apache 2.0' 22 | DOWNLOAD_URL = 'https://github.com/pietrobarbiero/project-template' 23 | VERSION = __version__ 24 | INSTALL_REQUIRES = ['numpy', 'scipy', 'scikit-learn', 'pandas'] 25 | CLASSIFIERS = ['Intended Audience :: Science/Research', 26 | 'Intended Audience :: Developers', 27 | 'License :: OSI Approved', 28 | 'Programming Language :: Python', 29 | 'Topic :: Software Development', 30 | 'Topic :: Scientific/Engineering', 31 | 'Operating System :: Microsoft :: Windows', 32 | 'Operating System :: POSIX', 33 | 'Operating System :: Unix', 34 | 'Operating System :: MacOS', 35 | 'Programming Language :: Python :: 2.7', 36 | 'Programming Language :: Python :: 3.5', 37 | 'Programming Language :: Python :: 3.6', 38 | 'Programming Language :: Python :: 3.7'] 39 | EXTRAS_REQUIRE = { 40 | 'tests': [ 41 | 'pytest', 42 | 'pytest-cov'], 43 | 'docs': [ 44 | 'sphinx', 45 | 'sphinx-gallery', 46 | 'sphinx_rtd_theme', 47 | 'numpydoc', 48 | 'matplotlib' 49 | ] 50 | } 51 | 52 | setup(name=DISTNAME, 53 | maintainer=MAINTAINER, 54 | maintainer_email=MAINTAINER_EMAIL, 55 | description=DESCRIPTION, 56 | license=LICENSE, 57 | url=URL, 58 | version=VERSION, 59 | download_url=DOWNLOAD_URL, 60 | long_description=LONG_DESCRIPTION, 61 | zip_safe=False, # the package can run out of an .egg file 62 | classifiers=CLASSIFIERS, 63 | packages=find_packages(), 64 | install_requires=INSTALL_REQUIRES, 65 | extras_require=EXTRAS_REQUIRE) 66 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | import os 17 | import sys 18 | 19 | sys.path.insert(0, os.path.abspath('../')) 20 | 21 | # -- Project information ----------------------------------------------------- 22 | 23 | project = 'Project Template' 24 | copyright = '2020, Pietro Barbiero' 25 | author = 'Pietro Barbiero' 26 | 27 | 28 | # -- General configuration --------------------------------------------------- 29 | 30 | master_doc = 'index' 31 | 32 | # Add any Sphinx extension module names here, as strings. They can be 33 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 34 | # ones. 35 | extensions = ['sphinx.ext.autodoc', 'sphinx.ext.coverage', 'sphinx_rtd_theme'] 36 | 37 | # Add any paths that contain templates here, relative to this directory. 38 | templates_path = ['_templates'] 39 | 40 | # List of patterns, relative to source directory, that match files and 41 | # directories to ignore when looking for source files. 42 | # This pattern also affects html_static_path and html_extra_path. 43 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 44 | 45 | 46 | # -- Options for HTML output ------------------------------------------------- 47 | 48 | # The theme to use for HTML and HTML Help pages. See the documentation for 49 | # a list of builtin themes. 50 | # 51 | # html_theme = 'alabaster' 52 | html_theme = "sphinx_rtd_theme" 53 | 54 | html_theme_options = { 55 | 'canonical_url': 'https://dbgen.readthedocs.io/en/latest/', 56 | 'logo_only': False, 57 | 'display_version': True, 58 | 'prev_next_buttons_location': 'bottom', 59 | 'style_external_links': False, 60 | # Toc options 61 | 'collapse_navigation': False, 62 | 'sticky_navigation': True, 63 | 'navigation_depth': 4, 64 | 'includehidden': True, 65 | 'titles_only': False 66 | } 67 | 68 | 69 | # Add any paths that contain custom static files (such as style sheets) here, 70 | # relative to this directory. They are copied after the builtin static files, 71 | # so a file named "default.css" will overwrite the builtin "default.css". 72 | html_static_path = ['_static'] -------------------------------------------------------------------------------- /reasoning/_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from itertools import product 3 | 4 | 5 | def forward(X, weights, bias): 6 | """ 7 | Simulate the forward pass on one layer. 8 | 9 | :param X: input matrix. 10 | :param weights: weight matrix. 11 | :param bias: bias vector. 12 | :return: 13 | """ 14 | a = np.matmul(weights, np.transpose(X)) 15 | b = np.reshape(np.repeat(bias, np.shape(X)[0], axis=0), np.shape(a)) 16 | output = sigmoid_activation(a + b) 17 | y_pred = np.where(output < 0.5, 0, 1) 18 | return y_pred 19 | 20 | 21 | 22 | def get_nonpruned_weights(weight_matrix, fan_in): 23 | """ 24 | Get non-pruned weights. 25 | 26 | :param weight_matrix: weight matrix of the reasoning network; shape: $h_{i+1} \times h_{i}$. 27 | :param fan_in: number of incoming active weights for each neuron in the network. 28 | :return: 29 | """ 30 | n_neurons = weight_matrix.shape[0] 31 | weights_active = np.zeros((n_neurons, fan_in)) 32 | for i in range(n_neurons): 33 | nonpruned_positions = np.nonzero(weight_matrix[i]) 34 | weights_active[i] = (weight_matrix)[i, nonpruned_positions] 35 | return weights_active 36 | 37 | 38 | def build_truth_table(fan_in): 39 | """ 40 | Build the truth table taking into account non-pruned features only, 41 | 42 | :param fan_in: number of incoming active weights for each neuron in the network. 43 | :return: 44 | """ 45 | items = [] 46 | for i in range(fan_in): 47 | items.append([0, 1]) 48 | truth_table = list(product(*items)) 49 | return np.array(truth_table) 50 | 51 | 52 | def get_nonpruned_positions(weights, neuron_list): 53 | """ 54 | Get the list of the position of non-pruned weights. 55 | 56 | :param weights: list of the weight matrices of the reasoning network; shape: $h_{i+1} \times h_{i}$. 57 | :param neuron_list: list containing the number of neurons for each layer of the network. 58 | :return: 59 | """ 60 | nonpruned_positions = [] 61 | for j in range(len(weights)): 62 | non_pruned_position_layer_j = [] 63 | for i in range(neuron_list[j]): 64 | non_pruned_position_layer_j.append(np.nonzero(weights[j][i])) 65 | nonpruned_positions.append(non_pruned_position_layer_j) 66 | 67 | return nonpruned_positions 68 | 69 | 70 | def count_neurons(weights): 71 | """ 72 | Count the number of neurons for each layer of the neural network. 73 | 74 | :param weights: list of the weight matrices of the reasoning network; shape: $h_{i+1} \times h_{i}$. 75 | :return: 76 | """ 77 | n_layers = len(weights) 78 | neuron_list = np.zeros(n_layers, dtype=int) 79 | for j in range(n_layers): 80 | # for each layer of weights, 81 | # get the shape of the weight matrix (number of output neurons) 82 | neuron_list[j] = np.shape(weights[j])[0] 83 | return neuron_list 84 | 85 | 86 | def sigmoid_activation(x): 87 | return (1 / (1 + np.exp(-x))) 88 | -------------------------------------------------------------------------------- /doc/_static/js/copybutton.js: -------------------------------------------------------------------------------- 1 | $(document).ready(function() { 2 | /* Add a [>>>] button on the top-right corner of code samples to hide 3 | * the >>> and ... prompts and the output and thus make the code 4 | * copyable. */ 5 | var div = $('.highlight-python .highlight,' + 6 | '.highlight-python3 .highlight,' + 7 | '.highlight-pycon .highlight,' + 8 | '.highlight-default .highlight') 9 | var pre = div.find('pre'); 10 | 11 | // get the styles from the current theme 12 | pre.parent().parent().css('position', 'relative'); 13 | var hide_text = 'Hide the prompts and output'; 14 | var show_text = 'Show the prompts and output'; 15 | var border_width = pre.css('border-top-width'); 16 | var border_style = pre.css('border-top-style'); 17 | var border_color = pre.css('border-top-color'); 18 | var button_styles = { 19 | 'cursor':'pointer', 'position': 'absolute', 'top': '0', 'right': '0', 20 | 'border-color': border_color, 'border-style': border_style, 21 | 'border-width': border_width, 'color': border_color, 'text-size': '75%', 22 | 'font-family': 'monospace', 'padding-left': '0.2em', 'padding-right': '0.2em', 23 | 'border-radius': '0 3px 0 0' 24 | } 25 | 26 | // create and add the button to all the code blocks that contain >>> 27 | div.each(function(index) { 28 | var jthis = $(this); 29 | if (jthis.find('.gp').length > 0) { 30 | var button = $('>>>'); 31 | button.css(button_styles) 32 | button.attr('title', hide_text); 33 | button.data('hidden', 'false'); 34 | jthis.prepend(button); 35 | } 36 | // tracebacks (.gt) contain bare text elements that need to be 37 | // wrapped in a span to work with .nextUntil() (see later) 38 | jthis.find('pre:has(.gt)').contents().filter(function() { 39 | return ((this.nodeType == 3) && (this.data.trim().length > 0)); 40 | }).wrap(''); 41 | }); 42 | 43 | // define the behavior of the button when it's clicked 44 | $('.copybutton').click(function(e){ 45 | e.preventDefault(); 46 | var button = $(this); 47 | if (button.data('hidden') === 'false') { 48 | // hide the code output 49 | button.parent().find('.go, .gp, .gt').hide(); 50 | button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'hidden'); 51 | button.css('text-decoration', 'line-through'); 52 | button.attr('title', show_text); 53 | button.data('hidden', 'true'); 54 | } else { 55 | // show the code output 56 | button.parent().find('.go, .gp, .gt').show(); 57 | button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'visible'); 58 | button.css('text-decoration', 'none'); 59 | button.attr('title', hide_text); 60 | button.data('hidden', 'false'); 61 | } 62 | }); 63 | }); 64 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | Constraint-based Learning with Neural Networks 2 | ================================================== 3 | |DOI| 4 | 5 | .. |DOI| image:: https://zenodo.org/badge/DOI/10.5281/zenodo.4244088.svg?style=for-the-badge 6 | :alt: Zenodo (.org) 7 | :target: https://doi.org/10.5281/zenodo.4244088 8 | 9 | 10 | Notebooks 11 | ---------- 12 | This repository contains two notebooks which will guide you step-by-step towards 13 | the implementation of learning of and with constraints in Pytorch. 14 | 15 | - `Learning with constraints `_: 16 | learn how to train a NN with human-driven constraints 17 | 18 | .. figure:: https://github.com/pietrobarbiero/constraint-learning/blob/master/img/learning_with_constraints.png 19 | :height: 200px 20 | 21 | - `Learning of constraints `_: 22 | learn how to make a NN learn how to explain its predictions with logic 23 | 24 | .. figure:: https://github.com/pietrobarbiero/constraint-learning/blob/master/img/learning_of_constraints.png 25 | :height: 200px 26 | 27 | Cite the notebooks! 28 | ********************* 29 | If you find this repository useful, please consider citing:: 30 | 31 | @misc{barbiero2020constraint, 32 | title={pietrobarbiero/constraint-learning: Absolutno}, 33 | DOI={10.5281/zenodo.4244088}, 34 | abstractNote={Constraint-based Learning with Neural Networks}, 35 | publisher={Zenodo}, 36 | author={Pietro Barbiero}, 37 | year={2020}, 38 | month={Nov} 39 | } 40 | 41 | 42 | Theory 43 | -------- 44 | Theoretical foundations can be found in the following papers. 45 | 46 | Learning of constraints:: 47 | 48 | @inproceedings{ciravegna2020constraint, 49 | title={A Constraint-Based Approach to Learning and Explanation.}, 50 | author={Ciravegna, Gabriele and Giannini, Francesco and Melacci, Stefano and Maggini, Marco and Gori, Marco}, 51 | booktitle={AAAI}, 52 | pages={3658--3665}, 53 | year={2020} 54 | } 55 | 56 | Learning with constraints:: 57 | 58 | @inproceedings{marra2019lyrics, 59 | title={LYRICS: A General Interface Layer to Integrate Logic Inference and Deep Learning}, 60 | author={Marra, Giuseppe and Giannini, Francesco and Diligenti, Michelangelo and Gori, Marco}, 61 | booktitle={Joint European Conference on Machine Learning and Knowledge Discovery in Databases}, 62 | pages={283--298}, 63 | year={2019}, 64 | organization={Springer} 65 | } 66 | 67 | Constraints theory in machine learning:: 68 | 69 | @book{gori2017machine, 70 | title={Machine Learning: A constraint-based approach}, 71 | author={Gori, Marco}, 72 | year={2017}, 73 | publisher={Morgan Kaufmann} 74 | } 75 | 76 | 77 | Authors 78 | ------- 79 | 80 | `Pietro Barbiero `__ 81 | 82 | Licence 83 | ------- 84 | 85 | Copyright 2020 Pietro Barbiero. 86 | 87 | Licensed under the Apache License, Version 2.0 (the "License"); you may 88 | not use this file except in compliance with the License. You may obtain 89 | a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0. 90 | 91 | Unless required by applicable law or agreed to in writing, software 92 | distributed under the License is distributed on an "AS IS" BASIS, 93 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 94 | 95 | See the License for the specific language governing permissions and 96 | limitations under the License. -------------------------------------------------------------------------------- /reasoning/_fol_extractor.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | from sympy import to_cnf 5 | 6 | from reasoning._utils import count_neurons, get_nonpruned_positions, \ 7 | build_truth_table, get_nonpruned_weights, forward 8 | 9 | 10 | def generate_fol_explanations(weights: List[np.array], bias: List[np.array]): 11 | """ 12 | Generate the FOL formulas corresponding to the parameters of a reasoning network. 13 | 14 | :param weights: list of the weight matrices of the reasoning network; shape: $h_{i+1} \times h_{i}$. 15 | :param bias: list of the bias vectors of the reasoning network; shape: $h_{i} \times 1$. 16 | :return: 17 | """ 18 | assert len(weights) == len(bias) 19 | 20 | # count number of layers of the reasoning network 21 | n_layers = len(weights) 22 | fan_in = np.count_nonzero((weights[0])[0, :]) 23 | n_features = np.shape(weights[0])[1] 24 | 25 | # create fancy feature names 26 | feature_names = list() 27 | for k in range(n_features): 28 | feature_names.append("f" + str(k + 1)) 29 | 30 | # count the number of hidden neurons for each layer 31 | neuron_list = count_neurons(weights) 32 | # get the position of non-pruned weights 33 | nonpruned_positions = get_nonpruned_positions(weights, neuron_list) 34 | 35 | # generate the query dataset, i.e. a truth table 36 | truth_table = build_truth_table(fan_in) 37 | 38 | # simulate a forward pass using non-pruned weights only 39 | predictions = list() 40 | for j in range(n_layers): 41 | weights_active = get_nonpruned_weights(weights[j], fan_in) 42 | y_pred = forward(truth_table, weights_active, bias[j]) 43 | predictions.append(y_pred) 44 | 45 | for j in range(n_layers): 46 | formulas = list() 47 | for i in range(neuron_list[j]): 48 | formula = _compute_fol_formula(truth_table, predictions[j][i], feature_names, nonpruned_positions[j][i][0]) 49 | formulas.append(f'({formula})') 50 | 51 | # the new feature names are the formulas we just computed 52 | feature_names = formulas 53 | return formulas 54 | 55 | 56 | def _compute_fol_formula(truth_table, predictions, feature_names, nonpruned_positions): 57 | """ 58 | Compute First Order Logic formulas. 59 | 60 | :param truth_table: input truth table. 61 | :param predictions: output predictions for the current neuron. 62 | :param feature_names: name of the input features. 63 | :param nonpruned_positions: position of non-pruned weights 64 | :return: 65 | """ 66 | # select the rows of the input truth table for which the output is true 67 | X = truth_table[np.nonzero(predictions)] 68 | 69 | # if the output is never true, then return false 70 | if np.shape(X)[0] == 0: return "False" 71 | 72 | # if the output is never false, then return true 73 | if np.shape(X)[0] == np.shape(truth_table)[0]: return "True" 74 | 75 | # compute the formula 76 | formula = '' 77 | n_rows, n_features = X.shape 78 | for i in range(n_rows): 79 | # if the formula is not empty, start appending an additional term 80 | if formula != '': 81 | formula = formula + "|" 82 | 83 | # open the bracket 84 | formula = formula + "(" 85 | for j in range(n_features): 86 | # get the name (column index) of the feature 87 | feature_name = feature_names[nonpruned_positions[j]] 88 | 89 | # if the feature is not active, 90 | # then the corresponding predicate is false, 91 | # then we need to negate the feature 92 | if X[i][j] == 0: 93 | formula += "~" 94 | 95 | # append the feature name 96 | formula += feature_name + "&" 97 | 98 | formula = formula[:-1] + ')' 99 | 100 | # replace "not True" with "False" and vice versa 101 | formula = formula.replace('~(True)', 'False') 102 | formula = formula.replace('~(False)', 'True') 103 | 104 | # simplify formula 105 | cnf_formula = to_cnf(formula, False) 106 | return str(cnf_formula) 107 | 108 | 109 | if __name__ == '__main__': 110 | w1 = np.array([[1, 0, 2, 0, 0], [1, 0, 3, 0, 0], [0, 1, 0, -1, 0]]) 111 | w2 = np.array([[1, 0, -2]]) 112 | b1 = [1, 0, -1] 113 | b2 = [1] 114 | 115 | w = [w1, w2] 116 | b = [b1, b2] 117 | 118 | f = generate_fol_explanations(w, b) 119 | print("Formula: ", f) 120 | -------------------------------------------------------------------------------- /doc/quick_start.rst: -------------------------------------------------------------------------------- 1 | ##################################### 2 | Quick Start with the project-template 3 | ##################################### 4 | 5 | This package serves as a skeleton package aiding at developing compatible 6 | scikit-learn contribution. 7 | 8 | Creating your own scikit-learn contribution package 9 | =================================================== 10 | 11 | 1. Download and setup your repository 12 | ------------------------------------- 13 | 14 | To create your package, you need to clone the ``project-template`` repository:: 15 | 16 | $ git clone https://github.com/scikit-learn-contrib/project-template.git 17 | 18 | Before to reinitialize your git repository, you need to make the following 19 | changes. Replace all occurrences of ``skltemplate`` and ``sklearn-template`` 20 | with the name of you own contribution. You can find all the occurrences using 21 | the following command:: 22 | 23 | $ git grep skltemplate 24 | $ git grep sklearn-template 25 | 26 | To remove the history of the template package, you need to remove the `.git` 27 | directory:: 28 | 29 | $ cd project-template 30 | $ rm -rf .git 31 | 32 | Then, you need to initialize your new git repository:: 33 | 34 | $ git init 35 | $ git add . 36 | $ git commit -m 'Initial commit' 37 | 38 | Finally, you create an online repository on GitHub and push your code online:: 39 | 40 | $ git remote add origin https://github.com/your_remote/your_contribution.git 41 | $ git push origin master 42 | 43 | 2. Develop your own scikit-learn estimators 44 | ------------------------------------------- 45 | 46 | .. _check_estimator: http://scikit-learn.org/stable/modules/generated/sklearn.utils.estimator_checks.check_estimator.html#sklearn.utils.estimator_checks.check_estimator 47 | .. _`Contributor's Guide`: http://scikit-learn.org/stable/developers/ 48 | .. _PEP8: https://www.python.org/dev/peps/pep-0008/ 49 | .. _PEP257: https://www.python.org/dev/peps/pep-0257/ 50 | .. _NumPyDoc: https://github.com/numpy/numpydoc 51 | .. _doctests: https://docs.python.org/3/library/doctest.html 52 | 53 | You can modify the source files as you want. However, your custom estimators 54 | need to pass the check_estimator_ test to be scikit-learn compatible. You can 55 | refer to the :ref:`User Guide ` to help you create a compatible 56 | scikit-learn estimator. 57 | 58 | In any case, developers should endeavor to adhere to scikit-learn's 59 | `Contributor's Guide`_ which promotes the use of: 60 | 61 | * algorithm-specific unit tests, in addition to ``check_estimator``'s common 62 | tests; 63 | * PEP8_-compliant code; 64 | * a clearly documented API using NumpyDoc_ and PEP257_-compliant docstrings; 65 | * references to relevant scientific literature in standard citation formats; 66 | * doctests_ to provide succinct usage examples; 67 | * standalone examples to illustrate the usage, model visualisation, and 68 | benefits/benchmarks of particular algorithms; 69 | * efficient code when the need for optimization is supported by benchmarks. 70 | 71 | 3. Edit the documentation 72 | ------------------------- 73 | 74 | .. _Sphinx: http://www.sphinx-doc.org/en/stable/ 75 | 76 | The documentation is created using Sphinx_. In addition, the examples are 77 | created using ``sphinx-gallery``. Therefore, to generate locally the 78 | documentation, you are required to install the following packages:: 79 | 80 | $ pip install sphinx sphinx-gallery sphinx_rtd_theme matplotlib numpydoc pillow 81 | 82 | The documentation is made of: 83 | 84 | * a home page, ``doc/index.rst``; 85 | * an API documentation, ``doc/api.rst`` in which you should add all public 86 | objects for which the docstring should be exposed publicly. 87 | * a User Guide documentation, ``doc/user_guide.rst``, containing the narrative 88 | documentation of your package, to give as much intuition as possible to your 89 | users. 90 | * examples which are created in the `examples/` folder. Each example 91 | illustrates some usage of the package. the example file name should start by 92 | `plot_*.py`. 93 | 94 | The documentation is built with the following commands:: 95 | 96 | $ cd doc 97 | $ make html 98 | 99 | 4. Setup the continuous integration 100 | ----------------------------------- 101 | 102 | The project template already contains configuration files of the continuous 103 | integration system. Basically, the following systems are set: 104 | 105 | * Travis CI is used to test the package in Linux. You need to activate Travis 106 | CI for your own repository. Refer to the Travis CI documentation. 107 | * AppVeyor is used to test the package in Windows. You need to activate 108 | AppVeyor for your own repository. Refer to the AppVeyor documentation. 109 | * Circle CI is used to check if the documentation is generated properly. You 110 | need to activate Circle CI for your own repository. Refer to the Circle CI 111 | documentation. 112 | * ReadTheDocs is used to build and host the documentation. You need to activate 113 | ReadTheDocs for your own repository. Refer to the ReadTheDocs documentation. 114 | * CodeCov for tracking the code coverage of the package. You need to activate 115 | CodeCov for you own repository. 116 | * PEP8Speaks for automatically checking the PEP8 compliance of your project for 117 | each Pull Request. 118 | 119 | Publish your package 120 | ==================== 121 | 122 | .. _PyPi: https://packaging.python.org/tutorials/packaging-projects/ 123 | .. _conda-foge: https://conda-forge.org/ 124 | 125 | You can make your package available through PyPi_ and conda-forge_. Refer to 126 | the associated documentation to be able to upload your packages such that 127 | it will be installable with ``pip`` and ``conda``. Once published, it will 128 | be possible to install your package with the following commands:: 129 | 130 | $ pip install your-scikit-learn-contribution 131 | $ conda install -c conda-forge your-scikit-learn-contribution 132 | -------------------------------------------------------------------------------- /notebooks/learning_of_constraints_digits.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.datasets import load_digits 4 | from sklearn.preprocessing import OneHotEncoder 5 | import seaborn as sns 6 | import matplotlib.pyplot as plt 7 | from sklearn.model_selection import train_test_split 8 | from itertools import product 9 | import pandas as pd 10 | import torch.nn.utils.prune as prune 11 | 12 | 13 | X, y = load_digits(return_X_y=True) 14 | print(f'X shape: {X.shape}\nClasses: {np.unique(y)}') 15 | 16 | enc = OneHotEncoder() 17 | y1h = enc.fit_transform(y.reshape(-1, 1)).toarray() 18 | print(f'Before: {y.shape}\nAfter: {y1h.shape}') 19 | 20 | y2 = np.zeros((len(y), 2)) 21 | for i, yi in enumerate(y): 22 | if yi % 2: 23 | y2[i, 0] = 1 24 | else: 25 | y2[i, 1] = 1 26 | y1h2 = np.hstack((y1h, y2)) 27 | 28 | print(f'Target vector shape: {y1h2.shape}') 29 | for i in range(10): 30 | print(f'Example ({y[i]}): {y1h2[i]}') 31 | 32 | X_train_np, X_test_np, y_train_np, y_test_np = train_test_split(X, y1h2, test_size=0.33, random_state=42) 33 | 34 | x_train = torch.FloatTensor(X_train_np) 35 | y_train = torch.FloatTensor(y_train_np) 36 | x_test = torch.FloatTensor(X_test_np) 37 | y_test = torch.FloatTensor(y_test_np) 38 | 39 | 40 | class FeedForwardNet(torch.nn.Module): 41 | def __init__(self, D_in, H, D_out): 42 | """ 43 | In the constructor we instantiate two nn.Linear modules and assign them as 44 | member variables. 45 | """ 46 | super(FeedForwardNet, self).__init__() 47 | self.linear1 = torch.nn.Linear(D_in, H) 48 | self.linear2 = torch.nn.Linear(H, H) 49 | self.linear3 = torch.nn.Linear(H, D_out) 50 | 51 | def forward(self, x): 52 | """ 53 | In the forward function we accept a Tensor of input data and we must return 54 | a Tensor of output data. We can use Modules defined in the constructor as 55 | well as arbitrary operators on Tensors. 56 | """ 57 | h = self.linear1(x) 58 | h = torch.nn.functional.relu(h) 59 | h = self.linear2(h) 60 | h = torch.nn.functional.relu(h) 61 | h = self.linear3(h) 62 | y_pred = torch.sigmoid(h) 63 | return y_pred 64 | 65 | 66 | din, dh, dout = x_train.shape[1], 20, y_train.shape[1] 67 | model = FeedForwardNet(din, dh, dout) 68 | 69 | print(model) 70 | loss = torch.nn.BCELoss() 71 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 72 | model.train() 73 | epoch = 2000 74 | for epoch in range(epoch): 75 | optimizer.zero_grad() 76 | 77 | # Forward pass 78 | y_pred = model(x_train) 79 | y_pred_np = y_pred.detach().numpy() 80 | 81 | # Compute Loss 82 | tot_loss = loss(y_pred, y_train) 83 | 84 | # compute accuracy 85 | y_pred_d = (y_pred > 0.5).detach().numpy() 86 | accuracy = ((y_pred_d == y_train_np).sum(axis=1) == y_train_np.shape[1]).mean() 87 | 88 | if epoch % 100 == 0: 89 | print(f'Epoch {epoch + 1}: ' 90 | f'total loss: {tot_loss.item():.4f} ' 91 | f'| accuracy: {accuracy:.4f} ') 92 | 93 | # Backward pass 94 | tot_loss.backward() 95 | optimizer.step() 96 | 97 | y_pred = model(x_test) 98 | 99 | # compute accuracy 100 | y_pred_round = (y_pred > 0.5).to(torch.float).detach().numpy() 101 | accuracy = ((y_pred_round == y_test_np).sum(axis=1) == y_test_np.shape[1]).mean() 102 | 103 | print(f'accuracy: {accuracy:.4f}') 104 | 105 | 106 | class ExplainEven(torch.nn.Module): 107 | def __init__(self, D_in): 108 | """ 109 | In the constructor we instantiate two nn.Linear modules and assign them as 110 | member variables. 111 | """ 112 | super(ExplainEven, self).__init__() 113 | self.linear1 = torch.nn.Linear(D_in, 10) 114 | self.linear2 = torch.nn.Linear(10, 1) 115 | 116 | def forward(self, x): 117 | """ 118 | In the forward function we accept a Tensor of input data and we must return 119 | a Tensor of output data. We can use Modules defined in the constructor as 120 | well as arbitrary operators on Tensors. 121 | """ 122 | h = self.linear1(x) 123 | h = self.linear2(h) 124 | y_pred = torch.sigmoid(h) 125 | return y_pred 126 | 127 | 128 | y_pred_train = model(x_train).detach().numpy().astype(float) 129 | y_pred_test = model(x_test).detach().numpy().astype(float) 130 | 131 | x_concepts_train_np, y_concepts_train_np = y_pred_train[:, :-1], y_pred_train[:, -1] 132 | x_concepts_test_np, y_concepts_test_np = y_pred_test[:, :-1], y_pred_test[:, -1] 133 | 134 | x_concepts_train, y_concepts_train = torch.FloatTensor(x_concepts_train_np), torch.FloatTensor(y_concepts_train_np) 135 | x_concepts_test, y_concepts_test = torch.FloatTensor(x_concepts_test_np), torch.FloatTensor(y_concepts_test_np) 136 | 137 | D_in = y_pred_train.shape[1] - 1 138 | even_net = ExplainEven(D_in) 139 | print(even_net) 140 | 141 | optimizer = torch.optim.Adam(even_net.parameters(), lr=0.01) 142 | even_net.train() 143 | epoch = 500 144 | accuracy = 0 145 | for epoch in range(epoch): 146 | optimizer.zero_grad() 147 | # Forward pass 148 | y_pred = even_net(x_concepts_train) 149 | y_pred_np = y_pred.detach().numpy() 150 | 151 | # Compute Loss 152 | tot_loss = loss(y_pred.squeeze(), y_concepts_train) + \ 153 | 0.08 * \ 154 | even_net.linear1.weight.norm(1) * \ 155 | even_net.linear2.weight.norm(1) 156 | 157 | # compute accuracy 158 | y_pred_d = (y_pred > 0.5).detach().numpy().ravel() 159 | accuracy = (y_pred_d == (y_concepts_train_np > 0.5)).mean() 160 | 161 | if epoch % 100 == 0: 162 | print(f'Epoch {epoch + 1}: ' 163 | f'total loss: {tot_loss.item():.4f} ' 164 | f'| accuracy: {accuracy:.4f} ') 165 | 166 | # Backward pass 167 | tot_loss.backward() 168 | optimizer.step() 169 | 170 | # Pruning 171 | for i, (module) in enumerate(even_net._modules.items()): 172 | mask = torch.ones(module[1].weight.shape) 173 | param_absneg = -torch.abs(module[1].weight) 174 | idx = torch.topk(param_absneg, k=param_absneg.shape[1] - 2)[1] 175 | for i in range(len(idx)): 176 | mask[i, idx[i]] = 0 177 | prune.custom_from_mask(module[1], name="weight", mask=mask) 178 | 179 | # Tuning 180 | for epoch in range(epoch): 181 | optimizer.zero_grad() 182 | # Forward pass 183 | y_pred = even_net(x_concepts_train) 184 | y_pred_np = y_pred.detach().numpy() 185 | 186 | # Compute Loss 187 | tot_loss = loss(y_pred.squeeze(), y_concepts_train) 188 | 189 | # compute accuracy 190 | y_pred_d = (y_pred > 0.5).detach().numpy().ravel() 191 | accuracy = (y_pred_d == (y_concepts_train_np > 0.5)).mean() 192 | 193 | if epoch % 100 == 0: 194 | print(f'Epoch {epoch + 1}: ' 195 | f'total loss: {tot_loss.item():.4f} ' 196 | f'| accuracy: {accuracy:.4f} ') 197 | 198 | # Backward pass 199 | tot_loss.backward() 200 | optimizer.step() 201 | 202 | weights, bias = [], [] 203 | for i, (module) in enumerate(even_net._modules.items()): 204 | weights.append(module[1].weight.detach().numpy()) 205 | bias.append(module[1].bias.detach().numpy()) 206 | 207 | 208 | from intoCNF_with_prints import booleanConstraint 209 | 210 | f = booleanConstraint(weights, bias) 211 | print(f) 212 | 213 | from sympy.logic import simplify_logic 214 | 215 | sf = simplify_logic(f[0]) 216 | print(sf) 217 | 218 | a = 1 219 | -------------------------------------------------------------------------------- /doc/user_guide.rst: -------------------------------------------------------------------------------- 1 | .. title:: User guide : contents 2 | 3 | .. _user_guide: 4 | 5 | ================================================== 6 | User guide: create your own scikit-learn estimator 7 | ================================================== 8 | 9 | Estimator 10 | --------- 11 | 12 | The central piece of transformer, regressor, and classifier is 13 | :class:`sklearn.base.BaseEstimator`. All estimators in scikit-learn are derived 14 | from this class. In more details, this base class enables to set and get 15 | parameters of the estimator. It can be imported as:: 16 | 17 | >>> from sklearn.base import BaseEstimator 18 | 19 | Once imported, you can create a class which inherate from this base class:: 20 | 21 | >>> class MyOwnEstimator(BaseEstimator): 22 | ... pass 23 | 24 | Transformer 25 | ----------- 26 | 27 | Transformers are scikit-learn estimators which implement a ``transform`` method. 28 | The use case is the following: 29 | 30 | * at ``fit``, some parameters can be learned from ``X`` and ``y``; 31 | * at ``transform``, `X` will be transformed, using the parameters learned 32 | during ``fit``. 33 | 34 | .. _mixin: https://en.wikipedia.org/wiki/Mixin 35 | 36 | In addition, scikit-learn provides a 37 | mixin_, i.e. :class:`sklearn.base.TransformerMixin`, which 38 | implement the combination of ``fit`` and ``transform`` called ``fit_transform``:: 39 | 40 | One can import the mixin class as:: 41 | 42 | >>> from sklearn.base import TransformerMixin 43 | 44 | Therefore, when creating a transformer, you need to create a class which 45 | inherits from both :class:`sklearn.base.BaseEstimator` and 46 | :class:`sklearn.base.TransformerMixin`. The scikit-learn API imposed ``fit`` to 47 | **return ``self``**. The reason is that it allows to pipeline ``fit`` and 48 | ``transform`` imposed by the :class:`sklearn.base.TransformerMixin`. The 49 | ``fit`` method is expected to have ``X`` and ``y`` as inputs. Note that 50 | ``transform`` takes only ``X`` as input and is expected to return the 51 | transformed version of ``X``:: 52 | 53 | >>> class MyOwnTransformer(BaseEstimator, TransformerMixin): 54 | ... def fit(self, X, y=None): 55 | ... return self 56 | ... def transform(self, X): 57 | ... return X 58 | 59 | We build a basic example to show that our :class:`MyOwnTransformer` is working 60 | within a scikit-learn ``pipeline``:: 61 | 62 | >>> from sklearn.datasets import load_iris 63 | >>> from sklearn.pipeline import make_pipeline 64 | >>> from sklearn.linear_model import LogisticRegression 65 | >>> X, y = load_iris(return_X_y=True) 66 | >>> pipe = make_pipeline(MyOwnTransformer(), 67 | ... LogisticRegression(random_state=10, 68 | ... solver='lbfgs')) 69 | >>> pipe.fit(X, y) # doctest: +ELLIPSIS 70 | Pipeline(...) 71 | >>> pipe.predict(X) # doctest: +ELLIPSIS 72 | array([...]) 73 | 74 | Predictor 75 | --------- 76 | 77 | Regressor 78 | ~~~~~~~~~ 79 | 80 | Similarly, regressors are scikit-learn estimators which implement a ``predict`` 81 | method. The use case is the following: 82 | 83 | * at ``fit``, some parameters can be learned from ``X`` and ``y``; 84 | * at ``predict``, predictions will be computed using ``X`` using the parameters 85 | learned during ``fit``. 86 | 87 | In addition, scikit-learn provides a mixin_, i.e. 88 | :class:`sklearn.base.RegressorMixin`, which implements the ``score`` method 89 | which computes the :math:`R^2` score of the predictions. 90 | 91 | One can import the mixin as:: 92 | 93 | >>> from sklearn.base import RegressorMixin 94 | 95 | Therefore, we create a regressor, :class:`MyOwnRegressor` which inherits from 96 | both :class:`sklearn.base.BaseEstimator` and 97 | :class:`sklearn.base.RegressorMixin`. The method ``fit`` gets ``X`` and ``y`` 98 | as input and should return ``self``. It should implement the ``predict`` 99 | function which should output the predictions of your regressor:: 100 | 101 | >>> import numpy as np 102 | >>> class MyOwnRegressor(BaseEstimator, RegressorMixin): 103 | ... def fit(self, X, y): 104 | ... return self 105 | ... def predict(self, X): 106 | ... return np.mean(X, axis=1) 107 | 108 | We illustrate that this regressor is working within a scikit-learn pipeline:: 109 | 110 | >>> from sklearn.datasets import load_diabetes 111 | >>> X, y = load_diabetes(return_X_y=True) 112 | >>> pipe = make_pipeline(MyOwnTransformer(), MyOwnRegressor()) 113 | >>> pipe.fit(X, y) # doctest: +ELLIPSIS 114 | Pipeline(...) 115 | >>> pipe.predict(X) # doctest: +ELLIPSIS 116 | array([...]) 117 | 118 | Since we inherit from the :class:`sklearn.base.RegressorMixin`, we can call 119 | the ``score`` method which will return the :math:`R^2` score:: 120 | 121 | >>> pipe.score(X, y) # doctest: +ELLIPSIS 122 | -3.9... 123 | 124 | Classifier 125 | ~~~~~~~~~~ 126 | 127 | Similarly to regressors, classifiers implement ``predict``. In addition, they 128 | output the probabilities of the prediction using the ``predict_proba`` method: 129 | 130 | * at ``fit``, some parameters can be learned from ``X`` and ``y``; 131 | * at ``predict``, predictions will be computed using ``X`` using the parameters 132 | learned during ``fit``. The output corresponds to the predicted class for each sample; 133 | * ``predict_proba`` will give a 2D matrix where each column corresponds to the 134 | class and each entry will be the probability of the associated class. 135 | 136 | In addition, scikit-learn provides a mixin, i.e. 137 | :class:`sklearn.base.ClassifierMixin`, which implements the ``score`` method 138 | which computes the accuracy score of the predictions. 139 | 140 | One can import this mixin as:: 141 | 142 | >>> from sklearn.base import ClassifierMixin 143 | 144 | Therefore, we create a classifier, :class:`MyOwnClassifier` which inherits 145 | from both :class:`slearn.base.BaseEstimator` and 146 | :class:`sklearn.base.ClassifierMixin`. The method ``fit`` gets ``X`` and ``y`` 147 | as input and should return ``self``. It should implement the ``predict`` 148 | function which should output the class inferred by the classifier. 149 | ``predict_proba`` will output some probabilities instead:: 150 | 151 | >>> class MyOwnClassifier(BaseEstimator, ClassifierMixin): 152 | ... def fit(self, X, y): 153 | ... self.classes_ = np.unique(y) 154 | ... return self 155 | ... def predict(self, X): 156 | ... return np.random.randint(0, self.classes_.size, 157 | ... size=X.shape[0]) 158 | ... def predict_proba(self, X): 159 | ... pred = np.random.rand(X.shape[0], self.classes_.size) 160 | ... return pred / np.sum(pred, axis=1)[:, np.newaxis] 161 | 162 | We illustrate that this regressor is working within a scikit-learn pipeline:: 163 | 164 | >>> X, y = load_iris(return_X_y=True) 165 | >>> pipe = make_pipeline(MyOwnTransformer(), MyOwnClassifier()) 166 | >>> pipe.fit(X, y) # doctest: +ELLIPSIS 167 | Pipeline(...) 168 | 169 | Then, you can call ``predict`` and ``predict_proba``:: 170 | 171 | >>> pipe.predict(X) # doctest: +ELLIPSIS 172 | array([...]) 173 | >>> pipe.predict_proba(X) # doctest: +ELLIPSIS 174 | array([...]) 175 | 176 | Since our classifier inherits from :class:`sklearn.base.ClassifierMixin`, we 177 | can compute the accuracy by calling the ``score`` method:: 178 | 179 | >>> pipe.score(X, y) # doctest: +ELLIPSIS 180 | 0... 181 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 21 | 22 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext 23 | 24 | help: 25 | @echo "Please use \`make ' where is one of" 26 | @echo " html to make standalone HTML files" 27 | @echo " dirhtml to make HTML files named index.html in directories" 28 | @echo " singlehtml to make a single large HTML file" 29 | @echo " pickle to make pickle files" 30 | @echo " json to make JSON files" 31 | @echo " htmlhelp to make HTML files and a HTML help project" 32 | @echo " qthelp to make HTML files and a qthelp project" 33 | @echo " devhelp to make HTML files and a Devhelp project" 34 | @echo " epub to make an epub" 35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 36 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 38 | @echo " text to make text files" 39 | @echo " man to make manual pages" 40 | @echo " texinfo to make Texinfo files" 41 | @echo " info to make Texinfo files and run them through makeinfo" 42 | @echo " gettext to make PO message catalogs" 43 | @echo " changes to make an overview of all changed/added/deprecated items" 44 | @echo " xml to make Docutils-native XML files" 45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 46 | @echo " linkcheck to check all external links for integrity" 47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 48 | 49 | clean: 50 | -rm -rf $(BUILDDIR)/* 51 | -rm -rf auto_examples/ 52 | -rm -rf generated/* 53 | -rm -rf modules/generated/* 54 | 55 | html: 56 | # These two lines make the build a bit more lengthy, and the 57 | # the embedding of images more robust 58 | rm -rf $(BUILDDIR)/html/_images 59 | #rm -rf _build/doctrees/ 60 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 61 | @echo 62 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 63 | 64 | dirhtml: 65 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 66 | @echo 67 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 68 | 69 | singlehtml: 70 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 71 | @echo 72 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 73 | 74 | pickle: 75 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 76 | @echo 77 | @echo "Build finished; now you can process the pickle files." 78 | 79 | json: 80 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 81 | @echo 82 | @echo "Build finished; now you can process the JSON files." 83 | 84 | htmlhelp: 85 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 86 | @echo 87 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 88 | ".hhp project file in $(BUILDDIR)/htmlhelp." 89 | 90 | qthelp: 91 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 92 | @echo 93 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 94 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 95 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/project-template.qhcp" 96 | @echo "To view the help file:" 97 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/project-template.qhc" 98 | 99 | devhelp: 100 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 101 | @echo 102 | @echo "Build finished." 103 | @echo "To view the help file:" 104 | @echo "# mkdir -p $$HOME/.local/share/devhelp/project-template" 105 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/project-template" 106 | @echo "# devhelp" 107 | 108 | epub: 109 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 110 | @echo 111 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 112 | 113 | latex: 114 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 115 | @echo 116 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 117 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 118 | "(use \`make latexpdf' here to do that automatically)." 119 | 120 | latexpdf: 121 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 122 | @echo "Running LaTeX files through pdflatex..." 123 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 124 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 125 | 126 | latexpdfja: 127 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 128 | @echo "Running LaTeX files through platex and dvipdfmx..." 129 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 130 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 131 | 132 | text: 133 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 134 | @echo 135 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 136 | 137 | man: 138 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 139 | @echo 140 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 141 | 142 | texinfo: 143 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 144 | @echo 145 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 146 | @echo "Run \`make' in that directory to run these through makeinfo" \ 147 | "(use \`make info' here to do that automatically)." 148 | 149 | info: 150 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 151 | @echo "Running Texinfo files through makeinfo..." 152 | make -C $(BUILDDIR)/texinfo info 153 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 154 | 155 | gettext: 156 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 157 | @echo 158 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 159 | 160 | changes: 161 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 162 | @echo 163 | @echo "The overview file is in $(BUILDDIR)/changes." 164 | 165 | linkcheck: 166 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 167 | @echo 168 | @echo "Link check complete; look for any errors in the above output " \ 169 | "or in $(BUILDDIR)/linkcheck/output.txt." 170 | 171 | doctest: 172 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 173 | @echo "Testing of doctests in the sources finished, look at the " \ 174 | "results in $(BUILDDIR)/doctest/output.txt." 175 | 176 | xml: 177 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 178 | @echo 179 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 180 | 181 | pseudoxml: 182 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 183 | @echo 184 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 185 | -------------------------------------------------------------------------------- /doc/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | REM Command file for Sphinx documentation 4 | 5 | if "%SPHINXBUILD%" == "" ( 6 | set SPHINXBUILD=sphinx-build 7 | ) 8 | set BUILDDIR=_build 9 | set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . 10 | set I18NSPHINXOPTS=%SPHINXOPTS% . 11 | if NOT "%PAPER%" == "" ( 12 | set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% 13 | set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% 14 | ) 15 | 16 | if "%1" == "" goto help 17 | 18 | if "%1" == "help" ( 19 | :help 20 | echo.Please use `make ^` where ^ is one of 21 | echo. html to make standalone HTML files 22 | echo. dirhtml to make HTML files named index.html in directories 23 | echo. singlehtml to make a single large HTML file 24 | echo. pickle to make pickle files 25 | echo. json to make JSON files 26 | echo. htmlhelp to make HTML files and a HTML help project 27 | echo. qthelp to make HTML files and a qthelp project 28 | echo. devhelp to make HTML files and a Devhelp project 29 | echo. epub to make an epub 30 | echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter 31 | echo. text to make text files 32 | echo. man to make manual pages 33 | echo. texinfo to make Texinfo files 34 | echo. gettext to make PO message catalogs 35 | echo. changes to make an overview over all changed/added/deprecated items 36 | echo. xml to make Docutils-native XML files 37 | echo. pseudoxml to make pseudoxml-XML files for display purposes 38 | echo. linkcheck to check all external links for integrity 39 | echo. doctest to run all doctests embedded in the documentation if enabled 40 | goto end 41 | ) 42 | 43 | if "%1" == "clean" ( 44 | for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i 45 | del /q /s %BUILDDIR%\* 46 | goto end 47 | ) 48 | 49 | 50 | %SPHINXBUILD% 2> nul 51 | if errorlevel 9009 ( 52 | echo. 53 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 54 | echo.installed, then set the SPHINXBUILD environment variable to point 55 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 56 | echo.may add the Sphinx directory to PATH. 57 | echo. 58 | echo.If you don't have Sphinx installed, grab it from 59 | echo.http://sphinx-doc.org/ 60 | exit /b 1 61 | ) 62 | 63 | if "%1" == "html" ( 64 | %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html 65 | if errorlevel 1 exit /b 1 66 | echo. 67 | echo.Build finished. The HTML pages are in %BUILDDIR%/html. 68 | goto end 69 | ) 70 | 71 | if "%1" == "dirhtml" ( 72 | %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml 73 | if errorlevel 1 exit /b 1 74 | echo. 75 | echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. 76 | goto end 77 | ) 78 | 79 | if "%1" == "singlehtml" ( 80 | %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml 81 | if errorlevel 1 exit /b 1 82 | echo. 83 | echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. 84 | goto end 85 | ) 86 | 87 | if "%1" == "pickle" ( 88 | %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle 89 | if errorlevel 1 exit /b 1 90 | echo. 91 | echo.Build finished; now you can process the pickle files. 92 | goto end 93 | ) 94 | 95 | if "%1" == "json" ( 96 | %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json 97 | if errorlevel 1 exit /b 1 98 | echo. 99 | echo.Build finished; now you can process the JSON files. 100 | goto end 101 | ) 102 | 103 | if "%1" == "htmlhelp" ( 104 | %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp 105 | if errorlevel 1 exit /b 1 106 | echo. 107 | echo.Build finished; now you can run HTML Help Workshop with the ^ 108 | .hhp project file in %BUILDDIR%/htmlhelp. 109 | goto end 110 | ) 111 | 112 | if "%1" == "qthelp" ( 113 | %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp 114 | if errorlevel 1 exit /b 1 115 | echo. 116 | echo.Build finished; now you can run "qcollectiongenerator" with the ^ 117 | .qhcp project file in %BUILDDIR%/qthelp, like this: 118 | echo.^> qcollectiongenerator %BUILDDIR%\qthelp\project-template.qhcp 119 | echo.To view the help file: 120 | echo.^> assistant -collectionFile %BUILDDIR%\qthelp\project-template.ghc 121 | goto end 122 | ) 123 | 124 | if "%1" == "devhelp" ( 125 | %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp 126 | if errorlevel 1 exit /b 1 127 | echo. 128 | echo.Build finished. 129 | goto end 130 | ) 131 | 132 | if "%1" == "epub" ( 133 | %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub 134 | if errorlevel 1 exit /b 1 135 | echo. 136 | echo.Build finished. The epub file is in %BUILDDIR%/epub. 137 | goto end 138 | ) 139 | 140 | if "%1" == "latex" ( 141 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 142 | if errorlevel 1 exit /b 1 143 | echo. 144 | echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. 145 | goto end 146 | ) 147 | 148 | if "%1" == "latexpdf" ( 149 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 150 | cd %BUILDDIR%/latex 151 | make all-pdf 152 | cd %BUILDDIR%/.. 153 | echo. 154 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 155 | goto end 156 | ) 157 | 158 | if "%1" == "latexpdfja" ( 159 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 160 | cd %BUILDDIR%/latex 161 | make all-pdf-ja 162 | cd %BUILDDIR%/.. 163 | echo. 164 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 165 | goto end 166 | ) 167 | 168 | if "%1" == "text" ( 169 | %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text 170 | if errorlevel 1 exit /b 1 171 | echo. 172 | echo.Build finished. The text files are in %BUILDDIR%/text. 173 | goto end 174 | ) 175 | 176 | if "%1" == "man" ( 177 | %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man 178 | if errorlevel 1 exit /b 1 179 | echo. 180 | echo.Build finished. The manual pages are in %BUILDDIR%/man. 181 | goto end 182 | ) 183 | 184 | if "%1" == "texinfo" ( 185 | %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo 186 | if errorlevel 1 exit /b 1 187 | echo. 188 | echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. 189 | goto end 190 | ) 191 | 192 | if "%1" == "gettext" ( 193 | %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale 194 | if errorlevel 1 exit /b 1 195 | echo. 196 | echo.Build finished. The message catalogs are in %BUILDDIR%/locale. 197 | goto end 198 | ) 199 | 200 | if "%1" == "changes" ( 201 | %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes 202 | if errorlevel 1 exit /b 1 203 | echo. 204 | echo.The overview file is in %BUILDDIR%/changes. 205 | goto end 206 | ) 207 | 208 | if "%1" == "linkcheck" ( 209 | %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck 210 | if errorlevel 1 exit /b 1 211 | echo. 212 | echo.Link check complete; look for any errors in the above output ^ 213 | or in %BUILDDIR%/linkcheck/output.txt. 214 | goto end 215 | ) 216 | 217 | if "%1" == "doctest" ( 218 | %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest 219 | if errorlevel 1 exit /b 1 220 | echo. 221 | echo.Testing of doctests in the sources finished, look at the ^ 222 | results in %BUILDDIR%/doctest/output.txt. 223 | goto end 224 | ) 225 | 226 | if "%1" == "xml" ( 227 | %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml 228 | if errorlevel 1 exit /b 1 229 | echo. 230 | echo.Build finished. The XML files are in %BUILDDIR%/xml. 231 | goto end 232 | ) 233 | 234 | if "%1" == "pseudoxml" ( 235 | %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml 236 | if errorlevel 1 exit /b 1 237 | echo. 238 | echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. 239 | goto end 240 | ) 241 | 242 | :end 243 | -------------------------------------------------------------------------------- /notebooks/intoCNF.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sympy import * 3 | from sympy.logic import simplify_logic 4 | from sympy.parsing.sympy_parser import parse_expr 5 | 6 | import itertools 7 | import sys 8 | 9 | 10 | 11 | 12 | 13 | 14 | def act(x): 15 | return (1 / (1 + np.exp(-x))) 16 | 17 | def cartesian_product(*arrays): 18 | la = len(arrays) 19 | dtype = np.result_type(*arrays) 20 | arr = np.empty([len(a) for a in arrays] + [la], dtype=dtype) 21 | for i, a in enumerate(np.ix_(*arrays)): 22 | arr[..., i] = a 23 | return arr.reshape(-1, la) 24 | 25 | def combine(m, n): 26 | a = np.shape(m)[0] 27 | c = list() 28 | count = 0 29 | for i in range(a): 30 | if(m[i] == n[i]): 31 | c.append(m[i]) 32 | elif(m[i] != n[i]): 33 | c.append(2) 34 | count += 1 35 | 36 | if(count > 1): 37 | return None 38 | else: 39 | return c 40 | 41 | 42 | 43 | ############################################################### 44 | #Writing a Psi_i(with sympy) 45 | def calc_form(X,y,form,pos): 46 | X1 = X[np.nonzero(y)] 47 | if np.shape(X1)[0] == 0: 48 | return "False" 49 | if np.shape(X1)[0] == np.shape(X)[0]: 50 | return "True" 51 | formula = "" 52 | size = np.shape(X1) 53 | for i in range(size[0]): 54 | if formula != "": 55 | formula = formula + "|" 56 | formula = formula + "(" 57 | for j in range(size[1]): 58 | if X1[i][j] == 0: 59 | formula = formula + "~" + form[pos[j]] +"&" 60 | if X1[i][j] == 1: 61 | formula = formula + form[pos[j]] +"&" 62 | if X1[i][size[1]-1] == 0: 63 | formula = formula + "~" + form[pos[size[1]-1]]+")" 64 | if X1[i][size[1]-1] == 1: 65 | formula = formula + form[pos[size[1]-1]]+")" 66 | 67 | form_cnf = to_cnf(formula,True) 68 | formula=str(form_cnf) 69 | return formula 70 | 71 | 72 | 73 | 74 | ######################################################################## 75 | 76 | 77 | def booleanConstraint(weights,bias): 78 | 79 | #weights contains the list of weight matrices with shape h_i+1 X h_i 80 | #bias contains the list of bias vectors with shape h_i X 1 81 | 82 | #parameters 83 | num_layers = len(weights) 84 | h = np.zeros(num_layers, dtype=int) 85 | for j in range(num_layers): 86 | h[j] = np.shape(weights[j])[0] 87 | rel_pos = [[np.nonzero(weights[j][i]) for i in range(h[j])] for j in range(num_layers)] 88 | 89 | 90 | #inputs 91 | fan_in = np.count_nonzero((weights[0])[0,:]) 92 | array = np.repeat([[0,1]], fan_in, axis=0) 93 | X = cartesian_product(*array) 94 | num_inputs = np.shape(X)[0] 95 | 96 | #outputs 97 | y = list() 98 | for j in range(num_layers): 99 | weights_active = np.zeros((h[j],fan_in)) 100 | output = np.zeros((h[j],num_inputs)) 101 | for i in range(h[j]): 102 | weights_active[i] = (weights[j])[i, np.nonzero(weights[j][i])] 103 | a = np.matmul(weights_active,np.transpose(X)) 104 | b = np.reshape(np.repeat(bias[j],np.shape(X)[0],axis=0),np.shape(a)) 105 | output = act(a+b) 106 | crisp=np.where(output < 0.5, 0, 1) 107 | y.append(crisp) 108 | 109 | 110 | 111 | #list of formulas to be combined 112 | form_labels =list() 113 | for k in range(np.shape(weights[0])[1]): 114 | form_labels.append("f" + str(k+1)) 115 | for j in range(num_layers): 116 | formula = list() 117 | print(form_labels) 118 | for i in range(h[j]): 119 | formula.append("("+calc_form(X,y[j][i],form_labels,rel_pos[j][i][0])+")") 120 | form_labels=formula 121 | return formula 122 | 123 | 124 | def conjunt_list(psi): 125 | N = len(psi) 126 | c=0 127 | psi_list = list() 128 | for i in range(N): 129 | if psi[i]=="&": 130 | psi_list.append(psi[c:i]) 131 | c=i+1 132 | psi_list.append(psi[c:N]) 133 | return psi_list 134 | 135 | #psi represents the list of psi_i rules extracted and already in CNF 136 | def global_rules(psi_list, tractable=False): 137 | try: 138 | N = len(psi_list) 139 | psi_list_list = [conjunt_list(psi_list[i]) for i in range(N)] 140 | # print(psi_list_list) 141 | 142 | ############################################ 143 | if tractable: 144 | #global rule psi (possible explosion) 145 | psi1 ="(" 146 | for i in range(N-1): 147 | psi1 = psi1 + psi_list[i]+") | (" 148 | psi1 = psi1 + psi_list[N-1]+")" 149 | psi_cnf1 = to_cnf(psi1,True) 150 | psi_cnf1 = str(psi_cnf1) 151 | return psi_cnf1 152 | ############################################ 153 | 154 | psi_conjuncts = list() 155 | for element in itertools.product(*psi_list_list): 156 | form_temp = element[0] 157 | for i in range(1, N): 158 | form_temp = form_temp + " | " + element[i] 159 | par = simplify_logic(parse_expr(form_temp)) 160 | par = str(par) 161 | psi_conjuncts.append(par) 162 | psi = "("+psi_conjuncts[0] + ")" 163 | for i in range(1,len(psi_conjuncts)): 164 | psi = psi + " & (" + psi_conjuncts[i] + ")" 165 | # print(psi) 166 | psi_cnf = to_cnf(psi,True) 167 | psi_cnf = str(psi_cnf) 168 | except: 169 | print("Oops!", sys.exc_info(), "occured.") 170 | psi_cnf = "" 171 | for psi in psi_list: 172 | psi_cnf = psi_cnf + " | (" + psi + ")" 173 | # print(psi_cnf) 174 | # psi_conjuncts = [] 175 | 176 | return psi_cnf 177 | 178 | 179 | 180 | 181 | 182 | 183 | #######################################################################################test examples to CNF 184 | # a=["x","y","z"] 185 | # b=["x|y","y|z","k"] 186 | # d=["x","w"] 187 | # c=[a,b,d] 188 | # N= len(c) 189 | # 190 | # for element in itertools.product(*lista_psi): 191 | # # for element in itertools.product(a,b): 192 | # print(element) 193 | # form_temp = element[0] 194 | # for i in range(1,N): 195 | # form_temp = form_temp+"|"+element[i] 196 | # print(form_temp) 197 | # par = simplify_logic(parse_expr(form_temp)) 198 | # print(par) 199 | # form_list.append(par) 200 | 201 | 202 | 203 | 204 | 205 | # l=["(Eight & odd & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~even) | (Five & even & ~Eight & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~odd) | (Five & odd & ~Eight & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~even) | (Nine & even & ~Eight & ~Five & ~Four & ~One & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~odd) | (Nine & odd & ~Eight & ~Five & ~Four & ~One & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~even) | (One & odd & ~Eight & ~Five & ~Four & ~Nine & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~even) | (Seven & odd & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Six & ~Three & ~Two & ~Zero & ~even) | (Six & odd & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Three & ~Two & ~Zero & ~even) | (Three & odd & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Two & ~Zero & ~even) | (Two & odd & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Zero & ~even) | (Zero & odd & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Two & ~even) | (odd & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~even)", "(Eight & even & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~odd) | (Eight & odd & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~even) | (Five & even & ~Eight & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~odd) | (Four & even & ~Eight & ~Five & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~odd) | (Nine & even & ~Eight & ~Five & ~Four & ~One & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~odd) | (Seven & even & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Six & ~Three & ~Two & ~Zero & ~odd) | (Six & even & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Three & ~Two & ~Zero & ~odd) | (Three & even & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Two & ~Zero & ~odd) | (Two & even & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Zero & ~odd) | (Two & odd & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Zero & ~even) | (Zero & even & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Two & ~odd) | (even & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~odd)"] 206 | # 207 | # l=["~TAIL & (BICYCLE | HANDLEBAR)", "ROOFSIDE & ~HEAD & ~TORSO", "SOFA & (~NECK | ~SCREEN)", "CAP & ~WING", "~POTTEDPLANT & (DOG | ~HEAD)", "BIRD & (~BODY | ~LEG)", "STERN & ~EYE & ~RIGHTSIDE", "~HEAD & (STERN | ~TORSO)", "HEAD & (HO | HORSE)", "CHAINWHEEL & ~HEAD", "~AEROPLANE & (BODY | ~HAND)", "BUS | (~CAR & ~HEAD)", "~HEAD & ~TORSO", "TABLE & ~EYE & ~LEG", "STERN & (~HEAD | ~TORSO)", "CHAIR", "~CHAIR & (MOTORBIKE | ~TORSO)", "~BODY & (TRAIN | ~TORSO)", "SHEEP & ~BOAT & ~PLANT", "STERN & ~HEAD", "CAT & ~FOOT", "~LEG & (BACKSIDE | ~HEADLIGHT)", "SCREEN | (~HEAD & ~PERSON)", "~FRONTSIDE & ~HORSE & ~TVMONITOR"] 208 | # 209 | # 210 | # lista=[str(to_cnf(i,True)) for i in l] 211 | # print(lista) 212 | # 213 | # 214 | # 215 | # 216 | # # lista=["x & y","(y|z)&(y|x)","k & w"] 217 | # # lista=["x & y","(y|z)&(y|x)","k & w"] 218 | # # print(lista) 219 | # 220 | # 221 | # 222 | # psi = global_rules(lista) 223 | # print(psi) 224 | # 225 | # 226 | # exit() 227 | # 228 | 229 | 230 | #TESTING EXAMPLES psi_i 231 | # # 1 232 | # w1 = np.array([[0.,1.1,-2.1],[3.8,-1.9,0.],[1.5,0.,-1.3]]) 233 | w2 = np.array([[0.,1.5,2.1]]) 234 | # b1 = [1.2,0.,-1.9] 235 | b2 = [5.1] 236 | # 237 | w = [w2] 238 | b = [b2] 239 | 240 | # #2 241 | # w1 = np.array([[0,1,-2,0,0],[0,0,3,-1,0],[0,1,0,-1,0]]) 242 | # w2 = np.array([[0,1,2]]) 243 | # b1 = [1,0,-1] 244 | # b2 = [5] 245 | # # 246 | # w = [w1,w2] 247 | # b = [b1,b2] 248 | 249 | 250 | #3 251 | # w1 = np.array([[0.,1.2,-2.1,0.,0.],[0.,0.,3.1,-1.9,0.],[0.,1.4,0.,-1.2,0.],[0.,1.4,0.,-1.5,0.]]) 252 | # w2 = np.array([[0.,1.5,2.3,0.],[0.,-2.1,0.,3.],[-2.9,1.1,0.,0.]]) 253 | # w3 = np.array([[2.1,2.2,0.]]) 254 | # b1 = [1.1,0.,-1.2,1.1] 255 | # b2 = [1.1,-2.2,1.] 256 | # b3 = [0.1] 257 | # 258 | # 259 | # w = [w1,w2,w3] 260 | # b = [b1,b2,b3] 261 | 262 | #4 263 | # w1 = np.array([[0.,1.2,-2.1,0.9,0.],[1.8,0.,2.1,-1.9,0.],[0.,1.4,1.1,-1.2,0.],[0.,1.4,0.,-1.5,-1.]]) 264 | # w2 = np.array([[1.1,1.5,2.3,0.],[0.,-2.1,0.9,3.],[-2.9,1.1,0.,1.5]]) 265 | # w3 = np.array([[2.1,-2.2,0.9]]) 266 | # b1 = [1.1,0.,-1.2,1.1] 267 | # b2 = [1.1,-2.2,1.] 268 | # b3 = [0.1] 269 | # 270 | # w = [w1,w2,w3] 271 | # b = [b1,b2,b3] 272 | 273 | 274 | 275 | # w1 = np.array([[-1.1,1.2,0.,0.,0.,0.],[0.,0.9,0.6,0.,0.,0.],[0.,0.,0.,1.4,-1.1,0.],[0.,0.,0.,0.,1.2,-1.]]) 276 | # w2 = np.array([[-1.1,1.5,0.,0.],[0.,-2.1,0.9,0.],[0.,0.,0.9,-1.2]]) 277 | # # w3 = np.array([[1.,1.,1.]]) 278 | # b1 = [0.1,-0.1,0.3,0.05] 279 | # b2 = [0.2,-0.1,0.1] 280 | # # b3 = [0.1] 281 | # 282 | # w = [w1,w2] 283 | # b = [b1,b2] 284 | 285 | # a=parse_expr("~(x | True)") 286 | # b=parse_expr("~(x & False)") 287 | # c=parse_expr("~(True)") 288 | # d=parse_expr("~(False)") 289 | # print(a) 290 | # print(b) 291 | # print(c) 292 | # print(d) 293 | # # print(e) 294 | # exit() 295 | 296 | # 297 | # 298 | # w1 = np.array([[1.6,0.,-0.9,0.],[0.,0.8,0.,-0.5],[0.,0.3,0.4,0.]]) 299 | # w2 = np.array([[-1.1,0.,1.1],[0.,-0.9,1.1]]) 300 | # w3 = np.array([[0.5,-0.6,0.]]) 301 | # b1 = [0.1,0.1,-0.1] 302 | # b2 = [0.1,-0.1] 303 | # b3 = [0.1] 304 | # 305 | # w = [w1,w2,w3] 306 | # b = [b1,b2,b3] 307 | # 308 | # 309 | # #parsing and simplyfing 310 | # 311 | f=booleanConstraint(w,b) 312 | print("psi_i: ",f) 313 | # h="" 314 | # for i in range(len(f)-1): 315 | # h=h+f[i]+"|" 316 | # h=h+f[len(f)-1] 317 | # g = parse_expr(h) 318 | # # g = parse_expr(f) 319 | # print("psi_parsata: ",g) 320 | # print("psi_simple: ",simplify_logic(g)) 321 | # 322 | # 323 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /notebooks/intoCNF_with_prints.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sympy import * 3 | from sympy.logic import simplify_logic 4 | from sympy.parsing.sympy_parser import parse_expr 5 | 6 | import itertools 7 | import sys 8 | 9 | 10 | def act(x): 11 | return (1 / (1 + np.exp(-x))) 12 | 13 | 14 | def cartesian_product(*arrays): 15 | la = len(arrays) 16 | dtype = np.result_type(*arrays) 17 | arr = np.empty([len(a) for a in arrays] + [la], dtype=dtype) 18 | for i, a in enumerate(np.ix_(*arrays)): 19 | arr[..., i] = a 20 | return arr.reshape(-1, la) 21 | 22 | 23 | def combine(m, n): 24 | a = np.shape(m)[0] 25 | c = list() 26 | count = 0 27 | for i in range(a): 28 | if (m[i] == n[i]): 29 | c.append(m[i]) 30 | elif (m[i] != n[i]): 31 | c.append(2) 32 | count += 1 33 | 34 | if (count > 1): 35 | return None 36 | else: 37 | return c 38 | 39 | 40 | ############################################################### 41 | #Writing a Psi_i(with sympy) 42 | def calc_form(X, y, form, pos): 43 | X1 = X[np.nonzero(y)] 44 | if np.shape(X1)[0] == 0: 45 | return "False" 46 | if np.shape(X1)[0] == np.shape(X)[0]: 47 | return "True" 48 | formula = "" 49 | size = np.shape(X1) 50 | for i in range(size[0]): 51 | if formula != "": 52 | formula = formula + "|" 53 | formula = formula + "(" 54 | for j in range(size[1]): 55 | if X1[i][j] == 0: 56 | formula = formula + "~" + form[pos[j]] + "&" 57 | if X1[i][j] == 1: 58 | formula = formula + form[pos[j]] + "&" 59 | # Code below duplicates for, can't we just close the bracket, e.g. formula[-1]=')' 60 | if X1[i][size[1] - 1] == 0: 61 | formula = formula + "~" + form[pos[size[1] - 1]] + ")" 62 | if X1[i][size[1] - 1] == 1: 63 | formula = formula + form[pos[size[1] - 1]] + ")" 64 | 65 | form_cnf = to_cnf(formula, True) 66 | formula = str(form_cnf) 67 | return formula 68 | 69 | 70 | ######################################################################## 71 | 72 | 73 | def booleanConstraint(weights, bias): 74 | print("w", weights, "b", bias) 75 | 76 | #weights contains the list of weight matrices with shape h_i+1 X h_i 77 | #bias contains the list of bias vectors with shape h_i X 1 78 | 79 | #parameters 80 | num_layers = len(weights) 81 | print("num_layers", num_layers) 82 | h = np.zeros(num_layers, dtype=int) 83 | for j in range(num_layers): 84 | h[j] = np.shape(weights[j])[0] 85 | print('h[',j,']=np.shape(', weights[j], ')[0]=', h[j]) 86 | rel_pos = [[np.nonzero(weights[j][i]) for i in range(h[j])] 87 | for j in range(num_layers)] 88 | print('rel_pos', rel_pos) 89 | 90 | #inputs 91 | fan_in = np.count_nonzero((weights[0])[0, :]) 92 | print('fan_in', fan_in, '= count nonzero(', (weights[0])[0, :], ')') 93 | array = np.repeat([[0, 1]], fan_in, axis=0) 94 | print('array', array) 95 | X = cartesian_product(*array) 96 | print('X', X) 97 | num_inputs = np.shape(X)[0] 98 | print('num_inputs', num_inputs) 99 | 100 | #outputs 101 | y = list() 102 | for j in range(num_layers): 103 | print("Layer", j, 'h', h[j]) 104 | weights_active = np.zeros((h[j], fan_in)) 105 | print(" weights_active", weights_active.shape) 106 | output = np.zeros((h[j], num_inputs)) 107 | print(" output", output.shape) 108 | for i in range(h[j]): 109 | print(" Neuron", i) 110 | weights_active[i] = (weights[j])[i, np.nonzero(weights[j][i])] 111 | print(" weights_active[", i, "]=", (weights[j]), "[", i, ",", np.nonzero(weights[j][i]), "]") 112 | a = np.matmul(weights_active, np.transpose(X)) 113 | print(" a=np.matmul(", weights_active, ",", np.transpose(X), ")=", a) 114 | b = np.reshape(np.repeat(bias[j], np.shape(X)[0], axis=0), np.shape(a)) 115 | output = act(a + b) 116 | print("output", output) 117 | crisp = np.where(output < 0.5, 0, 1) 118 | print("crisp", crisp) 119 | y.append(crisp) 120 | 121 | #list of formulas to be combined 122 | form_labels = list() 123 | for k in range(np.shape(weights[0])[1]): 124 | print(k, "form_labels", form_labels, " --> ", end='') 125 | form_labels.append("f" + str(k + 1)) 126 | print(form_labels) 127 | for j in range(num_layers): 128 | print("Layer ", j) 129 | formula = list() 130 | print(form_labels) 131 | for i in range(h[j]): 132 | print(" i", i, formula, " --> ", end='') 133 | formula.append( 134 | "(" + calc_form(X, y[j][i], form_labels, rel_pos[j][i][0]) + 135 | ")") 136 | print(formula) 137 | print("calc_form(", X.tolist(), y[j][i], form_labels, rel_pos[j][i][0], ')') 138 | form_labels = formula 139 | return formula 140 | 141 | 142 | def conjunt_list(psi): 143 | N = len(psi) 144 | c = 0 145 | psi_list = list() 146 | for i in range(N): 147 | if psi[i] == "&": 148 | psi_list.append(psi[c:i]) 149 | c = i + 1 150 | psi_list.append(psi[c:N]) 151 | return psi_list 152 | 153 | 154 | #psi represents the list of psi_i rules extracted and already in CNF 155 | def global_rules(psi_list, tractable=False): 156 | try: 157 | N = len(psi_list) 158 | psi_list_list = [conjunt_list(psi_list[i]) for i in range(N)] 159 | # print(psi_list_list) 160 | 161 | ############################################ 162 | if tractable: 163 | #global rule psi (possible explosion) 164 | psi1 = "(" 165 | for i in range(N - 1): 166 | psi1 = psi1 + psi_list[i] + ") | (" 167 | psi1 = psi1 + psi_list[N - 1] + ")" 168 | psi_cnf1 = to_cnf(psi1, True) 169 | psi_cnf1 = str(psi_cnf1) 170 | return psi_cnf1 171 | ############################################ 172 | 173 | psi_conjuncts = list() 174 | for element in itertools.product(*psi_list_list): 175 | form_temp = element[0] 176 | for i in range(1, N): 177 | form_temp = form_temp + " | " + element[i] 178 | par = simplify_logic(parse_expr(form_temp)) 179 | par = str(par) 180 | psi_conjuncts.append(par) 181 | psi = "(" + psi_conjuncts[0] + ")" 182 | for i in range(1, len(psi_conjuncts)): 183 | psi = psi + " & (" + psi_conjuncts[i] + ")" 184 | # print(psi) 185 | psi_cnf = to_cnf(psi, True) 186 | psi_cnf = str(psi_cnf) 187 | except: 188 | print("Oops!", sys.exc_info(), "occured.") 189 | psi_cnf = "" 190 | for psi in psi_list: 191 | psi_cnf = psi_cnf + " | (" + psi + ")" 192 | # print(psi_cnf) 193 | # psi_conjuncts = [] 194 | 195 | return psi_cnf 196 | 197 | if __name__ == '__main__': 198 | #######################################################################################test examples to CNF 199 | # a=["x","y","z"] 200 | # b=["x|y","y|z","k"] 201 | # d=["x","w"] 202 | # c=[a,b,d] 203 | # N= len(c) 204 | # 205 | # for element in itertools.product(*lista_psi): 206 | # # for element in itertools.product(a,b): 207 | # print(element) 208 | # form_temp = element[0] 209 | # for i in range(1,N): 210 | # form_temp = form_temp+"|"+element[i] 211 | # print(form_temp) 212 | # par = simplify_logic(parse_expr(form_temp)) 213 | # print(par) 214 | # form_list.append(par) 215 | 216 | # l=["(Eight & odd & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~even) | (Five & even & ~Eight & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~odd) | (Five & odd & ~Eight & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~even) | (Nine & even & ~Eight & ~Five & ~Four & ~One & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~odd) | (Nine & odd & ~Eight & ~Five & ~Four & ~One & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~even) | (One & odd & ~Eight & ~Five & ~Four & ~Nine & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~even) | (Seven & odd & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Six & ~Three & ~Two & ~Zero & ~even) | (Six & odd & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Three & ~Two & ~Zero & ~even) | (Three & odd & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Two & ~Zero & ~even) | (Two & odd & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Zero & ~even) | (Zero & odd & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Two & ~even) | (odd & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~even)", "(Eight & even & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~odd) | (Eight & odd & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~even) | (Five & even & ~Eight & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~odd) | (Four & even & ~Eight & ~Five & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~odd) | (Nine & even & ~Eight & ~Five & ~Four & ~One & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~odd) | (Seven & even & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Six & ~Three & ~Two & ~Zero & ~odd) | (Six & even & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Three & ~Two & ~Zero & ~odd) | (Three & even & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Two & ~Zero & ~odd) | (Two & even & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Zero & ~odd) | (Two & odd & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Zero & ~even) | (Zero & even & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Two & ~odd) | (even & ~Eight & ~Five & ~Four & ~Nine & ~One & ~Seven & ~Six & ~Three & ~Two & ~Zero & ~odd)"] 217 | # 218 | # l=["~TAIL & (BICYCLE | HANDLEBAR)", "ROOFSIDE & ~HEAD & ~TORSO", "SOFA & (~NECK | ~SCREEN)", "CAP & ~WING", "~POTTEDPLANT & (DOG | ~HEAD)", "BIRD & (~BODY | ~LEG)", "STERN & ~EYE & ~RIGHTSIDE", "~HEAD & (STERN | ~TORSO)", "HEAD & (HO | HORSE)", "CHAINWHEEL & ~HEAD", "~AEROPLANE & (BODY | ~HAND)", "BUS | (~CAR & ~HEAD)", "~HEAD & ~TORSO", "TABLE & ~EYE & ~LEG", "STERN & (~HEAD | ~TORSO)", "CHAIR", "~CHAIR & (MOTORBIKE | ~TORSO)", "~BODY & (TRAIN | ~TORSO)", "SHEEP & ~BOAT & ~PLANT", "STERN & ~HEAD", "CAT & ~FOOT", "~LEG & (BACKSIDE | ~HEADLIGHT)", "SCREEN | (~HEAD & ~PERSON)", "~FRONTSIDE & ~HORSE & ~TVMONITOR"] 219 | # 220 | # 221 | # lista=[str(to_cnf(i,True)) for i in l] 222 | # print(lista) 223 | # 224 | # 225 | # 226 | # 227 | # # lista=["x & y","(y|z)&(y|x)","k & w"] 228 | # # lista=["x & y","(y|z)&(y|x)","k & w"] 229 | # # print(lista) 230 | # 231 | # 232 | # 233 | # psi = global_rules(lista) 234 | # print(psi) 235 | # 236 | # 237 | # exit() 238 | # 239 | 240 | #TESTING EXAMPLES psi_i 241 | # # 1 242 | # w1 = np.array([[0.,1.1,-2.1],[3.8,-1.9,0.],[1.5,0.,-1.3]]) 243 | w2 = np.array([[3.8,-1.9,0.]]) 244 | # b1 = [1.2,0.,-1.9] 245 | b2 = [5.1] 246 | # 247 | w = [w2] 248 | b = [b2] 249 | 250 | # #2 251 | w1 = np.array([[0,1,-2,0,0],[0,0,3,-1,0],[0,1,0,-1,0]]) 252 | w2 = np.array([[0,1,2]]) 253 | b1 = [1,0,-1] 254 | b2 = [5] 255 | # 256 | w = [w1,w2] 257 | b = [b1,b2] 258 | 259 | #3 260 | # w1 = np.array([[0.,1.2,-2.1,0.,0.],[0.,0.,3.1,-1.9,0.],[0.,1.4,0.,-1.2,0.],[0.,1.4,0.,-1.5,0.]]) 261 | # w2 = np.array([[0.,1.5,2.3,0.],[0.,-2.1,0.,3.],[-2.9,1.1,0.,0.]]) 262 | # w3 = np.array([[2.1,2.2,0.]]) 263 | # b1 = [1.1,0.,-1.2,1.1] 264 | # b2 = [1.1,-2.2,1.] 265 | # b3 = [0.1] 266 | # 267 | # 268 | # w = [w1,w2,w3] 269 | # b = [b1,b2,b3] 270 | 271 | #4 272 | # w1 = np.array([[0.,1.2,-2.1,0.9,0.],[1.8,0.,2.1,-1.9,0.],[0.,1.4,1.1,-1.2,0.],[0.,1.4,0.,-1.5,-1.]]) 273 | # w2 = np.array([[1.1,1.5,2.3,0.],[0.,-2.1,0.9,3.],[-2.9,1.1,0.,1.5]]) 274 | # w3 = np.array([[2.1,-2.2,0.9]]) 275 | # b1 = [1.1,0.,-1.2,1.1] 276 | # b2 = [1.1,-2.2,1.] 277 | # b3 = [0.1] 278 | # 279 | # w = [w1,w2,w3] 280 | # b = [b1,b2,b3] 281 | 282 | # w1 = np.array([[-1.1,1.2,0.,0.,0.,0.],[0.,0.9,0.6,0.,0.,0.],[0.,0.,0.,1.4,-1.1,0.],[0.,0.,0.,0.,1.2,-1.]]) 283 | # w2 = np.array([[-1.1,1.5,0.,0.],[0.,-2.1,0.9,0.],[0.,0.,0.9,-1.2]]) 284 | # # w3 = np.array([[1.,1.,1.]]) 285 | # b1 = [0.1,-0.1,0.3,0.05] 286 | # b2 = [0.2,-0.1,0.1] 287 | # # b3 = [0.1] 288 | # 289 | # w = [w1,w2] 290 | # b = [b1,b2] 291 | 292 | # a=parse_expr("~(x | True)") 293 | # b=parse_expr("~(x & False)") 294 | # c=parse_expr("~(True)") 295 | # d=parse_expr("~(False)") 296 | # print(a) 297 | # print(b) 298 | # print(c) 299 | # print(d) 300 | # # print(e) 301 | # exit() 302 | 303 | # 304 | # 305 | # w1 = np.array([[1.6,0.,-0.9,0.],[0.,0.8,0.,-0.5],[0.,0.3,0.4,0.]]) 306 | # w2 = np.array([[-1.1,0.,1.1],[0.,-0.9,1.1]]) 307 | # w3 = np.array([[0.5,-0.6,0.]]) 308 | # b1 = [0.1,0.1,-0.1] 309 | # b2 = [0.1,-0.1] 310 | # b3 = [0.1] 311 | # 312 | # w = [w1,w2,w3] 313 | # b = [b1,b2,b3] 314 | # 315 | # 316 | # #parsing and simplyfing 317 | # 318 | f = booleanConstraint(w, b) 319 | print("psi_i: ", f) 320 | # h="" 321 | # for i in range(len(f)-1): 322 | # h=h+f[i]+"|" 323 | # h=h+f[len(f)-1] 324 | # g = parse_expr(h) 325 | # # g = parse_expr(f) 326 | # print("psi_parsata: ",g) 327 | # print("psi_simple: ",simplify_logic(g)) 328 | # 329 | # 330 | --------------------------------------------------------------------------------