├── .circleci └── config.yml ├── .gitignore ├── .readthedocs.yml ├── LICENSE ├── MANIFEST.in ├── README.rst ├── appveyor.yml ├── doc ├── Makefile ├── _static │ ├── css │ │ └── project-template.css │ ├── img │ │ ├── cover.png │ │ ├── dag1.png │ │ ├── dag2.png │ │ ├── dag2a.png │ │ ├── dag3.png │ │ ├── dag3a.png │ │ ├── skdag-banner.png │ │ ├── skdag-dark.png │ │ └── stack.png │ └── js │ │ └── copybutton.js ├── api.rst ├── conf.py ├── index.rst ├── make.bat ├── quick_start.rst └── user_guide.rst ├── environment.yml ├── examples └── README.txt ├── img ├── skdag-banner.png ├── skdag-dark-fill.png ├── skdag-dark.kra ├── skdag-dark.png ├── skdag-fill.png ├── skdag.kra └── skdag.png ├── requirements.txt ├── requirements_doc.txt ├── requirements_full.txt ├── requirements_test.txt ├── setup.cfg ├── setup.py └── skdag ├── __init__.py ├── _version.py ├── dag ├── __init__.py ├── _builder.py ├── _dag.py ├── _render.py ├── _utils.py └── tests │ ├── __init__.py │ ├── test_builder.py │ ├── test_dag.py │ └── utils.py └── exceptions.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | jobs: 4 | build: 5 | docker: 6 | - image: circleci/python:3.6.1 7 | working_directory: ~/repo 8 | steps: 9 | - checkout 10 | - run: 11 | name: install dependencies 12 | command: | 13 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh 14 | chmod +x miniconda.sh && ./miniconda.sh -b -p ~/miniconda 15 | export PATH="~/miniconda/bin:$PATH" 16 | conda update --yes --quiet conda 17 | conda create -n testenv --yes --quiet python=3 18 | source activate testenv 19 | conda install --yes pip numpy scipy scikit-learn matplotlib sphinx sphinx_rtd_theme numpydoc pillow 20 | pip install sphinx-gallery 21 | pip install . 22 | cd doc 23 | make html 24 | - store_artifacts: 25 | path: doc/_build/html/ 26 | destination: doc 27 | - store_artifacts: 28 | path: ~/log.txt 29 | - run: ls -ltrh doc/_build/html 30 | filters: 31 | branches: 32 | ignore: gh-pages 33 | 34 | workflows: 35 | version: 2 36 | workflow: 37 | jobs: 38 | - build 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # scikit-learn specific 10 | doc/_build/ 11 | doc/auto_examples/ 12 | doc/modules/generated/ 13 | doc/datasets/generated/ 14 | 15 | # Distribution / packaging 16 | 17 | .Python 18 | env/ 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *,cover 53 | .hypothesis/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | 62 | # Sphinx documentation 63 | doc/_build/ 64 | doc/generated/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # Jupyter 70 | .ipynb_checkpoints 71 | 72 | # Misc 73 | *.swp 74 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | formats: 2 | - none 3 | requirements_file: requirements.txt 4 | python: 5 | pip_install: true 6 | extra_requirements: 7 | - test 8 | - doc 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 big-o 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.* 2 | include requirements*.txt 3 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | .. -*- mode: rst -*- 2 | 3 | |AppVeyor|_ |Codecov|_ |ReadTheDocs|_ 4 | 5 | .. |AppVeyor| image:: https://ci.appveyor.com/api/projects/status/github/big-o/skdag?branch=main&svg=true 6 | .. _AppVeyor: https://ci.appveyor.com/project/big-o/skdag 7 | 8 | .. |Codecov| image:: https://codecov.io/gh/big-o/skdag/branch/main/graph/badge.svg 9 | .. _Codecov: https://codecov.io/gh/big-o/skdag 10 | 11 | .. |ReadTheDocs| image:: https://readthedocs.org/projects/skdag/badge/?version=latest 12 | .. _ReadTheDocs: https://skdag.readthedocs.io/en/latest/?badge=latest 13 | 14 | skdag - A more flexible alternative to scikit-learn Pipelines 15 | ============================================================= 16 | 17 | .. image:: img/skdag-banner.png 18 | 19 | scikit-dag (``skdag``) is an open-sourced, MIT-licenced library that provides advanced 20 | workflow management to any machine learning operations that follow 21 | scikit-learn_ conventions. Installation is simple: 22 | 23 | .. code-block:: bash 24 | 25 | pip install skdag 26 | 27 | It works by introducing Directed Acyclic Graphs as a drop-in replacement for traditional 28 | scikit-learn ``Pipeline``. This gives you a simple interface for a range of use cases 29 | including complex pre-processing, model stacking and benchmarking. 30 | 31 | .. code-block:: python 32 | 33 | from skdag import DAGBuilder 34 | 35 | dag = ( 36 | DAGBuilder(infer_dataframe=True) 37 | .add_step("impute", SimpleImputer()) 38 | .add_step("vitals", "passthrough", deps={"impute": ["age", "sex", "bmi", "bp"]}) 39 | .add_step( 40 | "blood", 41 | PCA(n_components=2, random_state=0), 42 | deps={"impute": ["s1", "s2", "s3", "s4", "s5", "s6"]} 43 | ) 44 | .add_step( 45 | "rf", 46 | RandomForestRegressor(max_depth=5, random_state=0), 47 | deps=["blood", "vitals"] 48 | ) 49 | .add_step("svm", SVR(C=0.7), deps=["blood", "vitals"]) 50 | .add_step( 51 | "knn", 52 | KNeighborsRegressor(n_neighbors=5), 53 | deps=["blood", "vitals"] 54 | ) 55 | .add_step("meta", LinearRegression(), deps=["rf", "svm", "knn"]) 56 | .make_dag() 57 | ) 58 | 59 | dag.show(detailed=True) 60 | 61 | .. image:: doc/_static/img/cover.png 62 | 63 | The above DAG imputes missing values, runs PCA on the columns relating to blood test 64 | results and leaves the other columns as they are. Then they get passed to three 65 | different regressors before being passed onto a final meta-estimator. Because DAGs 66 | (unlike pipelines) allow predictors in the middle or a workflow, you can use them to 67 | implement model stacking. We also chose to run the DAG steps in parallel wherever 68 | possible. 69 | 70 | After building our DAG, we can treat it as any other estimator: 71 | 72 | .. code-block:: python 73 | 74 | from sklearn import datasets 75 | 76 | X, y = datasets.load_diabetes(return_X_y=True, as_frame=True) 77 | X_train, X_test, y_train, y_test = train_test_split( 78 | X, y, test_size=0.2, random_state=0 79 | ) 80 | 81 | dag.fit(X_train, y_train) 82 | dag.predict(X_test) 83 | 84 | Just like a pipeline, you can optimise it with a gridsearch, pickle it etc. 85 | 86 | Note that this package does not deal with things like delayed dependencies and 87 | distributed architectures - consider an `established `_ 88 | `solution `_ for such use cases. ``skdag`` is just for building and 89 | executing local ensembles from estimators. 90 | 91 | `Read on `_ to learn more about ``skdag``... 92 | 93 | .. _scikit-learn: https://scikit-learn.org 94 | -------------------------------------------------------------------------------- /appveyor.yml: -------------------------------------------------------------------------------- 1 | build: false 2 | 3 | environment: 4 | matrix: 5 | - APPVEYOR_BUILD_WORKER_IMAGE: Ubuntu 6 | APPVEYOR_YML_DISABLE_PS_LINUX: true 7 | 8 | stack: python 3.8 9 | 10 | install: | 11 | if [[ "${APPVEYOR_BUILD_WORKER_IMAGE}" == "Ubuntu" ]]; then 12 | sudo apt update 13 | sudo apt install -y graphviz libgraphviz-dev 14 | elif [[ "${APPVEYOR_BUILD_WORKER_IMAGE}" == "macOS" ]]; then 15 | brew update 16 | brew install graphviz 17 | fi 18 | pip install --upgrade pip 19 | for f in $(find . -maxdepth 1 -name 'requirements*.txt'); do 20 | pip install -r ${f} 21 | done 22 | pip install . 23 | 24 | test_script: 25 | - mkdir for_test 26 | - cd for_test 27 | - pytest -v --cov=skdag --pyargs skdag 28 | 29 | after_test: 30 | - cp .coverage ${APPVEYOR_BUILD_FOLDER} 31 | - cd ${APPVEYOR_BUILD_FOLDER} 32 | - curl -Os https://uploader.codecov.io/latest/linux/codecov 33 | - chmod +x codecov 34 | - ./codecov 35 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 21 | 22 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext 23 | 24 | help: 25 | @echo "Please use \`make ' where is one of" 26 | @echo " html to make standalone HTML files" 27 | @echo " dirhtml to make HTML files named index.html in directories" 28 | @echo " singlehtml to make a single large HTML file" 29 | @echo " pickle to make pickle files" 30 | @echo " json to make JSON files" 31 | @echo " htmlhelp to make HTML files and a HTML help project" 32 | @echo " qthelp to make HTML files and a qthelp project" 33 | @echo " devhelp to make HTML files and a Devhelp project" 34 | @echo " epub to make an epub" 35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 36 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 38 | @echo " text to make text files" 39 | @echo " man to make manual pages" 40 | @echo " texinfo to make Texinfo files" 41 | @echo " info to make Texinfo files and run them through makeinfo" 42 | @echo " gettext to make PO message catalogs" 43 | @echo " changes to make an overview of all changed/added/deprecated items" 44 | @echo " xml to make Docutils-native XML files" 45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 46 | @echo " linkcheck to check all external links for integrity" 47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 48 | 49 | clean: 50 | -rm -rf $(BUILDDIR)/* 51 | -rm -rf auto_examples/ 52 | -rm -rf generated/* 53 | -rm -rf modules/generated/* 54 | 55 | html: 56 | # These two lines make the build a bit more lengthy, and the 57 | # the embedding of images more robust 58 | rm -rf $(BUILDDIR)/html/_images 59 | #rm -rf _build/doctrees/ 60 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 61 | @echo 62 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 63 | 64 | dirhtml: 65 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 66 | @echo 67 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 68 | 69 | singlehtml: 70 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 71 | @echo 72 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 73 | 74 | pickle: 75 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 76 | @echo 77 | @echo "Build finished; now you can process the pickle files." 78 | 79 | json: 80 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 81 | @echo 82 | @echo "Build finished; now you can process the JSON files." 83 | 84 | htmlhelp: 85 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 86 | @echo 87 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 88 | ".hhp project file in $(BUILDDIR)/htmlhelp." 89 | 90 | qthelp: 91 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 92 | @echo 93 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 94 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 95 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/project-template.qhcp" 96 | @echo "To view the help file:" 97 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/project-template.qhc" 98 | 99 | devhelp: 100 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 101 | @echo 102 | @echo "Build finished." 103 | @echo "To view the help file:" 104 | @echo "# mkdir -p $$HOME/.local/share/devhelp/project-template" 105 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/project-template" 106 | @echo "# devhelp" 107 | 108 | epub: 109 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 110 | @echo 111 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 112 | 113 | latex: 114 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 115 | @echo 116 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 117 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 118 | "(use \`make latexpdf' here to do that automatically)." 119 | 120 | latexpdf: 121 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 122 | @echo "Running LaTeX files through pdflatex..." 123 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 124 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 125 | 126 | latexpdfja: 127 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 128 | @echo "Running LaTeX files through platex and dvipdfmx..." 129 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 130 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 131 | 132 | text: 133 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 134 | @echo 135 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 136 | 137 | man: 138 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 139 | @echo 140 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 141 | 142 | texinfo: 143 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 144 | @echo 145 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 146 | @echo "Run \`make' in that directory to run these through makeinfo" \ 147 | "(use \`make info' here to do that automatically)." 148 | 149 | info: 150 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 151 | @echo "Running Texinfo files through makeinfo..." 152 | make -C $(BUILDDIR)/texinfo info 153 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 154 | 155 | gettext: 156 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 157 | @echo 158 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 159 | 160 | changes: 161 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 162 | @echo 163 | @echo "The overview file is in $(BUILDDIR)/changes." 164 | 165 | linkcheck: 166 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 167 | @echo 168 | @echo "Link check complete; look for any errors in the above output " \ 169 | "or in $(BUILDDIR)/linkcheck/output.txt." 170 | 171 | doctest: 172 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 173 | @echo "Testing of doctests in the sources finished, look at the " \ 174 | "results in $(BUILDDIR)/doctest/output.txt." 175 | 176 | xml: 177 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 178 | @echo 179 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 180 | 181 | pseudoxml: 182 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 183 | @echo 184 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 185 | -------------------------------------------------------------------------------- /doc/_static/css/project-template.css: -------------------------------------------------------------------------------- 1 | @import url("theme.css"); 2 | 3 | .highlight a { 4 | text-decoration: underline; 5 | } 6 | 7 | .deprecated p { 8 | padding: 10px 7px 10px 10px; 9 | color: #b94a48; 10 | background-color: #F3E5E5; 11 | border: 1px solid #eed3d7; 12 | } 13 | 14 | .deprecated p span.versionmodified { 15 | font-weight: bold; 16 | } 17 | 18 | img.logo { 19 | max-height: 100px; 20 | } 21 | -------------------------------------------------------------------------------- /doc/_static/img/cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/doc/_static/img/cover.png -------------------------------------------------------------------------------- /doc/_static/img/dag1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/doc/_static/img/dag1.png -------------------------------------------------------------------------------- /doc/_static/img/dag2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/doc/_static/img/dag2.png -------------------------------------------------------------------------------- /doc/_static/img/dag2a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/doc/_static/img/dag2a.png -------------------------------------------------------------------------------- /doc/_static/img/dag3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/doc/_static/img/dag3.png -------------------------------------------------------------------------------- /doc/_static/img/dag3a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/doc/_static/img/dag3a.png -------------------------------------------------------------------------------- /doc/_static/img/skdag-banner.png: -------------------------------------------------------------------------------- 1 | ../../../img/skdag-banner.png -------------------------------------------------------------------------------- /doc/_static/img/skdag-dark.png: -------------------------------------------------------------------------------- 1 | ../../../img/skdag-dark.png -------------------------------------------------------------------------------- /doc/_static/img/stack.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/doc/_static/img/stack.png -------------------------------------------------------------------------------- /doc/_static/js/copybutton.js: -------------------------------------------------------------------------------- 1 | $(document).ready(function() { 2 | /* Add a [>>>] button on the top-right corner of code samples to hide 3 | * the >>> and ... prompts and the output and thus make the code 4 | * copyable. */ 5 | var div = $('.highlight-python .highlight,' + 6 | '.highlight-python3 .highlight,' + 7 | '.highlight-pycon .highlight,' + 8 | '.highlight-default .highlight') 9 | var pre = div.find('pre'); 10 | 11 | // get the styles from the current theme 12 | pre.parent().parent().css('position', 'relative'); 13 | var hide_text = 'Hide the prompts and output'; 14 | var show_text = 'Show the prompts and output'; 15 | var border_width = pre.css('border-top-width'); 16 | var border_style = pre.css('border-top-style'); 17 | var border_color = pre.css('border-top-color'); 18 | var button_styles = { 19 | 'cursor':'pointer', 'position': 'absolute', 'top': '0', 'right': '0', 20 | 'border-color': border_color, 'border-style': border_style, 21 | 'border-width': border_width, 'color': border_color, 'text-size': '75%', 22 | 'font-family': 'monospace', 'padding-left': '0.2em', 'padding-right': '0.2em', 23 | 'border-radius': '0 3px 0 0' 24 | } 25 | 26 | // create and add the button to all the code blocks that contain >>> 27 | div.each(function(index) { 28 | var jthis = $(this); 29 | if (jthis.find('.gp').length > 0) { 30 | var button = $('>>>'); 31 | button.css(button_styles) 32 | button.attr('title', hide_text); 33 | button.data('hidden', 'false'); 34 | jthis.prepend(button); 35 | } 36 | // tracebacks (.gt) contain bare text elements that need to be 37 | // wrapped in a span to work with .nextUntil() (see later) 38 | jthis.find('pre:has(.gt)').contents().filter(function() { 39 | return ((this.nodeType == 3) && (this.data.trim().length > 0)); 40 | }).wrap(''); 41 | }); 42 | 43 | // define the behavior of the button when it's clicked 44 | $('.copybutton').click(function(e){ 45 | e.preventDefault(); 46 | var button = $(this); 47 | if (button.data('hidden') === 'false') { 48 | // hide the code output 49 | button.parent().find('.go, .gp, .gt').hide(); 50 | button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'hidden'); 51 | button.css('text-decoration', 'line-through'); 52 | button.attr('title', show_text); 53 | button.data('hidden', 'true'); 54 | } else { 55 | // show the code output 56 | button.parent().find('.go, .gp, .gt').show(); 57 | button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'visible'); 58 | button.css('text-decoration', 'none'); 59 | button.attr('title', hide_text); 60 | button.data('hidden', 'false'); 61 | } 62 | }); 63 | }); 64 | -------------------------------------------------------------------------------- /doc/api.rst: -------------------------------------------------------------------------------- 1 | ######### 2 | skdag API 3 | ######### 4 | 5 | See below for detailed descriptions of the ``skdag`` interface. 6 | 7 | .. currentmodule:: skdag 8 | 9 | Estimator 10 | ========= 11 | 12 | .. autosummary:: 13 | :toctree: generated/ 14 | :template: class.rst 15 | 16 | DAG 17 | 18 | Exceptions 19 | ========== 20 | 21 | .. autosummary:: 22 | :toctree: generated/ 23 | :template: class.rst 24 | 25 | exceptions.DAGError 26 | 27 | Utilities 28 | ========= 29 | 30 | .. autosummary:: 31 | :toctree: generated/ 32 | :template: class.rst 33 | 34 | DAGBuilder -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # project-template documentation build configuration file, created by 4 | # sphinx-quickstart on Mon Jan 18 14:44:12 2016. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | 15 | import sys 16 | import os 17 | 18 | import sphinx_gallery 19 | import sphinx_rtd_theme 20 | 21 | # Add to sys.path the top-level directory where the package is located. 22 | sys.path.insert(0, os.path.abspath("..")) 23 | 24 | # If extensions (or modules to document with autodoc) are in another directory, 25 | # add these directories to sys.path here. If the directory is relative to the 26 | # documentation root, use os.path.abspath to make it absolute, like shown here. 27 | # sys.path.insert(0, os.path.abspath('.')) 28 | 29 | # -- General configuration ------------------------------------------------ 30 | 31 | # If your documentation needs a minimal Sphinx version, state it here. 32 | # needs_sphinx = '1.0' 33 | 34 | # Add any Sphinx extension module names here, as strings. They can be 35 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 36 | # ones. 37 | extensions = [ 38 | "sphinx.ext.autodoc", 39 | "sphinx.ext.autosummary", 40 | "sphinx.ext.doctest", 41 | "sphinx.ext.intersphinx", 42 | "sphinx.ext.viewcode", 43 | "numpydoc", 44 | "sphinx_gallery.gen_gallery", 45 | ] 46 | 47 | # this is needed for some reason... 48 | # see https://github.com/numpy/numpydoc/issues/69 49 | numpydoc_show_class_members = False 50 | 51 | # pngmath / imgmath compatibility layer for different sphinx versions 52 | import sphinx 53 | from distutils.version import LooseVersion 54 | 55 | if LooseVersion(sphinx.__version__) < LooseVersion("1.4"): 56 | extensions.append("sphinx.ext.pngmath") 57 | else: 58 | extensions.append("sphinx.ext.imgmath") 59 | 60 | autodoc_default_flags = ["members", "inherited-members"] 61 | 62 | # Add any paths that contain templates here, relative to this directory. 63 | templates_path = ["_templates"] 64 | 65 | # generate autosummary even if no references 66 | autosummary_generate = True 67 | 68 | # The suffix of source filenames. 69 | source_suffix = ".rst" 70 | 71 | # The encoding of source files. 72 | # source_encoding = 'utf-8-sig' 73 | 74 | # Generate the plots for the gallery 75 | plot_gallery = False 76 | 77 | # The master toctree document. 78 | master_doc = "index" 79 | 80 | # General information about the project. 81 | project = "skdag" 82 | copyright = "2022, big-o (github)" 83 | 84 | # The version info for the project you're documenting, acts as replacement for 85 | # |version| and |release|, also used in various other places throughout the 86 | # built documents. 87 | # 88 | # The short X.Y version. 89 | from skdag import __version__ 90 | 91 | version = __version__ 92 | # The full version, including alpha/beta/rc tags. 93 | release = __version__ 94 | 95 | # The language for content autogenerated by Sphinx. Refer to documentation 96 | # for a list of supported languages. 97 | # language = None 98 | 99 | # There are two options for replacing |today|: either, you set today to some 100 | # non-false value, then it is used: 101 | # today = '' 102 | # Else, today_fmt is used as the format for a strftime call. 103 | # today_fmt = '%B %d, %Y' 104 | 105 | # List of patterns, relative to source directory, that match files and 106 | # directories to ignore when looking for source files. 107 | exclude_patterns = ["_build", "_templates"] 108 | 109 | # The reST default role (used for this markup: `text`) to use for all 110 | # documents. 111 | # default_role = None 112 | 113 | # If true, '()' will be appended to :func: etc. cross-reference text. 114 | # add_function_parentheses = True 115 | 116 | # If true, the current module name will be prepended to all description 117 | # unit titles (such as .. function::). 118 | # add_module_names = True 119 | 120 | # If true, sectionauthor and moduleauthor directives will be shown in the 121 | # output. They are ignored by default. 122 | # show_authors = False 123 | 124 | # The name of the Pygments (syntax highlighting) style to use. 125 | pygments_style = "sphinx" 126 | 127 | # Custom style 128 | html_style = "css/project-template.css" 129 | 130 | # A list of ignored prefixes for module index sorting. 131 | # modindex_common_prefix = [] 132 | 133 | # If true, keep warnings as "system message" paragraphs in the built documents. 134 | # keep_warnings = False 135 | 136 | 137 | # -- Options for HTML output ---------------------------------------------- 138 | 139 | # The theme to use for HTML and HTML Help pages. See the documentation for 140 | # a list of builtin themes. 141 | html_theme = "sphinx_rtd_theme" 142 | 143 | # Theme options are theme-specific and customize the look and feel of a theme 144 | # further. For a list of options available for each theme, see the 145 | # documentation. 146 | html_theme_options = { 147 | "logo_only": True, 148 | } 149 | 150 | # Add any paths that contain custom themes here, relative to this directory. 151 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 152 | 153 | # The name for this set of Sphinx documents. If None, it defaults to 154 | # " v documentation". 155 | # html_title = None 156 | 157 | # A shorter title for the navigation bar. Default is the same as html_title. 158 | # html_short_title = None 159 | 160 | # The name of an image file (relative to this directory) to place at the top 161 | # of the sidebar. 162 | html_logo = "_static/img/skdag-dark.png" 163 | 164 | # The name of an image file (within the static path) to use as favicon of the 165 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 166 | # pixels large. 167 | # html_favicon = None 168 | 169 | # Add any paths that contain custom static files (such as style sheets) here, 170 | # relative to this directory. They are copied after the builtin static files, 171 | # so a file named "default.css" will overwrite the builtin "default.css". 172 | html_static_path = ["_static"] 173 | 174 | # Add any extra paths that contain custom files (such as robots.txt or 175 | # .htaccess) here, relative to this directory. These files are copied 176 | # directly to the root of the documentation. 177 | # html_extra_path = [] 178 | 179 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 180 | # using the given strftime format. 181 | # html_last_updated_fmt = '%b %d, %Y' 182 | 183 | # If true, SmartyPants will be used to convert quotes and dashes to 184 | # typographically correct entities. 185 | # html_use_smartypants = True 186 | 187 | # Custom sidebar templates, maps document names to template names. 188 | # html_sidebars = {} 189 | 190 | # Additional templates that should be rendered to pages, maps page names to 191 | # template names. 192 | # html_additional_pages = {} 193 | 194 | # If false, no module index is generated. 195 | # html_domain_indices = True 196 | 197 | # If false, no index is generated. 198 | # html_use_index = True 199 | 200 | # If true, the index is split into individual pages for each letter. 201 | # html_split_index = False 202 | 203 | # If true, links to the reST sources are added to the pages. 204 | # html_show_sourcelink = True 205 | 206 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 207 | # html_show_sphinx = True 208 | 209 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 210 | # html_show_copyright = True 211 | 212 | # If true, an OpenSearch description file will be output, and all pages will 213 | # contain a tag referring to it. The value of this option must be the 214 | # base URL from which the finished HTML is served. 215 | # html_use_opensearch = '' 216 | 217 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 218 | # html_file_suffix = None 219 | 220 | # Output file base name for HTML help builder. 221 | htmlhelp_basename = "project-templatedoc" 222 | 223 | 224 | # -- Options for LaTeX output --------------------------------------------- 225 | 226 | latex_elements = { 227 | # The paper size ('letterpaper' or 'a4paper'). 228 | #'papersize': 'letterpaper', 229 | # The font size ('10pt', '11pt' or '12pt'). 230 | #'pointsize': '10pt', 231 | # Additional stuff for the LaTeX preamble. 232 | #'preamble': '', 233 | } 234 | 235 | # Grouping the document tree into LaTeX files. List of tuples 236 | # (source start file, target name, title, 237 | # author, documentclass [howto, manual, or own class]). 238 | latex_documents = [ 239 | ( 240 | "index", 241 | "project-template.tex", 242 | "project-template Documentation", 243 | "big-o", 244 | "manual", 245 | ), 246 | ] 247 | 248 | # The name of an image file (relative to this directory) to place at the top of 249 | # the title page. 250 | # latex_logo = None 251 | 252 | # For "manual" documents, if this is true, then toplevel headings are parts, 253 | # not chapters. 254 | # latex_use_parts = False 255 | 256 | # If true, show page references after internal links. 257 | # latex_show_pagerefs = False 258 | 259 | # If true, show URL addresses after external links. 260 | # latex_show_urls = False 261 | 262 | # Documents to append as an appendix to all manuals. 263 | # latex_appendices = [] 264 | 265 | # If false, no module index is generated. 266 | # latex_domain_indices = True 267 | 268 | 269 | # -- Options for manual page output --------------------------------------- 270 | 271 | # One entry per manual page. List of tuples 272 | # (source start file, name, description, authors, manual section). 273 | man_pages = [ 274 | ( 275 | "index", 276 | "project-template", 277 | "project-template Documentation", 278 | ["big-o"], 279 | 1, 280 | ) 281 | ] 282 | 283 | # If true, show URL addresses after external links. 284 | # man_show_urls = False 285 | 286 | 287 | # -- Options for Texinfo output ------------------------------------------- 288 | 289 | # Grouping the document tree into Texinfo files. List of tuples 290 | # (source start file, target name, title, author, 291 | # dir menu entry, description, category) 292 | texinfo_documents = [ 293 | ( 294 | "index", 295 | "project-template", 296 | "project-template Documentation", 297 | "big-o", 298 | "project-template", 299 | "One line description of project.", 300 | "Miscellaneous", 301 | ), 302 | ] 303 | 304 | # Documents to append as an appendix to all manuals. 305 | # texinfo_appendices = [] 306 | 307 | # If false, no module index is generated. 308 | # texinfo_domain_indices = True 309 | 310 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 311 | # texinfo_show_urls = 'footnote' 312 | 313 | # If true, do not generate a @detailmenu in the "Top" node's menu. 314 | # texinfo_no_detailmenu = False 315 | 316 | 317 | # Example configuration for intersphinx: refer to the Python standard library. 318 | # intersphinx configuration 319 | intersphinx_mapping = { 320 | "python": ("https://docs.python.org/{.major}".format(sys.version_info), None), 321 | "numpy": ("https://numpy.org/doc/stable", None), 322 | "scipy": ("https://docs.scipy.org/doc/scipy", None), 323 | "matplotlib": ("https://matplotlib.org/stable", None), 324 | "sklearn": ("https://scikit-learn.org/stable", None), 325 | } 326 | 327 | # sphinx-gallery configuration 328 | sphinx_gallery_conf = { 329 | "doc_module": "skdag", 330 | "backreferences_dir": os.path.join("generated"), 331 | "reference_url": {"skdag": None}, 332 | } 333 | 334 | 335 | def setup(app): 336 | # a copy button to copy snippet of code from the documentation 337 | app.add_js_file("js/copybutton.js") 338 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | skdag - scikit-learn workflow management 2 | ============================================ 3 | 4 | scikit-dag (``skdag``) is an open-sourced, MIT-licenced library that provides advanced 5 | workflow management to any machine learning operations that follow 6 | :mod:`sklearn` conventions. It does this by introducing Directed Acyclic 7 | Graphs (:class:`skdag.dag.DAG`) as a drop-in replacement for traditional scikit-learn 8 | :mod:`sklearn.pipeline.Pipeline`. This gives you a simple interface for a range of use 9 | cases including complex pre-processing, model stacking and benchmarking. 10 | 11 | .. code-block:: python 12 | 13 | from skdag import DAGBuilder 14 | 15 | dag = ( 16 | DAGBuilder(infer_dataframe=True) 17 | .add_step("impute", SimpleImputer()) 18 | .add_step( 19 | "vitals", 20 | "passthrough", 21 | deps={"impute": ["age", "sex", "bmi", "bp"]}, 22 | ) 23 | .add_step( 24 | "blood", 25 | PCA(n_components=2, random_state=0), 26 | deps={"impute": ["s1", "s2", "s3", "s4", "s5", "s6"]}, 27 | ) 28 | .add_step( 29 | "rf", 30 | RandomForestRegressor(max_depth=5, random_state=0), 31 | deps=["blood", "vitals"], 32 | ) 33 | .add_step("svm", SVR(C=0.7), deps=["blood", "vitals"]) 34 | .add_step( 35 | "knn", 36 | KNeighborsRegressor(n_neighbors=5), 37 | deps=["blood", "vitals"], 38 | ) 39 | .add_step("meta", LinearRegression(), deps=["rf", "svm", "knn"]) 40 | .make_dag(n_jobs=2, verbose=True) 41 | ) 42 | 43 | dag.show(detailed=True) 44 | 45 | .. image:: _static/img/cover.png 46 | 47 | The above DAG imputes missing values, runs PCA on the columns relating to blood test 48 | results and leaves the other columns as they are. Then they get passed to three 49 | different regressors before being passed onto a final meta-estimator. Because DAGs 50 | (unlike pipelines) allow predictors in the middle or a workflow, you can use them to 51 | implement model stacking. We also chose to run the DAG steps in parallel wherever 52 | possible. 53 | 54 | After building our DAG, we can treat it as any other estimator: 55 | 56 | .. code-block:: python 57 | 58 | from sklearn import datasets 59 | 60 | X, y = datasets.load_diabetes(return_X_y=True, as_frame=True) 61 | X_train, X_test, y_train, y_test = train_test_split( 62 | X, y, test_size=0.2, random_state=0 63 | ) 64 | 65 | dag.fit(X_train, y_train) 66 | dag.predict(X_test) 67 | 68 | Just like a pipeline, you can optimise it with a gridsearch, pickle it etc. 69 | 70 | Note that this package does not deal with things like delayed dependencies and 71 | distributed architectures - consider an `established `_ 72 | `solution `_ for such use cases. ``skdag`` is just for building and 73 | executing local ensembles from estimators. 74 | 75 | :ref:`Read on` to learn more about ``skdag``... 76 | 77 | .. toctree:: 78 | :maxdepth: 2 79 | :hidden: 80 | :caption: Getting Started 81 | 82 | quick_start 83 | 84 | .. toctree:: 85 | :maxdepth: 2 86 | :hidden: 87 | :caption: Documentation 88 | 89 | user_guide 90 | 91 | .. toctree:: 92 | :maxdepth: 2 93 | :hidden: 94 | :caption: API 95 | 96 | api 97 | 98 | .. toctree:: 99 | :maxdepth: 2 100 | :hidden: 101 | :caption: Tutorial - Examples 102 | 103 | auto_examples/index 104 | 105 | `Getting started `_ 106 | ------------------------------------- 107 | 108 | A practical introduction to DAGs for scikit-learn. 109 | 110 | `User Guide `_ 111 | ------------------------------- 112 | 113 | Details of the full functionality provided by ``skdag``. 114 | 115 | `API Documentation `_ 116 | ------------------------------- 117 | 118 | Detailed API documentation. 119 | 120 | `Examples `_ 121 | -------------------------------------- 122 | 123 | Further examples that complement the `User Guide `_. 124 | -------------------------------------------------------------------------------- /doc/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | REM Command file for Sphinx documentation 4 | 5 | if "%SPHINXBUILD%" == "" ( 6 | set SPHINXBUILD=sphinx-build 7 | ) 8 | set BUILDDIR=_build 9 | set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . 10 | set I18NSPHINXOPTS=%SPHINXOPTS% . 11 | if NOT "%PAPER%" == "" ( 12 | set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% 13 | set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% 14 | ) 15 | 16 | if "%1" == "" goto help 17 | 18 | if "%1" == "help" ( 19 | :help 20 | echo.Please use `make ^` where ^ is one of 21 | echo. html to make standalone HTML files 22 | echo. dirhtml to make HTML files named index.html in directories 23 | echo. singlehtml to make a single large HTML file 24 | echo. pickle to make pickle files 25 | echo. json to make JSON files 26 | echo. htmlhelp to make HTML files and a HTML help project 27 | echo. qthelp to make HTML files and a qthelp project 28 | echo. devhelp to make HTML files and a Devhelp project 29 | echo. epub to make an epub 30 | echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter 31 | echo. text to make text files 32 | echo. man to make manual pages 33 | echo. texinfo to make Texinfo files 34 | echo. gettext to make PO message catalogs 35 | echo. changes to make an overview over all changed/added/deprecated items 36 | echo. xml to make Docutils-native XML files 37 | echo. pseudoxml to make pseudoxml-XML files for display purposes 38 | echo. linkcheck to check all external links for integrity 39 | echo. doctest to run all doctests embedded in the documentation if enabled 40 | goto end 41 | ) 42 | 43 | if "%1" == "clean" ( 44 | for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i 45 | del /q /s %BUILDDIR%\* 46 | goto end 47 | ) 48 | 49 | 50 | %SPHINXBUILD% 2> nul 51 | if errorlevel 9009 ( 52 | echo. 53 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 54 | echo.installed, then set the SPHINXBUILD environment variable to point 55 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 56 | echo.may add the Sphinx directory to PATH. 57 | echo. 58 | echo.If you don't have Sphinx installed, grab it from 59 | echo.http://sphinx-doc.org/ 60 | exit /b 1 61 | ) 62 | 63 | if "%1" == "html" ( 64 | %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html 65 | if errorlevel 1 exit /b 1 66 | echo. 67 | echo.Build finished. The HTML pages are in %BUILDDIR%/html. 68 | goto end 69 | ) 70 | 71 | if "%1" == "dirhtml" ( 72 | %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml 73 | if errorlevel 1 exit /b 1 74 | echo. 75 | echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. 76 | goto end 77 | ) 78 | 79 | if "%1" == "singlehtml" ( 80 | %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml 81 | if errorlevel 1 exit /b 1 82 | echo. 83 | echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. 84 | goto end 85 | ) 86 | 87 | if "%1" == "pickle" ( 88 | %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle 89 | if errorlevel 1 exit /b 1 90 | echo. 91 | echo.Build finished; now you can process the pickle files. 92 | goto end 93 | ) 94 | 95 | if "%1" == "json" ( 96 | %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json 97 | if errorlevel 1 exit /b 1 98 | echo. 99 | echo.Build finished; now you can process the JSON files. 100 | goto end 101 | ) 102 | 103 | if "%1" == "htmlhelp" ( 104 | %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp 105 | if errorlevel 1 exit /b 1 106 | echo. 107 | echo.Build finished; now you can run HTML Help Workshop with the ^ 108 | .hhp project file in %BUILDDIR%/htmlhelp. 109 | goto end 110 | ) 111 | 112 | if "%1" == "qthelp" ( 113 | %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp 114 | if errorlevel 1 exit /b 1 115 | echo. 116 | echo.Build finished; now you can run "qcollectiongenerator" with the ^ 117 | .qhcp project file in %BUILDDIR%/qthelp, like this: 118 | echo.^> qcollectiongenerator %BUILDDIR%\qthelp\project-template.qhcp 119 | echo.To view the help file: 120 | echo.^> assistant -collectionFile %BUILDDIR%\qthelp\project-template.ghc 121 | goto end 122 | ) 123 | 124 | if "%1" == "devhelp" ( 125 | %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp 126 | if errorlevel 1 exit /b 1 127 | echo. 128 | echo.Build finished. 129 | goto end 130 | ) 131 | 132 | if "%1" == "epub" ( 133 | %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub 134 | if errorlevel 1 exit /b 1 135 | echo. 136 | echo.Build finished. The epub file is in %BUILDDIR%/epub. 137 | goto end 138 | ) 139 | 140 | if "%1" == "latex" ( 141 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 142 | if errorlevel 1 exit /b 1 143 | echo. 144 | echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. 145 | goto end 146 | ) 147 | 148 | if "%1" == "latexpdf" ( 149 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 150 | cd %BUILDDIR%/latex 151 | make all-pdf 152 | cd %BUILDDIR%/.. 153 | echo. 154 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 155 | goto end 156 | ) 157 | 158 | if "%1" == "latexpdfja" ( 159 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 160 | cd %BUILDDIR%/latex 161 | make all-pdf-ja 162 | cd %BUILDDIR%/.. 163 | echo. 164 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 165 | goto end 166 | ) 167 | 168 | if "%1" == "text" ( 169 | %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text 170 | if errorlevel 1 exit /b 1 171 | echo. 172 | echo.Build finished. The text files are in %BUILDDIR%/text. 173 | goto end 174 | ) 175 | 176 | if "%1" == "man" ( 177 | %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man 178 | if errorlevel 1 exit /b 1 179 | echo. 180 | echo.Build finished. The manual pages are in %BUILDDIR%/man. 181 | goto end 182 | ) 183 | 184 | if "%1" == "texinfo" ( 185 | %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo 186 | if errorlevel 1 exit /b 1 187 | echo. 188 | echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. 189 | goto end 190 | ) 191 | 192 | if "%1" == "gettext" ( 193 | %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale 194 | if errorlevel 1 exit /b 1 195 | echo. 196 | echo.Build finished. The message catalogs are in %BUILDDIR%/locale. 197 | goto end 198 | ) 199 | 200 | if "%1" == "changes" ( 201 | %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes 202 | if errorlevel 1 exit /b 1 203 | echo. 204 | echo.The overview file is in %BUILDDIR%/changes. 205 | goto end 206 | ) 207 | 208 | if "%1" == "linkcheck" ( 209 | %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck 210 | if errorlevel 1 exit /b 1 211 | echo. 212 | echo.Link check complete; look for any errors in the above output ^ 213 | or in %BUILDDIR%/linkcheck/output.txt. 214 | goto end 215 | ) 216 | 217 | if "%1" == "doctest" ( 218 | %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest 219 | if errorlevel 1 exit /b 1 220 | echo. 221 | echo.Testing of doctests in the sources finished, look at the ^ 222 | results in %BUILDDIR%/doctest/output.txt. 223 | goto end 224 | ) 225 | 226 | if "%1" == "xml" ( 227 | %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml 228 | if errorlevel 1 exit /b 1 229 | echo. 230 | echo.Build finished. The XML files are in %BUILDDIR%/xml. 231 | goto end 232 | ) 233 | 234 | if "%1" == "pseudoxml" ( 235 | %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml 236 | if errorlevel 1 exit /b 1 237 | echo. 238 | echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. 239 | goto end 240 | ) 241 | 242 | :end 243 | -------------------------------------------------------------------------------- /doc/quick_start.rst: -------------------------------------------------------------------------------- 1 | .. _quickstart: 2 | 3 | ###################### 4 | Quick Start with skdag 5 | ###################### 6 | 7 | The following tutorial shows you how to write some simple directed acyclic graphs (DAGs) 8 | with ``skdag``. 9 | 10 | Installation 11 | ============ 12 | 13 | Installing skdag is simple: 14 | 15 | .. code:: bash 16 | 17 | pip install skdag 18 | 19 | Note that to visualise graphs you need to install the graphviz libraries too. See the 20 | `pygraphviz documentation `_ for installation guidance. 21 | 22 | Creating a DAG 23 | ============== 24 | 25 | The simplest DAGs are just a chain of singular dependencies. These DAGs may be 26 | created from the :meth:`skdag.dag.DAG.from_pipeline` method in the same way as a 27 | DAG: 28 | 29 | .. code-block:: python 30 | 31 | >>> from skdag import DAGBuilder 32 | >>> from sklearn.decomposition import PCA 33 | >>> from sklearn.impute import SimpleImputer 34 | >>> from sklearn.linear_model import LogisticRegression 35 | >>> dag = DAGBuilder().from_pipeline( 36 | ... steps=[ 37 | ... ("impute", SimpleImputer()), 38 | ... ("pca", PCA()), 39 | ... ("lr", LogisticRegression()) 40 | ... ] 41 | ... ).make_dag() 42 | >>> dag.show() 43 | o impute 44 | | 45 | o pca 46 | | 47 | o lr 48 | 49 | 50 | .. image:: _static/img/dag1.png 51 | 52 | For more complex DAGs, it is recommended to use a :class:`skdag.dag.DAGBuilder`, 53 | which allows you to define the graph by specifying the dependencies of each new 54 | estimator: 55 | 56 | .. code-block:: python 57 | 58 | >>> dag = ( 59 | ... DAGBuilder(infer_dataframe=True) 60 | ... .add_step("impute", SimpleImputer()) 61 | ... .add_step("vitals", "passthrough", deps={"impute": ["age", "sex", "bmi", "bp"]}) 62 | ... .add_step("blood", PCA(n_components=2, random_state=0), deps={"impute": slice(4, 10)}) 63 | ... .add_step("lr", LogisticRegression(random_state=0), deps=["blood", "vitals"]) 64 | ... .make_dag() 65 | ... ) 66 | >>> dag.show() 67 | o impute 68 | |\ 69 | o o blood,vitals 70 | |/ 71 | o lr 72 | 73 | 74 | .. image:: _static/img/dag2a.png 75 | 76 | In the above examples we pass the first four columns directly to a regressor, but 77 | the remaining columns have dimensionality reduction applied first before being 78 | passed to the same regressor as extra input columns. 79 | 80 | In this DAG, as well as using the ``deps`` option to control which estimators feed in to 81 | other estimators, but which columns are used (and ignored) by each step. For more detail 82 | on how to control this behaviour, see the `User Guide `_. 83 | 84 | The DAG may now be used as an estimator in its own right: 85 | 86 | >>> from sklearn import datasets 87 | >>> X, y = datasets.load_diabetes(return_X_y=True, as_frame=True) 88 | >>> type(dag.fit_predict(X, y)) 89 | 90 | 91 | In an extension to the scikit-learn estimator interface, DAGs also support multiple 92 | inputs and multiple outputs. Let's say we want to compare two different classifiers: 93 | 94 | >>> from sklearn.ensemble import RandomForestClassifier 95 | >>> cal = DAGBuilder(infer_dataframe=True).from_pipeline( 96 | ... [("rf", RandomForestClassifier(random_state=0))] 97 | ... ).make_dag() 98 | >>> dag2 = dag.join(cal, edges=[("blood", "rf"), ("vitals", "rf")]) 99 | >>> dag2.show() 100 | o impute 101 | |\ 102 | o o blood,vitals 103 | |x| 104 | o o lr,rf 105 | 106 | 107 | .. image:: _static/img/dag3a.png 108 | 109 | Now our DAG will return two outputs: one from each classifier. Multiple outputs are 110 | returned as a :class:`sklearn.utils.Bunch`: 111 | 112 | >>> y_pred = dag2.fit_predict(X, y) 113 | >>> type(y_pred.lr) 114 | 115 | >>> type(y_pred.rf) 116 | 117 | 118 | Similarly, multiple inputs are also acceptable and inputs can be provided by 119 | specifying ``X`` and ``y`` as ``dict``-like objects. -------------------------------------------------------------------------------- /doc/user_guide.rst: -------------------------------------------------------------------------------- 1 | .. title:: User guide : contents 2 | 3 | .. _user_guide: 4 | 5 | ######################## 6 | Composing Estimator DAGs 7 | ######################## 8 | 9 | The following tutorial shows you how to write some simple directed acyclic graphs (DAGs) 10 | with ``skdag``. 11 | 12 | Creating your first DAG 13 | ======================= 14 | 15 | The simplest DAGs are just a chain of singular dependencies, which is equivalent to a 16 | scikit-learn :class:`~sklearn.pipeline.Pipeline`. These DAGs may be created from the 17 | :meth:`~skdag.DAG.dag._dag.from_pipeline` method in the same way as a DAG: 18 | 19 | .. code-block:: python 20 | 21 | >>> from skdag import DAGBuilder 22 | >>> from sklearn.decomposition import PCA 23 | >>> from sklearn.impute import SimpleImputer 24 | >>> from sklearn.linear_model import LogisticRegression 25 | >>> dag = DAGBuilder(infer_dataframe=True).from_pipeline( 26 | ... steps=[ 27 | ... ("impute", SimpleImputer()), 28 | ... ("pca", PCA()), 29 | ... ("lr", LogisticRegression()) 30 | ... ] 31 | ... ).make_dag() 32 | 33 | You may view a diagram of the DAG with the :meth:`~skdag.dag.DAG.show` method. In a 34 | notbook environment this will display an image, whereas in a terminal it will generate 35 | ASCII text: 36 | 37 | .. code-block:: python 38 | 39 | >>> dag.show() 40 | o impute 41 | | 42 | o pca 43 | | 44 | o lr 45 | 46 | .. image:: _static/img/dag1.png 47 | 48 | Note that we also provided an extra option, ``infer_dataframe``. This is entirely 49 | optional, but if set the DAG will ensure that dataframe inputs have column and index 50 | information preserved (or inferred), and the output of the pipeline will also be a 51 | dataframe. This is useful if you wish to filter down the inputs for one particular step 52 | to only include certain columns; something we shall see in action later. 53 | 54 | For more complex DAGs, it is recommended to use a :class:`skdag.dag.DAGBuilder`, 55 | which allows you to define the graph by specifying the dependencies of each new 56 | estimator: 57 | 58 | .. code-block:: python 59 | 60 | >>> from skdag import DAGBuilder 61 | >>> from sklearn.compose import make_column_selector 62 | >>> dag = ( 63 | ... DAGBuilder(infer_dataframe=True) 64 | ... .add_step("impute", SimpleImputer()) 65 | ... .add_step("vitals", "passthrough", deps={"impute": ["age", "sex", "bmi", "bp"]}) 66 | ... .add_step("blood", PCA(n_components=2, random_state=0), deps={"impute": make_column_selector("s[0-9]+")}) 67 | ... .add_step("lr", LogisticRegression(random_state=0), deps=["blood", "vitals"]) 68 | ... .make_dag() 69 | ... ) 70 | >>> dag.show() 71 | o impute 72 | |\ 73 | o o blood,vitals 74 | |/ 75 | o lr 76 | 77 | .. image:: _static/img/dag2.png 78 | 79 | In the above examples we pass the first four columns directly to a regressor, but 80 | the remaining columns have dimensionality reduction applied first before being 81 | passed to the same regressor. Note that we can define our graph edges in two 82 | different ways: as a dict (if we need to select only certain columns from the source 83 | node) or as a simple list (if we want to simply grab all columns from all input 84 | nodes). Columns may be specified as any kind of iterable (list, slice etc.) or a column 85 | selector function that conforms to :meth:`sklearn.compose.make_column_selector`. 86 | 87 | If you wish to specify string column names for dependencies, ensure you provide the 88 | ``infer_dataframe=True`` option when you create a dag. This will ensure that all 89 | estimator outputs are coerced into dataframes. Where possible column names will be 90 | inferred, otherwise the column names will just be the name of the estimator step with an 91 | appended index number. If you do not specify ``infer_dataframe=True``, the dag will 92 | leave the outputs unmodified, which in most cases will mean numpy arrays that only 93 | support numeric column indices. 94 | 95 | The DAG may now be used as an estimator in its own right: 96 | 97 | .. code-block:: python 98 | 99 | >>> from sklearn import datasets 100 | >>> X, y = datasets.load_diabetes(return_X_y=True, as_frame=True) 101 | >>> y_hat = dag.fit_predict(X, y) 102 | >>> type(y_hat) 103 | 104 | 105 | In an extension to the scikit-learn estimator interface, DAGs also support multiple 106 | inputs and multiple outputs. Let's say we want to compare two different classifiers: 107 | 108 | .. code-block:: python 109 | 110 | >>> from sklearn.ensemble import RandomForestClassifier 111 | >>> rf = DAGBuilder().from_pipeline( 112 | ... [("rf", RandomForestClassifier(random_state=0))] 113 | ... ).make_dag() 114 | >>> dag2 = dag.join(rf, edges=[("blood", "rf"), ("vitals", "rf")]) 115 | >>> dag2.show() 116 | o impute 117 | |\ 118 | o o blood,vitals 119 | |x| 120 | o o lr,rf 121 | 122 | .. image:: _static/img/dag3.png 123 | 124 | Now our DAG will return two outputs: one from each classifier. Multiple outputs are 125 | returned as a :class:`sklearn.utils.Bunch`: 126 | 127 | .. code-block:: python 128 | 129 | >>> y_pred = dag2.fit_predict(X, y) 130 | >>> type(y_pred.lr) 131 | 132 | >>> type(y_pred.rf) 133 | 134 | 135 | Note that we have different types of output here because ``LogisticRegression`` natively 136 | supports dataframe input whereas ``RandomForestClassifier`` does not. We could fix this 137 | by specifying ``infer_dataframe=True`` when we createed our ``rf`` DAG extension. 138 | 139 | Similarly, multiple inputs are also acceptable and inputs can be provided by 140 | specifying ``X`` and ``y`` as ``dict``-like objects. 141 | 142 | ######## 143 | Stacking 144 | ######## 145 | 146 | Unlike Pipelines, DAGs do not require only the final step to be an estimator. This 147 | allows DAGs to be used for model stacking. 148 | 149 | Stacking is an ensemble method, like bagging or boosting, that allows multiple models 150 | to be combined into a single, more robust estimator. In stacking, predictions from 151 | multiple models are passed to a final `meta-estimator`; a simple model that combines the 152 | previous predictions into a final output. Like other ensemble methods, stacking can help 153 | to improve the performance and robustness of individual models. 154 | 155 | ``skdag`` implements stacking in a simple way. If an estimator without a ``transform()`` 156 | method is placed in a non-leaf step of the DAG, then the output of 157 | :meth:`predict_proba`, :meth:`decision_function` or :meth:`predict` will be passed to 158 | the next step(s). 159 | 160 | .. code-block:: python 161 | 162 | >>> from sklearn import datasets 163 | >>> from sklearn.linear_model import LinearRegression 164 | >>> from sklearn.model_selection import train_test_split 165 | >>> from sklearn.neighbors import KNeighborsRegressor 166 | >>> from sklearn.svm import SVR 167 | >>> X, y = datasets.load_diabetes(return_X_y=True) 168 | >>> X_train, X_test, y_train, y_test = train_test_split( 169 | ... X, y, test_size=0.2, random_state=0 170 | ... ) 171 | >>> knn = KNeighborsRegressor(3) 172 | >>> svr = SVR(C=1.0) 173 | >>> stack = ( 174 | ... DAGBuilder() 175 | ... .add_step("pass", "passthrough") 176 | ... .add_step("knn", knn, deps=["pass"]) 177 | ... .add_step("svr", svr, deps=["pass"]) 178 | ... .add_step("meta", LinearRegression(), deps=["knn", "svr"]) 179 | ... .make_dag() 180 | ... ) 181 | >>> stack.fit(X_train, y_train) 182 | DAG(... 183 | 184 | .. image:: _static/img/stack.png 185 | 186 | Note that the passthrough is not strictly necessary but it is convenient as it ensures 187 | the stack has a single entry point, which makes it simpler to use. 188 | 189 | The DAG infers that :meth:`predict` should be called for the two intermediate 190 | estimators. Our meta-estimator is then simply taking in prediction for each classifier 191 | as its input features. 192 | 193 | As we can now see, the stacking ensemble method gives us a boost in performance: 194 | 195 | .. code-block:: python 196 | 197 | >>> stack.score(X_test, y_test) 198 | 0.145... 199 | >>> knn.score(X_test, y_test) 200 | 0.138... 201 | >>> svr.score(X_test, y_test) 202 | 0.128... 203 | 204 | Note that for binary classifiers you probably need to specify that only the positive 205 | class probability is used as input by the meta-estimator. The DAG will automatically 206 | infer that :meth:`predict_proba` should be called, but you will need to manually tell 207 | the DAG which column to take. To do this, you can simply specify your step dependencies 208 | as a dictionary of step name to column indices instead: 209 | 210 | .. code:: python 211 | 212 | >>> from sklearn.ensemble import RandomForestClassifier 213 | >>> from sklearn.svm import SVC 214 | >>> clf_stack = ( 215 | ... DAGBuilder(infer_dataframe=True) 216 | ... .add_step("pass", "passthrough") 217 | ... .add_step("rf", RandomForestClassifier(), deps=["pass"]) 218 | ... .add_step("svr", SVC(), deps=["pass"]) 219 | ... .add_step("meta", LinearRegression(), deps={"rf": 1, "svr": 1}) 220 | ... .make_dag() 221 | ... ) 222 | 223 | Stacking works best when a diverse range of algorithms are used to provide predictions, 224 | which are then fed into a very simple meta-estimator. To minimize overfitting, 225 | cross-validation should be considered when using stacking. -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: project-template 2 | dependencies: 3 | - networkx>2.8 4 | - numpy 5 | - scipy 6 | - scikit-learn 7 | -------------------------------------------------------------------------------- /examples/README.txt: -------------------------------------------------------------------------------- 1 | .. _general_examples: 2 | 3 | General examples 4 | ================ 5 | 6 | Introductory examples. 7 | -------------------------------------------------------------------------------- /img/skdag-banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/img/skdag-banner.png -------------------------------------------------------------------------------- /img/skdag-dark-fill.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/img/skdag-dark-fill.png -------------------------------------------------------------------------------- /img/skdag-dark.kra: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/img/skdag-dark.kra -------------------------------------------------------------------------------- /img/skdag-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/img/skdag-dark.png -------------------------------------------------------------------------------- /img/skdag-fill.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/img/skdag-fill.png -------------------------------------------------------------------------------- /img/skdag.kra: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/img/skdag.kra -------------------------------------------------------------------------------- /img/skdag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/img/skdag.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | black 2 | joblib 3 | networkx>=2.6 4 | numpy 5 | scipy 6 | scikit-learn 7 | stackeddag 8 | -------------------------------------------------------------------------------- /requirements_doc.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpydoc 3 | sphinx 4 | sphinx-gallery 5 | sphinx_rtd_theme 6 | -------------------------------------------------------------------------------- /requirements_full.txt: -------------------------------------------------------------------------------- 1 | pygraphviz 2 | -------------------------------------------------------------------------------- /requirements_test.txt: -------------------------------------------------------------------------------- 1 | pandas 2 | pytest 3 | pytest-cov 4 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.rst 3 | 4 | [aliases] 5 | test = pytest 6 | 7 | [tool:pytest] 8 | doctest_optionflags = NORMALIZE_WHITESPACE ELLIPSIS 9 | testpaths = . 10 | addopts = 11 | -s 12 | --doctest-modules 13 | --doctest-glob="*.rst" 14 | --cov=skdag 15 | --ignore setup.py 16 | --ignore doc/_build 17 | --ignore doc/_templates 18 | --no-cov-on-fail 19 | 20 | [coverage:run] 21 | branch = True 22 | source = skdag 23 | include = */skdag/* 24 | omit = 25 | */tests/* 26 | *_test.py 27 | test_*.py 28 | */setup.py 29 | 30 | [coverage:report] 31 | exclude_lines = 32 | pragma: no cover 33 | def __repr__ 34 | if self.debug: 35 | if settings.DEBUG 36 | raise AssertionError 37 | raise NotImplementedError 38 | if 0: 39 | if __name__ == .__main__.: 40 | if self.verbose: 41 | show_missing = True 42 | 43 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | """A flexible alternative to scikit-learn Pipelines""" 3 | 4 | import codecs 5 | import os 6 | 7 | from setuptools import find_packages, setup 8 | 9 | 10 | def parse_requirements(filename): 11 | # Copy dependencies from requirements file 12 | with open(filename, encoding="utf-8") as f: 13 | requirements = [line.strip() for line in f.read().splitlines()] 14 | requirements = [ 15 | line.split("#")[0].strip() 16 | for line in requirements 17 | if not line.startswith("#") 18 | ] 19 | 20 | return requirements 21 | 22 | 23 | # get __version__ from _version.py 24 | ver_file = os.path.join("skdag", "_version.py") 25 | with open(ver_file) as f: 26 | exec(f.read()) 27 | 28 | DISTNAME = "skdag" 29 | DESCRIPTION = "A flexible alternative to scikit-learn Pipelines" 30 | 31 | with codecs.open("README.rst", encoding="utf-8") as f: 32 | LONG_DESCRIPTION = f.read() 33 | 34 | MAINTAINER = "big-o" 35 | MAINTAINER_EMAIL = "big-o@users.noreply.github.com" 36 | URL = "https://github.com/big-o/skdag" 37 | LICENSE = "new BSD" 38 | DOWNLOAD_URL = "https://github.com/scikit-learn-contrib/project-template" 39 | VERSION = __version__ 40 | INSTALL_REQUIRES = parse_requirements("requirements.txt") 41 | CLASSIFIERS = [ 42 | "Intended Audience :: Science/Research", 43 | "Intended Audience :: Developers", 44 | "License :: OSI Approved", 45 | "Programming Language :: Python", 46 | "Topic :: Software Development", 47 | "Topic :: Scientific/Engineering", 48 | "Operating System :: Microsoft :: Windows", 49 | "Operating System :: POSIX", 50 | "Operating System :: Unix", 51 | "Operating System :: MacOS", 52 | "Programming Language :: Python :: 3.7", 53 | "Programming Language :: Python :: 3.8", 54 | "Programming Language :: Python :: 3.9", 55 | ] 56 | EXTRAS_REQUIRE = { 57 | tgt: parse_requirements(f"requirements_{tgt}.txt") 58 | for tgt in ["test", "doc"] 59 | } 60 | 61 | setup( 62 | name=DISTNAME, 63 | maintainer=MAINTAINER, 64 | maintainer_email=MAINTAINER_EMAIL, 65 | description=DESCRIPTION, 66 | license=LICENSE, 67 | url=URL, 68 | version=VERSION, 69 | download_url=DOWNLOAD_URL, 70 | long_description=LONG_DESCRIPTION, 71 | zip_safe=False, # the package can run out of an .egg file 72 | classifiers=CLASSIFIERS, 73 | packages=find_packages(), 74 | install_requires=INSTALL_REQUIRES, 75 | extras_require=EXTRAS_REQUIRE, 76 | ) 77 | -------------------------------------------------------------------------------- /skdag/__init__.py: -------------------------------------------------------------------------------- 1 | from skdag.dag import * 2 | 3 | from skdag._version import __version__ 4 | 5 | __all__ = [ 6 | "DAG", 7 | "DAGBuilder", 8 | "DAGRenderer", 9 | "DAGStep", 10 | "__version__", 11 | ] 12 | -------------------------------------------------------------------------------- /skdag/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.7" 2 | -------------------------------------------------------------------------------- /skdag/dag/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :mod:`skdag.dag` module implements utilities to build a composite 3 | estimator, as a directed acyclic graph (DAG). A DAG may have one or more inputs (roots) 4 | and also one or more outputs (leaves), but may not contain any cyclic processing paths. 5 | """ 6 | # Author: big-o (github) 7 | # License: BSD 8 | 9 | from skdag.dag._dag import * 10 | from skdag.dag._builder import * 11 | from skdag.dag._render import * 12 | -------------------------------------------------------------------------------- /skdag/dag/_builder.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping, Sequence 2 | 3 | import networkx as nx 4 | from skdag.dag._dag import DAG, DAGStep 5 | from skdag.exceptions import DAGError 6 | 7 | __all__ = ["DAGBuilder"] 8 | 9 | 10 | class DAGBuilder: 11 | """ 12 | Helper utility for creating a :class:`skdag.DAG`. 13 | 14 | ``DAGBuilder`` allows a graph to be defined incrementally by specifying one node 15 | (step) at a time. Graph edges are defined by providing optional dependency lists 16 | that reference each step by name. Note that steps must be defined before they are 17 | used as dependencies. 18 | 19 | Parameters 20 | ---------- 21 | 22 | infer_dataframe : bool, default = False 23 | If True, assume ``dataframe_columns="infer"`` every time :meth:`.add_step` is 24 | called, if ``dataframe_columns`` is set to ``None``. This effectively makes the 25 | resulting DAG always try to coerce output into pandas DataFrames wherever 26 | possible. 27 | 28 | See Also 29 | -------- 30 | :class:`skdag.DAG` : The estimator DAG created by this utility. 31 | 32 | Examples 33 | -------- 34 | 35 | >>> from skdag import DAGBuilder 36 | >>> from sklearn.decomposition import PCA 37 | >>> from sklearn.impute import SimpleImputer 38 | >>> from sklearn.linear_model import LogisticRegression 39 | >>> dag = ( 40 | ... DAGBuilder() 41 | ... .add_step("impute", SimpleImputer()) 42 | ... .add_step("vitals", "passthrough", deps={"impute": slice(0, 4)}) 43 | ... .add_step("blood", PCA(n_components=2, random_state=0), deps={"impute": slice(4, 10)}) 44 | ... .add_step("lr", LogisticRegression(random_state=0), deps=["blood", "vitals"]) 45 | ... .make_dag() 46 | ... ) 47 | >>> print(dag.draw().strip()) 48 | o impute 49 | |\\ 50 | o o blood,vitals 51 | |/ 52 | o lr 53 | """ 54 | 55 | def __init__(self, infer_dataframe=False): 56 | self.graph = nx.DiGraph() 57 | self.infer_dataframe = infer_dataframe 58 | 59 | def from_pipeline(self, steps, **kwargs): 60 | """ 61 | Construct a DAG from a simple linear sequence of steps. The resulting DAG will 62 | be equivalent to a :class:`~sklearn.pipeline.Pipeline`. 63 | 64 | Parameters 65 | ---------- 66 | 67 | steps : sequence of (str, estimator) 68 | An ordered sequence of pipeline steps. A step is simply a pair of 69 | ``(name, estimator)``, just like a scikit-learn Pipeline. 70 | 71 | infer_dataframe : bool, default = False 72 | If True, assume ``dataframe_columns="infer"`` every time :meth:`.add_step` 73 | is called, if ``dataframe_columns`` is set to ``None``. This effectively 74 | makes the resulting DAG always try to coerce output into pandas DataFrames 75 | wherever possible. 76 | 77 | kwargs : kwargs 78 | Any other hyperparameters that are accepted by :class:`~skdag.dag.DAG`'s 79 | contructor. 80 | """ 81 | if hasattr(steps, "steps"): 82 | pipe = steps 83 | steps = pipe.steps 84 | if hasattr(pipe, "get_params"): 85 | kwargs = { 86 | **{ 87 | k: v 88 | for k, v in pipe.get_params().items() 89 | if k in ("memory", "verbose") 90 | }, 91 | **kwargs, 92 | } 93 | 94 | dfcols = "infer" if self.infer_dataframe else None 95 | 96 | for i in range(len(steps)): 97 | name, estimator = steps[i] 98 | self._validate_name(name) 99 | deps = {} 100 | if i > 0: 101 | dep = steps[i - 1][0] 102 | deps[dep] = None 103 | self._validate_deps(deps) 104 | 105 | step = DAGStep(name, estimator, deps, dataframe_columns=dfcols) 106 | self.graph.add_node(name, step=step) 107 | if deps: 108 | self.graph.add_edge(dep, name) 109 | 110 | self._validate_graph() 111 | 112 | return self 113 | 114 | def add_step(self, name, est, deps=None, dataframe_columns=None): 115 | self._validate_name(name) 116 | if isinstance(deps, Sequence): 117 | deps = {dep: None for dep in deps} 118 | 119 | if deps is not None: 120 | self._validate_deps(deps) 121 | else: 122 | deps = {} 123 | 124 | if dataframe_columns is None and self.infer_dataframe: 125 | dfcols = "infer" 126 | else: 127 | dfcols = dataframe_columns 128 | 129 | step = DAGStep(name, est, deps=deps, dataframe_columns=dfcols) 130 | self.graph.add_node(name, step=step) 131 | 132 | for dep in deps: 133 | # Since node is new, edges will never form a cycle. 134 | self.graph.add_edge(dep, name) 135 | 136 | self._validate_graph() 137 | 138 | return self 139 | 140 | def _validate_name(self, name): 141 | if not isinstance(name, str): 142 | raise KeyError(f"step names must be strings, got '{type(name)}'") 143 | 144 | if name in self.graph.nodes: 145 | raise KeyError(f"step with name '{name}' already exists") 146 | 147 | def _validate_deps(self, deps): 148 | if not isinstance(deps, Mapping) or not all( 149 | [isinstance(dep, str) for dep in deps] 150 | ): 151 | raise ValueError( 152 | "deps parameter must be a map of labels to indices, " 153 | f"got '{type(deps)}'." 154 | ) 155 | 156 | missing = [dep for dep in deps if dep not in self.graph] 157 | if missing: 158 | raise ValueError(f"unresolvable dependencies: {', '.join(sorted(missing))}") 159 | 160 | def _validate_graph(self): 161 | if not nx.algorithms.dag.is_directed_acyclic_graph(self.graph): 162 | raise DAGError("Workflow is not a DAG.") 163 | 164 | def make_dag(self, **kwargs): 165 | self._validate_graph() 166 | # Give the DAG a read-only view of the graph. 167 | return DAG(graph=self.graph.copy(as_view=True), **kwargs) 168 | 169 | def _repr_html_(self): 170 | return self.make_dag()._repr_html_() 171 | -------------------------------------------------------------------------------- /skdag/dag/_dag.py: -------------------------------------------------------------------------------- 1 | """ 2 | Directed Acyclic Graphs (DAGs) may be used to construct complex workflows for 3 | scikit-learn estimators. As the name suggests, data may only flow in one 4 | direction and can't go back on itself to a previously run step. 5 | """ 6 | from collections import UserDict 7 | from copy import deepcopy 8 | from inspect import signature 9 | from itertools import chain 10 | from typing import Iterable 11 | 12 | import networkx as nx 13 | import numpy as np 14 | from joblib import Parallel, delayed 15 | from scipy.sparse import dok_matrix, issparse 16 | from skdag.dag._render import DAGRenderer 17 | from skdag.dag._utils import ( 18 | _format_output, 19 | _in_notebook, 20 | _is_pandas, 21 | _is_passthrough, 22 | _is_predictor, 23 | _is_transformer, 24 | _stack, 25 | ) 26 | from sklearn.base import clone 27 | from sklearn.exceptions import NotFittedError 28 | from sklearn.utils import Bunch, _safe_indexing, deprecated 29 | from sklearn.utils._tags import _safe_tags 30 | from sklearn.utils.metaestimators import _BaseComposition, available_if 31 | from sklearn.utils.validation import check_is_fitted, check_memory 32 | 33 | try: 34 | from sklearn.utils import _print_elapsed_time 35 | except ImportError: 36 | from sklearn.utils._user_interface import _print_elapsed_time 37 | 38 | __all__ = ["DAG", "DAGStep"] 39 | 40 | 41 | def _get_columns(X, dep, cols, is_root, dep_is_passthrough, axis=1): 42 | if callable(cols): 43 | # sklearn.compose.make_column_selector 44 | cols = cols(X) 45 | 46 | if not is_root and not dep_is_passthrough: 47 | # The DAG will prepend output columns with the step name, so add this in to any 48 | # dep columns if missing. This helps keep user-provided deps readable. 49 | if isinstance(cols, str): 50 | cols = cols if cols.startswith(f"{dep}__") else f"{dep}__{cols}" 51 | elif isinstance(cols, Iterable): 52 | orig = cols 53 | cols = [] 54 | for col in orig: 55 | if isinstance(col, str): 56 | cols.append(col if col.startswith(f"{dep}__") else f"{dep}__{col}") 57 | else: 58 | cols.append(col) 59 | 60 | return _safe_indexing(X, cols, axis=axis) 61 | 62 | 63 | def _stack_inputs(dag, X, node): 64 | # For root nodes, the dependency is just the node name itself. 65 | deps = {node.name: None} if node.is_root else node.deps 66 | 67 | cols = [ 68 | _get_columns( 69 | X[dep], 70 | dep, 71 | cols, 72 | node.is_root, 73 | _is_passthrough(dag.graph_.nodes[dep]["step"].estimator), 74 | axis=1, 75 | ) 76 | for dep, cols in deps.items() 77 | ] 78 | 79 | to_stack = [ 80 | # If we sliced a single column from an input, reshape it to a 2d array. 81 | col.reshape(-1, 1) 82 | if col is not None and deps[dep] is not None and col.ndim < 2 83 | else col 84 | for col, dep in zip(cols, deps) 85 | ] 86 | 87 | X_stacked = _stack(to_stack, axis=node.axis) 88 | 89 | return X_stacked 90 | 91 | 92 | def _leaf_estimators_have(attr, how="all"): 93 | """Check that leaves have `attr`. 94 | Used together with `avaliable_if` in `DAG`.""" 95 | 96 | def check_leaves(self): 97 | # raises `AttributeError` with all details if `attr` does not exist 98 | failed = [] 99 | for leaf in self.leaves_: 100 | try: 101 | _is_passthrough(leaf.estimator) or getattr(leaf.estimator, attr) 102 | except AttributeError: 103 | failed.append(leaf.estimator) 104 | 105 | if (how == "all" and failed) or ( 106 | how == "any" and len(failed) != len(self.leaves_) 107 | ): 108 | raise AttributeError( 109 | f"{', '.join([repr(type(est)) for est in failed])} " 110 | f"object(s) has no attribute '{attr}'" 111 | ) 112 | return True 113 | 114 | return check_leaves 115 | 116 | 117 | def _transform_one(transformer, X, weight, allow_predictor=True, **fit_params): 118 | if _is_passthrough(transformer): 119 | res = X 120 | elif allow_predictor and not hasattr(transformer, "transform"): 121 | for fn in ["predict_proba", "decision_function", "predict"]: 122 | if hasattr(transformer, fn): 123 | res = getattr(transformer, fn)(X) 124 | if res.ndim < 2: 125 | res = res.reshape(-1, 1) 126 | break 127 | else: 128 | raise AttributeError( 129 | f"'{type(transformer).__name__}' object has no attribute 'transform'" 130 | ) 131 | else: 132 | res = transformer.transform(X) 133 | # if we have a weight for this transformer, multiply output 134 | if weight is not None: 135 | res = res * weight 136 | 137 | return res 138 | 139 | 140 | def _fit_transform_one( 141 | transformer, 142 | X, 143 | y, 144 | weight, 145 | message_clsname="", 146 | message=None, 147 | allow_predictor=True, 148 | **fit_params, 149 | ): 150 | """ 151 | Fits ``transformer`` to ``X`` and ``y``. The transformed result is returned 152 | with the fitted transformer. If ``weight`` is not ``None``, the result will 153 | be multiplied by ``weight``. 154 | """ 155 | with _print_elapsed_time(message_clsname, message): 156 | failed = False 157 | if _is_passthrough(transformer): 158 | res = X 159 | elif hasattr(transformer, "fit_transform"): 160 | res = transformer.fit_transform(X, y, **fit_params) 161 | elif hasattr(transformer, "transform"): 162 | res = transformer.fit(X, y, **fit_params).transform(X) 163 | elif allow_predictor: 164 | for fn in ["predict_proba", "decision_function", "predict"]: 165 | if hasattr(transformer, fn): 166 | res = getattr(transformer.fit(X, y, **fit_params), fn)(X) 167 | if res.ndim < 2: 168 | res = res.reshape(-1, 1) 169 | break 170 | else: 171 | failed = True 172 | res = None 173 | 174 | if res is not None and res.ndim < 2: 175 | res = res.reshape(-1, 1) 176 | else: 177 | failed = True 178 | 179 | if failed: 180 | raise AttributeError( 181 | f"'{type(transformer).__name__}' object has no attribute 'transform'" 182 | ) 183 | 184 | if weight is not None: 185 | res = res * weight 186 | 187 | return res, transformer 188 | 189 | 190 | def _parallel_fit(dag, step, Xin, Xs, y, fit_transform_fn, memory, **fit_params): 191 | transformer = step.estimator 192 | 193 | if step.deps: 194 | X = _stack_inputs(dag, Xs, step) 195 | else: 196 | # For root nodes, the destination rather than the source is 197 | # specified. 198 | # X = Xin[step.name] 199 | X = _stack_inputs(dag, Xin, step) 200 | 201 | clsname = type(dag).__name__ 202 | with _print_elapsed_time(clsname, dag._log_message(step)): 203 | if transformer is None or transformer == "passthrough": 204 | Xt, fitted_transformer = X, transformer 205 | else: 206 | if hasattr(memory, "location") and memory.location is None: 207 | # we do not clone when caching is disabled to 208 | # preserve backward compatibility 209 | cloned_transformer = transformer 210 | else: 211 | cloned_transformer = clone(transformer) 212 | 213 | # Fit or load from cache the current transformer 214 | Xt, fitted_transformer = fit_transform_fn( 215 | cloned_transformer, 216 | X, 217 | y, 218 | None, 219 | message_clsname=clsname, 220 | message=dag._log_message(step), 221 | **fit_params, 222 | ) 223 | 224 | Xt = _format_output(Xt, X, step) 225 | 226 | return Xt, fitted_transformer 227 | 228 | 229 | def _parallel_transform(dag, step, Xin, Xs, transform_fn, **fn_params): 230 | transformer = step.estimator 231 | if step.deps: 232 | X = _stack_inputs(dag, Xs, step) 233 | else: 234 | # For root nodes, the destination rather than the source is 235 | # specified. 236 | X = _stack_inputs(dag, Xin, step) 237 | # X = Xin[step.name] 238 | 239 | clsname = type(dag).__name__ 240 | with _print_elapsed_time(clsname, dag._log_message(step)): 241 | if transformer is None or transformer == "passthrough": 242 | Xt = X 243 | else: 244 | # Fit or load from cache the current transformer 245 | Xt = transform_fn( 246 | transformer, 247 | X, 248 | None, 249 | message_clsname=clsname, 250 | message=dag._log_message(step), 251 | **fn_params, 252 | ) 253 | 254 | Xt = _format_output(Xt, X, step) 255 | 256 | return Xt 257 | 258 | 259 | def _parallel_fit_leaf(dag, leaf, Xts, y, **fit_params): 260 | with _print_elapsed_time(type(dag).__name__, dag._log_message(leaf)): 261 | if leaf.estimator == "passthrough": 262 | fitted_estimator = leaf.estimator 263 | else: 264 | Xt = _stack_inputs(dag, Xts, leaf) 265 | fitted_estimator = leaf.estimator.fit(Xt, y, **fit_params) 266 | 267 | return fitted_estimator 268 | 269 | 270 | def _parallel_execute( 271 | dag, leaf, fn, Xts, y=None, fit_first=False, fit_params=None, fn_params=None 272 | ): 273 | with _print_elapsed_time("DAG", dag._log_message(leaf)): 274 | Xt = _stack_inputs(dag, Xts, leaf) 275 | fit_params = fit_params or {} 276 | fn_params = fn_params or {} 277 | if leaf.estimator == "passthrough": 278 | Xout = Xt 279 | elif fit_first and hasattr(leaf.estimator, f"fit_{fn}"): 280 | Xout = getattr(leaf.estimator, f"fit_{fn}")(Xt, y, **fit_params) 281 | else: 282 | if fit_first: 283 | leaf.estimator.fit(Xt, y, **fit_params) 284 | 285 | est_fn = getattr(leaf.estimator, fn) 286 | if "y" in signature(est_fn).parameters: 287 | Xout = est_fn(Xt, y=y, **fn_params) 288 | else: 289 | Xout = est_fn(Xt, **fn_params) 290 | 291 | Xout = _format_output(Xout, Xt, leaf) 292 | 293 | fitted_estimator = leaf.estimator 294 | 295 | return Xout, fitted_estimator 296 | 297 | 298 | class DAGStep: 299 | """ 300 | A single estimator step in a DAG. 301 | 302 | Parameters 303 | ---------- 304 | name : str 305 | The reference name for this step. 306 | estimator : estimator-like 307 | The estimator (transformer or predictor) that will be executed by this step. 308 | deps : dict 309 | A map of dependency names to columns. If columns is ``None``, then all input 310 | columns will be selected. 311 | dataframe_columns : list of str or "infer" (optional) 312 | Either a hard-coded list of column names to apply to any output data, or the 313 | string "infer", which means the column outputs will be assumed to match the 314 | column inputs if the output is 2d and not already a dataframe, the estimator is 315 | a transformer, and the final axis dimensions match the inputs. Otherwise the 316 | column names will be assumed to be the step name + index if the output is not 317 | already a dataframe. If set to ``None`` or inference is not possible, the 318 | outputs will be left unmodified. 319 | axis : int, default = 1 320 | The strategy for merging inputs if there is more than upstream dependency. 321 | ``axis=0`` will assume all inputs have the same features and stack the rows 322 | together; ``axis=1`` will assume each input provides different features for the 323 | same samples. 324 | """ 325 | 326 | def __init__(self, name, estimator, deps, dataframe_columns, axis=1): 327 | self.name = name 328 | self.estimator = estimator 329 | self.deps = deps 330 | self.dataframe_columns = dataframe_columns 331 | self.axis = axis 332 | self.index = None 333 | self.is_root = False 334 | self.is_leaf = False 335 | self.is_fitted = False 336 | 337 | def __repr__(self): 338 | return f"{type(self).__name__}({repr(self.name)}, {repr(self.estimator)})" 339 | 340 | 341 | class DAG(_BaseComposition): 342 | """ 343 | A Directed Acyclic Graph (DAG) of estimators, that itself implements the estimator 344 | interface. 345 | 346 | A DAG may consist of a simple chain of estimators (being exactly equivalent to a 347 | :mod:`sklearn.pipeline.Pipeline`) or a more complex path of dependencies. But as the 348 | name suggests, it may not contain any cyclic dependencies and data may only flow 349 | from one or more start points (roots) to one or more endpoints (leaves). 350 | 351 | Parameters 352 | ---------- 353 | 354 | graph : :class:`networkx.DiGraph` 355 | A directed graph with string node IDs indicating the step name. Each node must 356 | have a ``step`` attribute, which contains a :class:`skdag.dag.DAGStep`. 357 | 358 | memory : str or object with the joblib.Memory interface, default=None 359 | Used to cache the fitted transformers of the DAG. By default, no caching is 360 | performed. If a string is given, it is the path to the caching directory. 361 | Enabling caching triggers a clone of the transformers before fitting. Therefore, 362 | the transformer instance given to the DAG cannot be inspected directly. Use the 363 | attribute ``named_steps`` or ``steps`` to inspect estimators within the 364 | pipeline. Caching the transformers is advantageous when fitting is time 365 | consuming. 366 | 367 | n_jobs : int, default=None 368 | Number of jobs to run in parallel. ``None`` means 1 unless in a 369 | :obj:`joblib.parallel_backend` context. 370 | 371 | verbose : bool, default=False 372 | If True, the time elapsed while fitting each step will be printed as it is 373 | completed. 374 | 375 | Attributes 376 | ---------- 377 | 378 | graph_ : :class:`networkx.DiGraph` 379 | A read-only view of the workflow. 380 | 381 | classes_ : ndarray of shape (n_classes,) 382 | The classes labels. Only exists if the last step of the pipeline is a 383 | classifier. 384 | 385 | n_features_in_ : int 386 | Number of features seen during :term:`fit`. Only defined if all of the 387 | underlying root estimators in `graph_` expose such an attribute when fit. 388 | 389 | feature_names_in_ : ndarray of shape (`n_features_in_`,) 390 | Names of features seen during :term:`fit`. Only defined if the underlying 391 | estimators expose such an attribute when fit. 392 | 393 | See Also 394 | -------- 395 | :class:`skdag.DAGBuilder` : Convenience utility for simplified DAG construction. 396 | 397 | Examples 398 | -------- 399 | 400 | The simplest DAGs are just a chain of singular dependencies. These DAGs may be 401 | created from the :meth:`skdag.dag.DAG.from_pipeline` method in the same way as a 402 | DAG: 403 | 404 | >>> from sklearn.decomposition import PCA 405 | >>> from sklearn.impute import SimpleImputer 406 | >>> from sklearn.linear_model import LogisticRegression 407 | >>> dag = DAG.from_pipeline( 408 | ... steps=[ 409 | ... ("impute", SimpleImputer()), 410 | ... ("pca", PCA()), 411 | ... ("lr", LogisticRegression()) 412 | ... ] 413 | ... ) 414 | >>> print(dag.draw().strip()) 415 | o impute 416 | | 417 | o pca 418 | | 419 | o lr 420 | 421 | For more complex DAGs, it is recommended to use a :class:`skdag.dag.DAGBuilder`, 422 | which allows you to define the graph by specifying the dependencies of each new 423 | estimator: 424 | 425 | >>> from skdag import DAGBuilder 426 | >>> dag = ( 427 | ... DAGBuilder() 428 | ... .add_step("impute", SimpleImputer()) 429 | ... .add_step("vitals", "passthrough", deps={"impute": slice(0, 4)}) 430 | ... .add_step("blood", PCA(n_components=2, random_state=0), deps={"impute": slice(4, 10)}) 431 | ... .add_step("lr", LogisticRegression(random_state=0), deps=["blood", "vitals"]) 432 | ... .make_dag() 433 | ... ) 434 | >>> print(dag.draw().strip()) 435 | o impute 436 | |\\ 437 | o o blood,vitals 438 | |/ 439 | o lr 440 | 441 | In the above examples we pass the first four columns directly to a regressor, but 442 | the remaining columns have dimensionality reduction applied first before being 443 | passed to the same regressor. Note that we can define our graph edges in two 444 | different ways: as a dict (if we need to select only certain columns from the source 445 | node) or as a simple list (if we want to simply grab all columns from all input 446 | nodes). 447 | 448 | The DAG may now be used as an estimator in its own right: 449 | 450 | >>> from sklearn import datasets 451 | >>> X, y = datasets.load_diabetes(return_X_y=True) 452 | >>> dag.fit_predict(X, y) 453 | array([... 454 | 455 | In an extension to the scikit-learn estimator interface, DAGs also support multiple 456 | inputs and multiple outputs. Let's say we want to compare two different classifiers: 457 | 458 | >>> from sklearn.ensemble import RandomForestClassifier 459 | >>> cal = DAG.from_pipeline( 460 | ... [("rf", RandomForestClassifier(random_state=0))] 461 | ... ) 462 | >>> dag2 = dag.join(cal, edges=[("blood", "rf"), ("vitals", "rf")]) 463 | >>> print(dag2.draw().strip()) 464 | o impute 465 | |\\ 466 | o o blood,vitals 467 | |x| 468 | o o lr,rf 469 | 470 | Now our DAG will return two outputs: one from each classifier. Multiple outputs are 471 | returned as a :class:`sklearn.utils.Bunch`: 472 | 473 | >>> y_pred = dag2.fit_predict(X, y) 474 | >>> y_pred.lr 475 | array([... 476 | >>> y_pred.rf 477 | array([... 478 | 479 | Similarly, multiple inputs are also acceptable and inputs can be provided by 480 | specifying ``X`` and ``y`` as a ``dict``-like object. 481 | """ 482 | 483 | # BaseEstimator interface 484 | _required_parameters = ["graph"] 485 | 486 | @classmethod 487 | @deprecated( 488 | "DAG.from_pipeline is deprecated in 0.0.3 and will be removed in a future " 489 | "release. Please use DAGBuilder.from_pipeline instead." 490 | ) 491 | def from_pipeline(cls, steps, **kwargs): 492 | from skdag.dag._builder import DAGBuilder 493 | 494 | return DAGBuilder().from_pipeline(steps, **kwargs).make_dag() 495 | 496 | def __init__(self, graph, *, memory=None, n_jobs=None, verbose=False): 497 | self.graph = graph 498 | self.memory = memory 499 | self.verbose = verbose 500 | self.n_jobs = n_jobs 501 | 502 | def get_params(self, deep=True): 503 | """ 504 | Get parameters for this metaestimator. 505 | 506 | Returns the parameters given in the constructor as well as the 507 | estimators contained within the `steps_` of the `DAG`. 508 | 509 | Parameters 510 | ---------- 511 | deep : bool, default=True 512 | If True, will return the parameters for this estimator and 513 | contained subobjects that are estimators. 514 | 515 | Returns 516 | ------- 517 | params : mapping of string to any 518 | Parameter names mapped to their values. 519 | """ 520 | return self._get_params("steps_", deep=deep) 521 | 522 | def set_params(self, **params): 523 | """ 524 | Set the parameters of this metaestimator. 525 | 526 | Valid parameter keys can be listed with ``get_params()``. Note that 527 | you can directly set the parameters of the estimators contained in 528 | `steps_`. 529 | 530 | Parameters 531 | ---------- 532 | **params : dict 533 | Parameters of this metaestimator or parameters of estimators contained 534 | in `steps`. Parameters of the steps may be set using its name and 535 | the parameter name separated by a '__'. 536 | 537 | Returns 538 | ------- 539 | self : object 540 | DAG class instance. 541 | """ 542 | step_names = set(self.step_names) 543 | for param in list(params.keys()): 544 | if "__" not in param and param in step_names: 545 | self.graph_.nodes[param]["step"].estimator = params.pop(param) 546 | 547 | super().set_params(**params) 548 | return self 549 | 550 | def _log_message(self, step): 551 | if not self.verbose: 552 | return None 553 | 554 | return f"(step {step.name}: {step.index} of {len(self.graph_)}) Processing {step.name}" 555 | 556 | def _iter(self, with_leaves=True, filter_passthrough=True): 557 | """ 558 | Generate stage lists from self.graph_. 559 | When filter_passthrough is True, 'passthrough' and None transformers 560 | are filtered out. 561 | """ 562 | for stage in nx.topological_generations(self.graph_): 563 | stage = [self.graph_.nodes[step]["step"] for step in stage] 564 | if not with_leaves: 565 | stage = [step for step in stage if not step.is_leaf] 566 | 567 | if filter_passthrough: 568 | stage = [ 569 | step 570 | for step in stage 571 | if step.estimator is not None and step.estimator != "passthough" 572 | ] 573 | 574 | if len(stage) == 0: 575 | continue 576 | 577 | yield stage 578 | 579 | def __len__(self): 580 | """ 581 | Returns the size of the DAG 582 | """ 583 | return len(self.graph_) 584 | 585 | def __getitem__(self, name): 586 | """ 587 | Retrieve a named estimator. 588 | """ 589 | return self.graph_.nodes[name]["step"].estimator 590 | 591 | def _fit(self, X, y=None, **fit_params_steps): 592 | # Setup the memory 593 | memory = check_memory(self.memory) 594 | 595 | fit_transform_one_cached = memory.cache(_fit_transform_one) 596 | 597 | root_names = set([root.name for root in self.roots_]) 598 | Xin = self._resolve_inputs(X) 599 | Xs = {} 600 | with Parallel(n_jobs=self.n_jobs) as parallel: 601 | for stage in self._iter(with_leaves=False, filter_passthrough=False): 602 | stage_names = [step.name for step in stage] 603 | outputs, fitted_transformers = zip( 604 | *parallel( 605 | delayed(_parallel_fit)( 606 | self, 607 | step, 608 | Xin, 609 | Xs, 610 | y, 611 | fit_transform_one_cached, 612 | memory, 613 | **fit_params_steps[step.name], 614 | ) 615 | for step in stage 616 | ) 617 | ) 618 | 619 | for step, fitted_transformer in zip(stage, fitted_transformers): 620 | # Replace the transformer of the step with the fitted 621 | # transformer. This is necessary when loading the transformer 622 | # from the cache. 623 | step.estimator = fitted_transformer 624 | step.is_fitted = True 625 | 626 | Xs.update(dict(zip(stage_names, outputs))) 627 | 628 | # If all of a dep's dependents are now complete, we can free up some 629 | # memory. 630 | root_names = root_names - set(stage_names) 631 | for dep in {dep for step in stage for dep in step.deps}: 632 | dependents = self.graph_.successors(dep) 633 | if all(d in Xs and d not in root_names for d in dependents): 634 | del Xs[dep] 635 | 636 | # If a root node is also a leaf, it hasn't been fit yet and we need to pass on 637 | # its input for later. 638 | Xs.update({name: Xin[name] for name in root_names}) 639 | return Xs 640 | 641 | def _transform(self, X, **fn_params_steps): 642 | # Setup the memory 643 | memory = check_memory(self.memory) 644 | 645 | transform_one_cached = memory.cache(_transform_one) 646 | 647 | root_names = set([root.name for root in self.roots_]) 648 | Xin = self._resolve_inputs(X) 649 | Xs = {} 650 | with Parallel(n_jobs=self.n_jobs) as parallel: 651 | for stage in self._iter(with_leaves=False, filter_passthrough=False): 652 | stage_names = [step.name for step in stage] 653 | outputs = parallel( 654 | delayed(_parallel_transform)( 655 | self, 656 | step, 657 | Xin, 658 | Xs, 659 | transform_one_cached, 660 | **fn_params_steps[step.name], 661 | ) 662 | for step in stage 663 | ) 664 | 665 | Xs.update(dict(zip(stage_names, outputs))) 666 | 667 | # If all of a dep's dependents are now complete, we can free up some 668 | # memory. 669 | root_names = root_names - set(stage_names) 670 | for dep in {dep for step in stage for dep in step.deps}: 671 | dependents = self.graph_.successors(dep) 672 | if all(d in Xs and d not in root_names for d in dependents): 673 | del Xs[dep] 674 | 675 | # If a root node is also a leaf, it hasn't been fit yet and we need to pass on 676 | # its input for later. 677 | Xs.update({name: Xin[name] for name in root_names}) 678 | return Xs 679 | 680 | def _resolve_inputs(self, X): 681 | if isinstance(X, (dict, Bunch, UserDict)) and not isinstance(X, dok_matrix): 682 | inputs = sorted(X.keys()) 683 | if inputs != sorted(root.name for root in self.roots_): 684 | raise ValueError( 685 | "Input dicts must contain one key per entry node. " 686 | f"Entry nodes are {self.roots_}, got {inputs}." 687 | ) 688 | else: 689 | if len(self.roots_) != 1: 690 | raise ValueError( 691 | "Must provide a dictionary of inputs for a DAG with multiple entry " 692 | "points." 693 | ) 694 | X = {self.roots_[0].name: X} 695 | 696 | X = { 697 | step: x if issparse(x) or _is_pandas(x) else np.asarray(x) 698 | for step, x in X.items() 699 | } 700 | 701 | return X 702 | 703 | def _match_input_format(self, Xin, Xout): 704 | if len(self.leaves_) == 1 and ( 705 | not isinstance(Xin, (dict, Bunch, UserDict)) or isinstance(Xin, dok_matrix) 706 | ): 707 | return Xout[self.leaves_[0].name] 708 | return Bunch(**Xout) 709 | 710 | def fit(self, X, y=None, **fit_params): 711 | """ 712 | Fit the model. 713 | 714 | Fit all the transformers one after the other and transform the 715 | data. Finally, fit the transformed data using the final estimators. 716 | 717 | Parameters 718 | ---------- 719 | X : iterable 720 | Training data. Must fulfill input requirements of first step of the 721 | DAG. 722 | y : iterable, default=None 723 | Training targets. Must fulfill label requirements for all steps of 724 | the DAG. 725 | **fit_params : dict of string -> object 726 | Parameters passed to the ``fit`` method of each step, where 727 | each parameter name is prefixed such that parameter ``p`` for step 728 | ``s`` has key ``s__p``. 729 | 730 | Returns 731 | ------- 732 | self : object 733 | DAG fitted steps. 734 | """ 735 | self._validate_graph() 736 | fit_params_steps = self._check_fit_params(**fit_params) 737 | Xts = self._fit(X, y, **fit_params_steps) 738 | fitted_estimators = Parallel(n_jobs=self.n_jobs)( 739 | [ 740 | delayed(_parallel_fit_leaf)( 741 | self, leaf, Xts, y, **fit_params_steps[leaf.name] 742 | ) 743 | for leaf in self.leaves_ 744 | ] 745 | ) 746 | for est, leaf in zip(fitted_estimators, self.leaves_): 747 | leaf.estimator = est 748 | leaf.is_fitted = True 749 | 750 | # If we have a single root, mirror certain attributes in the DAG. 751 | if len(self.roots_) == 1: 752 | root = self.roots_[0].estimator 753 | for attr in ["n_features_in_", "feature_names_in_"]: 754 | if hasattr(root, attr): 755 | setattr(self, attr, getattr(root, attr)) 756 | 757 | return self 758 | 759 | def _fit_execute(self, fn, X, y=None, **fit_params): 760 | self._validate_graph() 761 | fit_params_steps = self._check_fit_params(**fit_params) 762 | Xts = self._fit(X, y, **fit_params_steps) 763 | Xout = {} 764 | 765 | leaf_names = [leaf.name for leaf in self.leaves_] 766 | outputs, fitted_estimators = zip( 767 | *Parallel(n_jobs=self.n_jobs)( 768 | delayed(_parallel_execute)( 769 | self, 770 | leaf, 771 | fn, 772 | Xts, 773 | y, 774 | fit_first=True, 775 | fit_params=fit_params_steps[leaf.name], 776 | ) 777 | for leaf in self.leaves_ 778 | ) 779 | ) 780 | 781 | Xout = dict(zip(leaf_names, outputs)) 782 | for step, fitted_estimator in zip(self.leaves_, fitted_estimators): 783 | step.estimator = fitted_estimator 784 | step.is_fitted = True 785 | 786 | return self._match_input_format(X, Xout) 787 | 788 | def _execute(self, fn, X, y=None, **fn_params): 789 | Xout = {} 790 | fn_params_steps = self._check_fn_params(**fn_params) 791 | Xts = self._transform(X, **fn_params_steps) 792 | 793 | leaf_names = [leaf.name for leaf in self.leaves_] 794 | outputs, _ = zip( 795 | *Parallel(n_jobs=self.n_jobs)( 796 | delayed(_parallel_execute)( 797 | self, 798 | leaf, 799 | fn, 800 | Xts, 801 | y, 802 | fit_first=False, 803 | fn_params=fn_params_steps[leaf.name], 804 | ) 805 | for leaf in self.leaves_ 806 | ) 807 | ) 808 | 809 | Xout = dict(zip(leaf_names, outputs)) 810 | 811 | return self._match_input_format(X, Xout) 812 | 813 | @available_if(_leaf_estimators_have("transform")) 814 | def fit_transform(self, X, y=None, **fit_params): 815 | """ 816 | Fit the model and transform with the final estimator. 817 | 818 | Fits all the transformers one after the other and transform the 819 | data. Then uses `fit_transform` on transformed data with the final 820 | estimator. 821 | 822 | Parameters 823 | ---------- 824 | X : iterable 825 | Training data. Must fulfill input requirements of first step of the 826 | DAG. 827 | y : iterable, default=None 828 | Training targets. Must fulfill label requirements for all steps of 829 | the DAG. 830 | **fit_params : dict of string -> object 831 | Parameters passed to the ``fit`` method of each step, where 832 | each parameter name is prefixed such that parameter ``p`` for step 833 | ``s`` has key ``s__p``. 834 | 835 | Returns 836 | ------- 837 | Xt : ndarray of shape (n_samples, n_transformed_features) 838 | Transformed samples. 839 | """ 840 | return self._fit_execute("transform", X, y, **fit_params) 841 | 842 | @available_if(_leaf_estimators_have("transform")) 843 | def transform(self, X): 844 | """ 845 | Transform the data, and apply `transform` with the final estimator. 846 | 847 | Call `transform` of each transformer in the DAG. The transformed 848 | data are finally passed to the final estimator that calls 849 | `transform` method. Only valid if the final estimator 850 | implements `transform`. 851 | 852 | This also works where final estimator is `None` in which case all prior 853 | transformations are applied. 854 | 855 | Parameters 856 | ---------- 857 | X : iterable 858 | Data to transform. Must fulfill input requirements of first step 859 | of the DAG. 860 | 861 | Returns 862 | ------- 863 | Xt : ndarray of shape (n_samples, n_transformed_features) 864 | Transformed data. 865 | """ 866 | return self._execute("transform", X) 867 | 868 | @available_if(_leaf_estimators_have("predict")) 869 | def fit_predict(self, X, y=None, **fit_params): 870 | """ 871 | Transform the data, and apply `fit_predict` with the final estimator. 872 | 873 | Call `fit_transform` of each transformer in the DAG. The transformed data are 874 | finally passed to the final estimator that calls `fit_predict` method. Only 875 | valid if the final estimators implement `fit_predict`. 876 | 877 | Parameters 878 | ---------- 879 | X : iterable 880 | Training data. Must fulfill input requirements of first step of 881 | the DAG. 882 | y : iterable, default=None 883 | Training targets. Must fulfill label requirements for all steps 884 | of the DAG. 885 | **fit_params : dict of string -> object 886 | Parameters passed to the ``fit`` method of each step, where 887 | each parameter name is prefixed such that parameter ``p`` for step 888 | ``s`` has key ``s__p``. 889 | 890 | Returns 891 | ------- 892 | y_pred : ndarray 893 | Result of calling `fit_predict` on the final estimator. 894 | """ 895 | return self._fit_execute("predict", X, y, **fit_params) 896 | 897 | @available_if(_leaf_estimators_have("predict")) 898 | def predict(self, X, **predict_params): 899 | """ 900 | Transform the data, and apply `predict` with the final estimator. 901 | 902 | Call `transform` of each transformer in the DAG. The transformed 903 | data are finally passed to the final estimator that calls `predict` 904 | method. Only valid if the final estimators implement `predict`. 905 | 906 | Parameters 907 | ---------- 908 | X : iterable 909 | Data to predict on. Must fulfill input requirements of first step 910 | of the DAG. 911 | **predict_params : dict of string -> object 912 | Parameters to the ``predict`` called at the end of all 913 | transformations in the DAG. Note that while this may be 914 | used to return uncertainties from some models with return_std 915 | or return_cov, uncertainties that are generated by the 916 | transformations in the DAG are not propagated to the 917 | final estimator. 918 | 919 | Returns 920 | ------- 921 | y_pred : ndarray 922 | Result of calling `predict` on the final estimator. 923 | """ 924 | return self._execute("predict", X, **predict_params) 925 | 926 | @available_if(_leaf_estimators_have("predict_proba")) 927 | def predict_proba(self, X, **predict_proba_params): 928 | """ 929 | Transform the data, and apply `predict_proba` with the final estimator. 930 | 931 | Call `transform` of each transformer in the DAG. The transformed 932 | data are finally passed to the final estimator that calls 933 | `predict_proba` method. Only valid if the final estimators implement 934 | `predict_proba`. 935 | 936 | Parameters 937 | ---------- 938 | X : iterable 939 | Data to predict on. Must fulfill input requirements of first step 940 | of the DAG. 941 | **predict_proba_params : dict of string -> object 942 | Parameters to the `predict_proba` called at the end of all 943 | transformations in the DAG. 944 | 945 | Returns 946 | ------- 947 | y_proba : ndarray of shape (n_samples, n_classes) 948 | Result of calling `predict_proba` on the final estimator. 949 | """ 950 | return self._execute("predict_proba", X, **predict_proba_params) 951 | 952 | @available_if(_leaf_estimators_have("decision_function")) 953 | def decision_function(self, X): 954 | """ 955 | Transform the data, and apply `decision_function` with the final estimator. 956 | 957 | Call `transform` of each transformer in the DAG. The transformed 958 | data are finally passed to the final estimator that calls 959 | `decision_function` method. Only valid if the final estimators 960 | implement `decision_function`. 961 | 962 | Parameters 963 | ---------- 964 | X : iterable 965 | Data to predict on. Must fulfill input requirements of first step 966 | of the DAG. 967 | 968 | Returns 969 | ------- 970 | y_score : ndarray of shape (n_samples, n_classes) 971 | Result of calling `decision_function` on the final estimator. 972 | """ 973 | return self._execute("decision_function", X) 974 | 975 | @available_if(_leaf_estimators_have("score_samples")) 976 | def score_samples(self, X): 977 | """ 978 | Transform the data, and apply `score_samples` with the final estimator. 979 | 980 | Call `transform` of each transformer in the DAG. The transformed 981 | data are finally passed to the final estimator that calls 982 | `score_samples` method. Only valid if the final estimators implement 983 | `score_samples`. 984 | 985 | Parameters 986 | ---------- 987 | X : iterable 988 | Data to predict on. Must fulfill input requirements of first step 989 | of the DAG. 990 | 991 | Returns 992 | ------- 993 | y_score : ndarray of shape (n_samples,) 994 | Result of calling `score_samples` on the final estimator. 995 | """ 996 | return self._execute("score_samples", X) 997 | 998 | @available_if(_leaf_estimators_have("score")) 999 | def score(self, X, y=None, **score_params): 1000 | """ 1001 | Transform the data, and apply `score` with the final estimator. 1002 | 1003 | Call `transform` of each transformer in the DAG. The transformed 1004 | data are finally passed to the final estimator that calls 1005 | `score` method. Only valid if the final estimators implement `score`. 1006 | 1007 | Parameters 1008 | ---------- 1009 | X : iterable 1010 | Data to predict on. Must fulfill input requirements of first step 1011 | of the DAG. 1012 | y : iterable, default=None 1013 | Targets used for scoring. Must fulfill label requirements for all 1014 | steps of the DAG. 1015 | sample_weight : array-like, default=None 1016 | If not None, this argument is passed as ``sample_weight`` keyword 1017 | argument to the ``score`` method of the final estimator. 1018 | 1019 | Returns 1020 | ------- 1021 | score : float 1022 | Result of calling `score` on the final estimator. 1023 | """ 1024 | return self._execute("score", X, y, **score_params) 1025 | 1026 | @available_if(_leaf_estimators_have("predict_log_proba")) 1027 | def predict_log_proba(self, X, **predict_log_proba_params): 1028 | """ 1029 | Transform the data, and apply `predict_log_proba` with the final estimator. 1030 | 1031 | Call `transform` of each transformer in the DAG. The transformed 1032 | data are finally passed to the final estimator that calls 1033 | `predict_log_proba` method. Only valid if the final estimator 1034 | implements `predict_log_proba`. 1035 | 1036 | Parameters 1037 | ---------- 1038 | X : iterable 1039 | Data to predict on. Must fulfill input requirements of first step 1040 | of the DAG. 1041 | **predict_log_proba_params : dict of string -> object 1042 | Parameters to the ``predict_log_proba`` called at the end of all 1043 | transformations in the DAG. 1044 | 1045 | Returns 1046 | ------- 1047 | y_log_proba : ndarray of shape (n_samples, n_classes) 1048 | Result of calling `predict_log_proba` on the final estimator. 1049 | """ 1050 | return self._execute("predict_log_proba", X, **predict_log_proba_params) 1051 | 1052 | def _check_fit_params(self, **fit_params): 1053 | fit_params_steps = { 1054 | name: {} for (name, step) in self.steps_ if step is not None 1055 | } 1056 | for pname, pval in fit_params.items(): 1057 | if pval is None: 1058 | continue 1059 | 1060 | if "__" not in pname: 1061 | raise ValueError( 1062 | f"DAG.fit does not accept the {pname} parameter. " 1063 | "You can pass parameters to specific steps of your " 1064 | "DAG using the stepname__parameter format, e.g. " 1065 | "`DAG.fit(X, y, logisticregression__sample_weight" 1066 | "=sample_weight)`." 1067 | ) 1068 | step, param = pname.split("__", 1) 1069 | fit_params_steps[step][param] = pval 1070 | return fit_params_steps 1071 | 1072 | def _check_fn_params(self, **fn_params): 1073 | global_params = {} 1074 | fn_params_steps = {name: {} for (name, step) in self.steps_ if step is not None} 1075 | for pname, pval in fn_params.items(): 1076 | if pval is None: 1077 | continue 1078 | 1079 | if "__" not in pname: 1080 | global_params[pname] = pval 1081 | else: 1082 | step, param = pname.split("__", 1) 1083 | fn_params_steps[step][param] = pval 1084 | 1085 | for step in fn_params_steps: 1086 | fn_params_steps[step].update(global_params) 1087 | 1088 | return fn_params_steps 1089 | 1090 | def _validate_graph(self): 1091 | if len(self.graph_) == 0: 1092 | raise ValueError("DAG has no nodes.") 1093 | 1094 | for i, (name, est) in enumerate(self.steps_): 1095 | step = self.graph_.nodes[name]["step"] 1096 | step.index = i 1097 | 1098 | # validate names 1099 | self._validate_names([name for (name, step) in self.steps_]) 1100 | 1101 | # validate transformers 1102 | for step in self.roots_ + self.branches_: 1103 | if step in self.leaves_: 1104 | # This will get validated later 1105 | continue 1106 | 1107 | est = step.estimator 1108 | # Unlike pipelines we also allow predictors to be used as a transformer, to support 1109 | # model stacking. 1110 | if ( 1111 | not _is_passthrough(est) 1112 | and not _is_transformer(est) 1113 | and not _is_predictor(est) 1114 | ): 1115 | raise TypeError( 1116 | "All intermediate steps should be " 1117 | "transformers and implement fit and transform " 1118 | "or be the string 'passthrough' " 1119 | f"'{est}' (type {type(est)}) doesn't" 1120 | ) 1121 | 1122 | # Validate final estimator(s) 1123 | for step in self.leaves_: 1124 | est = step.estimator 1125 | if not _is_passthrough(est) and not hasattr(est, "fit"): 1126 | raise TypeError( 1127 | "Leaf nodes of a DAG should implement fit " 1128 | "or be the string 'passthrough'. " 1129 | f"'{est}' (type {type(est)}) doesn't" 1130 | ) 1131 | 1132 | @property 1133 | def graph_(self): 1134 | if not hasattr(self, "_graph"): 1135 | # Read-only view of the graph. We should not modify 1136 | # the original graph. 1137 | self._graph = self.graph.copy(as_view=True) 1138 | 1139 | return self._graph 1140 | 1141 | @property 1142 | def leaves_(self): 1143 | if not hasattr(self, "_leaves"): 1144 | self._leaves = [node for node in self.nodes_ if node.is_leaf] 1145 | 1146 | return self._leaves 1147 | 1148 | @property 1149 | def branches_(self): 1150 | if not hasattr(self, "_branches"): 1151 | self._branches = [ 1152 | node for node in self.nodes_ if not node.is_leaf and not node.is_root 1153 | ] 1154 | 1155 | return self._branches 1156 | 1157 | @property 1158 | def roots_(self): 1159 | if not hasattr(self, "_roots"): 1160 | self._roots = [node for node in self.nodes_ if node.is_root] 1161 | 1162 | return self._roots 1163 | 1164 | @property 1165 | def nodes_(self): 1166 | if not hasattr(self, "_nodes"): 1167 | self._nodes = [] 1168 | for name, estimator in self.steps_: 1169 | step = self.graph_.nodes[name]["step"] 1170 | if self.graph_.out_degree(name) == 0: 1171 | step.is_leaf = True 1172 | if self.graph_.in_degree(name) == 0: 1173 | step.is_root = True 1174 | self._nodes.append(step) 1175 | 1176 | return self._nodes 1177 | 1178 | @property 1179 | def steps_(self): 1180 | "return list of (name, estimator) tuples to conform with Pipeline interface." 1181 | if not hasattr(self, "_steps"): 1182 | self._steps = [ 1183 | (node, self.graph_.nodes[node]["step"].estimator) 1184 | for node in nx.lexicographical_topological_sort(self.graph_) 1185 | ] 1186 | 1187 | return self._steps 1188 | 1189 | def join(self, other, edges, **kwargs): 1190 | """ 1191 | Create a new DAG by joining this DAG to another one, according to the edges 1192 | specified. 1193 | 1194 | Parameters 1195 | ---------- 1196 | 1197 | other : :class:`skdag.dag.DAG` 1198 | The other DAG to connect to. 1199 | edges : (str, str) or (str, str, index-like) 1200 | ``(u, v)`` edges that connect the two DAGs. ``u`` and ``v`` should be the 1201 | names of steps in the first and second DAG respectively. Optionally a third 1202 | parameter may be included to specify which columns to pass along the edge. 1203 | **kwargs : keyword params 1204 | Any other parameters to pass to the new DAG's constructor. 1205 | 1206 | Returns 1207 | ------- 1208 | dag : :class:`skdag.DAG` 1209 | A new DAG, containing a copy of each of the input DAGs, joined by the 1210 | specified edges. Note that the original input dags are unmodified. 1211 | 1212 | Examples 1213 | -------- 1214 | 1215 | >>> from sklearn.decomposition import PCA 1216 | >>> from sklearn.impute import SimpleImputer 1217 | >>> from sklearn.linear_model import LogisticRegression 1218 | >>> from sklearn.calibration import CalibratedClassifierCV 1219 | >>> from skdag.dag import DAGBuilder 1220 | >>> dag1 = ( 1221 | ... DAGBuilder() 1222 | ... .add_step("impute", SimpleImputer()) 1223 | ... .add_step("vitals", "passthrough", deps={"impute": slice(0, 4)}) 1224 | ... .add_step("blood", PCA(n_components=2, random_state=0), deps={"impute": slice(4, 10)}) 1225 | ... .add_step("lr", LogisticRegression(random_state=0), deps=["blood", "vitals"]) 1226 | ... .make_dag() 1227 | ... ) 1228 | >>> print(dag1.draw().strip()) 1229 | o impute 1230 | |\\ 1231 | o o blood,vitals 1232 | |/ 1233 | o lr 1234 | >>> dag2 = ( 1235 | ... DAGBuilder() 1236 | ... .add_step( 1237 | ... "calib", 1238 | ... CalibratedClassifierCV(LogisticRegression(random_state=0), cv=5), 1239 | ... ) 1240 | ... .make_dag() 1241 | ... ) 1242 | >>> print(dag2.draw().strip()) 1243 | o calib 1244 | >>> dag3 = dag1.join(dag2, edges=[("blood", "calib"), ("vitals", "calib")]) 1245 | >>> print(dag3.draw().strip()) 1246 | o impute 1247 | |\\ 1248 | o o blood,vitals 1249 | |x| 1250 | o o calib,lr 1251 | """ 1252 | if set(self.step_names) & set(other.step_names): 1253 | raise ValueError("DAGs with overlapping step names cannot be combined.") 1254 | 1255 | newgraph = deepcopy(self.graph_).copy() 1256 | for edge in edges: 1257 | if len(edge) == 2: 1258 | u, v, idx = *edge, None 1259 | else: 1260 | u, v, idx = edge 1261 | 1262 | if u not in self.graph_: 1263 | raise KeyError(u) 1264 | if v not in other.graph_: 1265 | raise KeyError(v) 1266 | 1267 | # source node can no longer be a leaf 1268 | ustep = newgraph.nodes[u]["step"] 1269 | if ustep.is_leaf: 1270 | ustep.is_leaf = False 1271 | 1272 | vnode = other.graph_.nodes[v] 1273 | old_step = vnode["step"] 1274 | vstep = DAGStep( 1275 | name=old_step.name, 1276 | estimator=old_step.estimator, 1277 | deps=old_step.deps, 1278 | dataframe_columns=old_step.dataframe_columns, 1279 | axis=old_step.axis, 1280 | ) 1281 | 1282 | if u not in vstep.deps: 1283 | vstep.deps[u] = idx 1284 | 1285 | vnode["step"] = vstep 1286 | 1287 | newgraph.add_node(v, **vnode) 1288 | newgraph.add_edge(u, v) 1289 | 1290 | return DAG(newgraph, **kwargs) 1291 | 1292 | def draw( 1293 | self, filename=None, style=None, detailed=False, format=None, layout="dot" 1294 | ): 1295 | """ 1296 | Render a graphical view of the DAG. 1297 | 1298 | By default the rendered file will be returned as a string. However if an output 1299 | file is provided then the output will be saved to file. 1300 | 1301 | Parameters 1302 | ---------- 1303 | 1304 | filename : str 1305 | The file to write the image to. If None, the rendered image will be sent to 1306 | stdout. 1307 | style : str, optional, choice of ['light', 'dark'] 1308 | Draw the image in light or dark mode. 1309 | detailed : bool, default = False 1310 | If True, show extra details in the node labels such as the estimator 1311 | signature. 1312 | format : str, choice of ['svg', 'png', 'jpg', 'txt'] 1313 | The rendering format to use. MAy be omitted if the format can be inferred 1314 | from the filename. 1315 | layout : str, default = 'dot' 1316 | The program to use for generating a graph layout. 1317 | 1318 | See Also 1319 | -------- 1320 | 1321 | :meth:`skdag.dag.DAG.show`, for use in interactive notebooks. 1322 | 1323 | Returns 1324 | ------- 1325 | 1326 | output : str, bytes or None 1327 | If a filename is provided the output is written to file and `None` is 1328 | returned. Otherwise, the output is returned as a string (for textual formats 1329 | like ascii or svg) or bytes. 1330 | """ 1331 | if filename is None and format is None: 1332 | try: 1333 | from IPython import get_ipython 1334 | 1335 | rich = type(get_ipython()).__name__ == "ZMQInteractiveShell" 1336 | except (ModuleNotFoundError, NameError): 1337 | rich = False 1338 | 1339 | format = "svg" if rich else "txt" 1340 | 1341 | if format is None: 1342 | format = filename.split(".")[-1] 1343 | 1344 | if format not in ["svg", "png", "jpg", "txt"]: 1345 | raise ValueError(f"Unsupported file format '{format}'") 1346 | 1347 | render = DAGRenderer(self.graph_, detailed=detailed, style=style).draw( 1348 | format=format, layout=layout 1349 | ) 1350 | if filename is None: 1351 | return render 1352 | else: 1353 | mode = "wb" if isinstance(render, bytes) else "w" 1354 | with open(filename, mode) as fp: 1355 | fp.write(render) 1356 | 1357 | def show(self, style=None, detailed=False, format=None, layout="dot"): 1358 | """ 1359 | Display a graphical representation of the DAG in an interactive notebook 1360 | environment. 1361 | 1362 | DAGs will be shown when displayed in a notebook, but calling this method 1363 | directly allows more options to be passed to customise the appearance more. 1364 | 1365 | Arguments are as for :meth`.draw`. 1366 | 1367 | Returns 1368 | ------- 1369 | 1370 | ``None`` 1371 | 1372 | See Also 1373 | -------- 1374 | 1375 | :meth:`skdag.DAG.draw` 1376 | """ 1377 | if format is None: 1378 | format = "svg" if _in_notebook() else "txt" 1379 | 1380 | data = self.draw(style=style, detailed=detailed, format=format, layout=layout) 1381 | if format == "svg": 1382 | from IPython.display import SVG, display 1383 | 1384 | display(SVG(data)) 1385 | elif format == "txt": 1386 | print(data) 1387 | elif format in ("jpg", "png"): 1388 | from IPython.display import Image, display 1389 | 1390 | display(Image(data)) 1391 | else: 1392 | raise ValueError(f"'{format}' format not supported.") 1393 | 1394 | def _repr_svg_(self): 1395 | return self.draw(format="svg") 1396 | 1397 | def _repr_png_(self): 1398 | return self.draw(format="png") 1399 | 1400 | def _repr_jpeg_(self): 1401 | return self.draw(format="jpg") 1402 | 1403 | def _repr_html_(self): 1404 | return self.draw(format="svg") 1405 | 1406 | def _repr_pretty_(self, p, cycle): 1407 | if cycle: 1408 | p.text(repr(self)) 1409 | else: 1410 | p.text(str(self)) 1411 | 1412 | def _repr_mimebundle_(self, **kwargs): 1413 | # Don't render yet... 1414 | renderers = { 1415 | "image/svg+xml": self._repr_svg_, 1416 | "image/png": self._repr_png_, 1417 | "image/jpeg": self._repr_jpeg_, 1418 | "text/plain": self.__str__, 1419 | "text/html": self._repr_html_, 1420 | } 1421 | 1422 | include = kwargs.get("include") 1423 | if include: 1424 | renderers = {k: v for k, v in renderers.items() if k in include} 1425 | 1426 | exclude = kwargs.get("exclude") 1427 | if exclude: 1428 | renderers = {k: v for k, v in renderers.items() if k not in exclude} 1429 | 1430 | # Now render any remaining options. 1431 | return {k: v() for k, v in renderers.items()} 1432 | 1433 | @property 1434 | def named_steps(self): 1435 | """ 1436 | Access the steps by name. 1437 | 1438 | Read-only attribute to access any step by given name. 1439 | Keys are steps names and values are the steps objects. 1440 | """ 1441 | # Use Bunch object to improve autocomplete 1442 | return Bunch(**dict(self.steps_)) 1443 | 1444 | @property 1445 | def step_names(self): 1446 | return list(self.graph_.nodes) 1447 | 1448 | @property 1449 | def edges(self): 1450 | return self.graph_.edges 1451 | 1452 | def _get_leaf_attr(self, attr): 1453 | if len(self.leaves_) == 1: 1454 | return getattr(self.leaves_[0].estimator, attr) 1455 | else: 1456 | return Bunch( 1457 | **{leaf.name: getattr(leaf.estimator, attr) for leaf in self.leaves_} 1458 | ) 1459 | 1460 | @property 1461 | def _estimator_type(self): 1462 | return self._get_leaf_attr("_estimator_type") 1463 | 1464 | @property 1465 | def classes_(self): 1466 | """The classes labels. Only exist if the leaf steps are classifiers.""" 1467 | return self._get_leaf_attr("classes_") 1468 | 1469 | def __sklearn_is_fitted__(self): 1470 | """Indicate whether DAG has been fit.""" 1471 | try: 1472 | # check if the last steps of the DAG are fitted 1473 | # we only check the last steps since if the last steps are fit, it 1474 | # means the previous steps should also be fit. This is faster than 1475 | # checking if every step of the DAG is fit. 1476 | for leaf in self._leaves: 1477 | check_is_fitted(leaf.estimator) 1478 | return True 1479 | except NotFittedError: 1480 | return False 1481 | 1482 | def _more_tags(self): 1483 | tags = {} 1484 | 1485 | # We assume the DAG can handle NaN if *all* the steps can. 1486 | tags["allow_nan"] = all( 1487 | _safe_tags(node.estimator, "allow_nan") for node in self.nodes_ 1488 | ) 1489 | 1490 | # Check if all *root* nodes expect pairwise input. 1491 | tags["pairwise"] = all( 1492 | _safe_tags(root.estimator, "pairwise") for root in self.roots_ 1493 | ) 1494 | 1495 | # CHeck if all *leaf* notes support multioutput 1496 | tags["multioutput"] = all( 1497 | _safe_tags(leaf.estimator, "multioutput") for leaf in self.leaves_ 1498 | ) 1499 | 1500 | return tags 1501 | -------------------------------------------------------------------------------- /skdag/dag/_render.py: -------------------------------------------------------------------------------- 1 | import html 2 | from typing import Iterable 3 | 4 | import black 5 | from matplotlib.pyplot import isinteractive 6 | import networkx as nx 7 | import stackeddag.core as sd 8 | from skdag.dag._utils import _is_passthrough 9 | 10 | __all__ = ["DAGRenderer"] 11 | 12 | 13 | class DAGRenderer: 14 | _EMPTY = "[empty]" 15 | _STYLES = { 16 | "light": { 17 | "node__color": "black", 18 | "node__fontcolor": "black", 19 | "edge__color": "black", 20 | "graph__bgcolor": "white", 21 | }, 22 | "dark": { 23 | "node__color": "white", 24 | "node__fontcolor": "white", 25 | "edge__color": "white", 26 | "graph__bgcolor": "none", 27 | }, 28 | } 29 | 30 | def __init__(self, dag, detailed=False, style=None): 31 | self.dag = dag 32 | if style is not None and style not in self._STYLES: 33 | raise ValueError(f"Unknown style: '{style}'.") 34 | self.style = style 35 | self.agraph = self.to_agraph(detailed) 36 | 37 | def _get_node_shape(self, estimator): 38 | if _is_passthrough(estimator): 39 | return "hexagon" 40 | if any(hasattr(estimator, attr) for attr in ["fit_predict", "predict"]): 41 | return "ellipse" 42 | return "box" 43 | 44 | def _is_empty(self): 45 | return len(self.dag) == 0 46 | 47 | def to_agraph(self, detailed): 48 | G = self.dag 49 | if self._is_empty(): 50 | G = nx.DiGraph() 51 | G.add_node( 52 | self._EMPTY, 53 | shape="box", 54 | ) 55 | 56 | try: 57 | A = nx.nx_agraph.to_agraph(G) 58 | except (ImportError, ModuleNotFoundError) as err: # pragma: no cover 59 | raise ImportError( 60 | "DAG visualisation requires pygraphviz to be installed. " 61 | "See http://pygraphviz.github.io/ for guidance." 62 | ) from err 63 | 64 | A.graph_attr["rankdir"] = "LR" 65 | if self.style is not None: 66 | A.graph_attr.update( 67 | { 68 | key.replace("graph__", ""): val 69 | for key, val in self._STYLES[self.style].items() 70 | if key.startswith("graph__") 71 | } 72 | ) 73 | 74 | mode = black.FileMode() 75 | mode.line_length = 16 76 | 77 | for v in G.nodes: 78 | anode = A.get_node(v) 79 | gnode = G.nodes[v] 80 | anode.attr["fontname"] = gnode.get("fontname", "SANS") 81 | if self.style is not None: 82 | anode.attr.update( 83 | { 84 | key.replace("node__", ""): val 85 | for key, val in self._STYLES[self.style].items() 86 | if key.startswith("node__") 87 | } 88 | ) 89 | if "step" in gnode: 90 | estimator = gnode["step"].estimator 91 | anode.attr["tooltip"] = repr(estimator) 92 | 93 | if detailed: 94 | estimator_str = html.escape( 95 | black.format_str(repr(estimator), mode=mode) 96 | ).replace("\n", '
') 97 | 98 | anode.attr["label"] = ( 99 | '<' 100 | f'' 101 | f'' 102 | "
{v}
{estimator_str}
>" 103 | ) 104 | 105 | anode.attr["shape"] = gnode.get( 106 | "shape", self._get_node_shape(estimator) 107 | ) 108 | if gnode["step"].is_fitted: 109 | anode.attr["peripheries"] = 2 110 | 111 | for u, v in G.edges: 112 | aedge = A.get_edge(u, v) 113 | if self.style is not None: 114 | aedge.attr.update( 115 | { 116 | key.replace("edge__", ""): val 117 | for key, val in self._STYLES[self.style].items() 118 | if key.startswith("edge__") 119 | } 120 | ) 121 | cols = G.nodes[v]["step"].deps[u] 122 | if cols: 123 | if isinstance(cols, Iterable): 124 | cols = list(cols) 125 | 126 | if len(cols) > 5: 127 | colrepr = f"[{repr(cols[0])}, ..., {repr(cols[-1])}]" 128 | else: 129 | colrepr = repr(cols) 130 | elif callable(cols): 131 | selector = cols 132 | cols = {} 133 | for attr in ["pattern", "dtype_include", "dtype_exclude"]: 134 | if hasattr(selector, attr): 135 | val = getattr(selector, attr) 136 | if val is not None: 137 | cols[attr] = val 138 | if cols: 139 | selrepr = ", ".join( 140 | f"{key}={repr(val)}" for key, val in cols.items() 141 | ) 142 | colrepr = f"column_selector({selrepr})" 143 | else: 144 | colrepr = f"{selector.__name__}()" 145 | else: 146 | colrepr = repr(cols) 147 | 148 | aedge.attr.update( 149 | {"label": colrepr, "fontsize": "8pt", "fontname": "SANS"} 150 | ) 151 | 152 | A.layout() 153 | return A 154 | 155 | def draw(self, format="svg", layout="dot"): 156 | A = self.agraph 157 | 158 | if format == "txt": 159 | if self._is_empty(): 160 | return self._EMPTY 161 | 162 | # Edge case: single node, no edges. 163 | if len(A.edges()) == 0: 164 | return f"o {next(A.nodes_iter())}" 165 | 166 | la = [] 167 | for node in A.nodes(): 168 | la.append((node, node)) 169 | 170 | ed = [] 171 | for src, dst in A.edges(): 172 | ed.append((src, [dst])) 173 | 174 | return sd.edgesToText(sd.mkLabels(la), sd.mkEdges(ed)).strip() 175 | 176 | data = A.draw(format=format, prog=layout) 177 | 178 | if format == "svg": 179 | return data.decode(self.agraph.encoding) 180 | return data 181 | -------------------------------------------------------------------------------- /skdag/dag/_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import sparse 3 | 4 | try: 5 | import pandas as pd 6 | except ImportError: 7 | pd = None 8 | 9 | 10 | def _is_passthrough(estimator): 11 | return estimator is None or estimator == "passthrough" 12 | 13 | 14 | def _is_transformer(estimator): 15 | return ( 16 | hasattr(estimator, "fit") or hasattr(estimator, "fit_transform") 17 | ) and hasattr(estimator, "transform") 18 | 19 | 20 | def _is_predictor(estimator): 21 | return (hasattr(estimator, "fit") or hasattr(estimator, "fit_predict")) and hasattr( 22 | estimator, "predict" 23 | ) 24 | 25 | 26 | def _in_notebook(): 27 | try: 28 | from IPython import get_ipython 29 | 30 | if "IPKernelApp" not in get_ipython().config: # pragma: no cover 31 | return False 32 | except ImportError: 33 | return False 34 | except AttributeError: 35 | return False 36 | return True 37 | 38 | 39 | def _stack(Xs, axis=0): 40 | """ 41 | Where an estimator has multiple upstream dependencies, this method defines the 42 | strategy for merging the upstream outputs into a single input. ``axis=0`` is 43 | equivalent to a vstack (combination of samples) and ``axis=1`` is equivalent to a 44 | hstack (combination of features). Higher axes are only supported for non-sparse 45 | data sources. 46 | """ 47 | if any(sparse.issparse(x) for x in Xs): 48 | if -2 <= axis < 2: 49 | axis = axis % 2 50 | else: 51 | raise NotImplementedError( 52 | f"Stacking is not supported for sparse inputs on axis > 1." 53 | ) 54 | 55 | if axis == 0: 56 | Xs = sparse.vstack(Xs).tocsr() 57 | elif axis == 1: 58 | Xs = sparse.hstack(Xs).tocsr() 59 | elif pd and all(_is_pandas(x) for x in Xs): 60 | Xs = pd.concat(Xs, axis=axis) 61 | else: 62 | if axis == 1: 63 | Xs = np.hstack(Xs) 64 | else: 65 | Xs = np.stack(Xs, axis=axis) 66 | 67 | return Xs 68 | 69 | 70 | def _is_pandas(X): 71 | "Check if X is a DataFrame or Series" 72 | return hasattr(X, "iloc") 73 | 74 | 75 | def _get_feature_names(estimator): 76 | try: 77 | feature_names = estimator.get_feature_names_out() 78 | except AttributeError: 79 | try: 80 | feature_names = estimator.get_feature_names() 81 | except AttributeError: 82 | feature_names = None 83 | 84 | return feature_names 85 | 86 | 87 | def _format_output(X, input, node): 88 | if node.dataframe_columns is None or pd is None or _is_pandas(X): 89 | return X 90 | 91 | outshape = np.asarray(X).shape 92 | outdim = len(outshape) 93 | if outdim > 2 or outdim < 1: 94 | return X 95 | 96 | if node.dataframe_columns == "infer": 97 | if outdim == 1: 98 | columns = [node.name] 99 | else: 100 | feature_names = _get_feature_names(node.estimator) 101 | if feature_names is None: 102 | feature_names = range(outshape[1]) 103 | 104 | columns = [f"{node.name}__{f}" for f in feature_names] 105 | else: 106 | columns = node.dataframe_columns 107 | 108 | if hasattr(input, "index"): 109 | index = input.index 110 | else: 111 | index = None 112 | 113 | if outdim == 2: 114 | df = pd.DataFrame(X, columns=columns, index=index) 115 | else: 116 | df = pd.Series(X, name=columns[0], index=index) 117 | 118 | return df 119 | -------------------------------------------------------------------------------- /skdag/dag/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/skdag/dag/tests/__init__.py -------------------------------------------------------------------------------- /skdag/dag/tests/test_builder.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from skdag import DAG, DAGBuilder 3 | from skdag.dag.tests.utils import Mult, Transf 4 | from skdag.exceptions import DAGError 5 | from sklearn.pipeline import Pipeline 6 | 7 | 8 | def test_builder_basics(): 9 | builder = DAGBuilder() 10 | builder.add_step("tr", Transf()) 11 | builder.add_step("ident", "passthrough", deps=["tr"]) 12 | builder.add_step("est", Mult(), deps=["ident"]) 13 | 14 | dag = builder.make_dag() 15 | assert isinstance(dag, DAG) 16 | assert dag._repr_html_() == builder._repr_html_() 17 | 18 | with pytest.raises(ValueError): 19 | builder.add_step("est2", Mult(), deps=["missing"]) 20 | 21 | with pytest.raises(ValueError): 22 | builder.add_step("est2", Mult(), deps=[1]) 23 | 24 | with pytest.raises(KeyError): 25 | # Step names *must* be strings. 26 | builder.add_step(2, Mult()) 27 | 28 | with pytest.raises(KeyError): 29 | # Step names *must* be unique. 30 | builder.add_step("est", Mult()) 31 | 32 | with pytest.raises(DAGError): 33 | # Cycles not allowed. 34 | builder.graph.add_edge("est", "ident") 35 | builder.make_dag() 36 | 37 | 38 | def test_pipeline(): 39 | steps = [("tr", Transf()), ("ident", "passthrough"), ("est", Mult())] 40 | 41 | builder = DAGBuilder() 42 | prev = None 43 | for name, est in steps: 44 | builder.add_step(name, est, deps=[prev] if prev else None) 45 | prev = name 46 | 47 | dag = builder.make_dag() 48 | 49 | assert dag.steps_ == steps 50 | -------------------------------------------------------------------------------- /skdag/dag/tests/test_dag.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the DAG module. 3 | """ 4 | import re 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import pytest 9 | from skdag import DAGBuilder 10 | from skdag.dag.tests.utils import FitParamT, Mult, NoFit, NoTrans, Transf 11 | from sklearn import datasets 12 | from sklearn.base import clone 13 | from sklearn.compose import make_column_selector 14 | from sklearn.decomposition import PCA 15 | from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor 16 | from sklearn.feature_selection import SelectKBest, f_classif 17 | from sklearn.impute import SimpleImputer 18 | from sklearn.linear_model import LinearRegression, LogisticRegression 19 | from sklearn.pipeline import Pipeline 20 | from sklearn.preprocessing import StandardScaler 21 | from sklearn.svm import SVC 22 | from sklearn.utils._testing import assert_array_almost_equal 23 | from sklearn.utils.estimator_checks import parametrize_with_checks 24 | 25 | iris = datasets.load_iris() 26 | cancer = datasets.load_breast_cancer() 27 | 28 | JUNK_FOOD_DOCS = ( 29 | "the pizza pizza beer copyright", 30 | "the pizza burger beer copyright", 31 | "the the pizza beer beer copyright", 32 | "the burger beer beer copyright", 33 | "the coke burger coke copyright", 34 | "the coke burger burger", 35 | ) 36 | 37 | 38 | def test_dag_invalid_parameters(): 39 | # Test the various init parameters of the dag in fit 40 | # method 41 | with pytest.raises(KeyError): 42 | dag = DAGBuilder().from_pipeline([(1, 1)]).make_dag() 43 | 44 | # Check that we can't fit DAGs with objects without fit 45 | # method 46 | msg = ( 47 | "Leaf nodes of a DAG should implement fit or be the string 'passthrough'" 48 | ".*NoFit.*" 49 | ) 50 | dag = DAGBuilder().from_pipeline([("clf", NoFit())]).make_dag() 51 | with pytest.raises(TypeError, match=msg): 52 | dag.fit([[1]], [1]) 53 | 54 | # Smoke test with only an estimator 55 | clf = NoTrans() 56 | dag = DAGBuilder().from_pipeline([("svc", clf)]).make_dag() 57 | assert dag.get_params(deep=True) == dict( 58 | svc__a=None, svc__b=None, svc=clf, **dag.get_params(deep=False) 59 | ) 60 | 61 | # Check that params are set 62 | dag.set_params(svc__a=0.1) 63 | assert clf.a == 0.1 64 | assert clf.b is None 65 | # Smoke test the repr: 66 | repr(dag) 67 | 68 | # Test with two objects 69 | clf = SVC() 70 | filter1 = SelectKBest(f_classif) 71 | dag = DAGBuilder().from_pipeline([("anova", filter1), ("svc", clf)]).make_dag() 72 | 73 | # Check that estimators are not cloned on pipeline construction 74 | assert dag.named_steps["anova"] is filter1 75 | assert dag.named_steps["svc"] is clf 76 | 77 | # Check that we can't fit with non-transformers on the way 78 | # Note that NoTrans implements fit, but not transform 79 | msg = "All intermediate steps should be transformers.*\\bNoTrans\\b.*" 80 | dag2 = DAGBuilder().from_pipeline([("t", NoTrans()), ("svc", clf)]).make_dag() 81 | with pytest.raises(TypeError, match=msg): 82 | dag2.fit([[1]], [1]) 83 | 84 | # Check that params are set 85 | dag.set_params(svc__C=0.1) 86 | assert clf.C == 0.1 87 | # Smoke test the repr: 88 | repr(dag) 89 | 90 | # Check that params are not set when naming them wrong 91 | msg = re.escape( 92 | "Invalid parameter 'C' for estimator SelectKBest(). Valid parameters are: ['k'," 93 | " 'score_func']." 94 | ) 95 | with pytest.raises(ValueError, match=msg): 96 | dag.set_params(anova__C=0.1) 97 | 98 | # Test clone 99 | dag2 = clone(dag) 100 | assert not dag.named_steps["svc"] is dag2.named_steps["svc"] 101 | 102 | # Check that apart from estimators, the parameters are the same 103 | params = dag.get_params(deep=True) 104 | params2 = dag2.get_params(deep=True) 105 | 106 | for x in dag.get_params(deep=False): 107 | params.pop(x) 108 | 109 | for x in dag.get_params(deep=False): 110 | params2.pop(x) 111 | 112 | # Remove estimators that where copied 113 | params.pop("svc") 114 | params.pop("anova") 115 | params2.pop("svc") 116 | params2.pop("anova") 117 | assert params == params2 118 | 119 | 120 | def test_dag_simple(): 121 | # Build a simple DAG of one node 122 | rng = np.random.default_rng(1) 123 | X = rng.random(size=(20, 5)) 124 | 125 | dag = DAGBuilder().add_step("clf", FitParamT()).make_dag() 126 | dag.fit(X) 127 | dag.predict(X) 128 | 129 | 130 | def test_dag_pipeline_init(): 131 | # Build a dag from a pipeline 132 | X = np.array([[1, 2]]) 133 | steps = (("transf", Transf()), ("clf", FitParamT())) 134 | pipe = Pipeline(steps, verbose=False) 135 | for inp in [pipe, steps]: 136 | dag = DAGBuilder().from_pipeline(inp).make_dag() 137 | dag.fit(X, y=None) 138 | dag.score(X) 139 | 140 | dag.set_params(transf="passthrough") 141 | dag.fit(X, y=None) 142 | dag.score(X) 143 | 144 | 145 | def test_dag_methods_anova(): 146 | # Test the various methods of the dag (anova). 147 | X = iris.data 148 | y = iris.target 149 | # Test with Anova + LogisticRegression 150 | clf = LogisticRegression() 151 | filter1 = SelectKBest(f_classif, k=2) 152 | dag1 = ( 153 | DAGBuilder().from_pipeline([("anova", filter1), ("logistic", clf)]).make_dag() 154 | ) 155 | dag2 = ( 156 | DAGBuilder() 157 | .add_step("anova", filter1) 158 | .add_step("logistic", clf, deps=["anova"]) 159 | .make_dag() 160 | ) 161 | dag1.fit(X, y) 162 | dag2.fit(X, y) 163 | assert_array_almost_equal(dag1.predict(X), dag2.predict(X)) 164 | assert_array_almost_equal(dag1.predict_proba(X), dag2.predict_proba(X)) 165 | assert_array_almost_equal(dag1.predict_log_proba(X), dag2.predict_log_proba(X)) 166 | assert_array_almost_equal(dag1.score(X, y), dag2.score(X, y)) 167 | 168 | 169 | def test_dag_fit_params(): 170 | # Test that the pipeline can take fit parameters 171 | dag = ( 172 | DAGBuilder() 173 | .from_pipeline([("transf", Transf()), ("clf", FitParamT())]) 174 | .make_dag() 175 | ) 176 | dag.fit(X=None, y=None, clf__should_succeed=True) 177 | # classifier should return True 178 | assert dag.predict(None) 179 | # and transformer params should not be changed 180 | assert dag.named_steps["transf"].a is None 181 | assert dag.named_steps["transf"].b is None 182 | # invalid parameters should raise an error message 183 | 184 | msg = re.escape("fit() got an unexpected keyword argument 'bad'") 185 | with pytest.raises(TypeError, match=msg): 186 | dag.fit(None, None, clf__bad=True) 187 | 188 | 189 | def test_dag_sample_weight_supported(): 190 | # DAG should pass sample_weight 191 | X = np.array([[1, 2]]) 192 | dag = ( 193 | DAGBuilder() 194 | .from_pipeline([("transf", Transf()), ("clf", FitParamT())]) 195 | .make_dag() 196 | ) 197 | dag.fit(X, y=None) 198 | assert dag.score(X) == 3 199 | assert dag.score(X, y=None) == 3 200 | assert dag.score(X, y=None, sample_weight=None) == 3 201 | assert dag.score(X, sample_weight=np.array([2, 3])) == 8 202 | 203 | 204 | def test_dag_sample_weight_unsupported(): 205 | # When sample_weight is None it shouldn't be passed 206 | X = np.array([[1, 2]]) 207 | dag = DAGBuilder().from_pipeline([("transf", Transf()), ("clf", Mult())]).make_dag() 208 | dag.fit(X, y=None) 209 | assert dag.score(X) == 3 210 | assert dag.score(X, sample_weight=None) == 3 211 | 212 | msg = re.escape("score() got an unexpected keyword argument 'sample_weight'") 213 | with pytest.raises(TypeError, match=msg): 214 | dag.score(X, sample_weight=np.array([2, 3])) 215 | 216 | 217 | def test_dag_raise_set_params_error(): 218 | # Test dag raises set params error message for nested models. 219 | dag = DAGBuilder().from_pipeline([("cls", LinearRegression())]).make_dag() 220 | 221 | # expected error message 222 | error_msg = ( 223 | r"Invalid parameter 'fake' for estimator DAG\(graph=]*>" 224 | r"\)\. Valid parameters are: \['graph', 'memory', 'n_jobs', 'verbose'\]." 225 | ) 226 | with pytest.raises(ValueError, match=error_msg): 227 | dag.set_params(fake="nope") 228 | 229 | # invalid outer parameter name for compound parameter: the expected error message 230 | # is the same as above. 231 | with pytest.raises(ValueError, match=error_msg): 232 | dag.set_params(fake__estimator="nope") 233 | 234 | # expected error message for invalid inner parameter 235 | error_msg = ( 236 | r"Invalid parameter 'invalid_param' for estimator LinearRegression\(\)\. Valid" 237 | r" parameters are: \['copy_X', 'fit_intercept', 'n_jobs', 'normalize'," 238 | r" 'positive'\]." 239 | ) 240 | with pytest.raises(ValueError, match=error_msg): 241 | dag.set_params(cls__invalid_param="nope") 242 | 243 | 244 | @pytest.mark.parametrize("idx", [1, [1]]) 245 | def test_dag_stacking_pca_svm_rf(idx): 246 | # Test the various methods of the pipeline (pca + svm). 247 | X = cancer.data 248 | y = cancer.target 249 | # Build a simple model stack with some preprocessing. 250 | pca = PCA(svd_solver="full", n_components="mle", whiten=True) 251 | svc = SVC(probability=True, random_state=0) 252 | rf = RandomForestClassifier(random_state=0) 253 | log = LogisticRegression() 254 | 255 | dag = ( 256 | DAGBuilder() 257 | .add_step("pca", pca) 258 | .add_step("svc", svc, deps=["pca"]) 259 | .add_step("rf", rf, deps=["pca"]) 260 | .add_step("log", log, deps={"svc": idx, "rf": idx}) 261 | .make_dag() 262 | ) 263 | dag.fit(X, y) 264 | 265 | prob_shape = len(cancer.target), len(cancer.target_names) 266 | tgt_shape = cancer.target.shape 267 | 268 | assert dag.predict_proba(X).shape == prob_shape 269 | assert dag.predict(X).shape == tgt_shape 270 | assert dag.predict_log_proba(X).shape == prob_shape 271 | assert isinstance(dag.score(X, y), (float, np.floating)) 272 | 273 | root = dag["log"] 274 | for attr in ["n_features_in_", "feature_names_in_"]: 275 | if hasattr(root, attr): 276 | assert hasattr(dag, attr) 277 | 278 | 279 | def test_dag_draw(): 280 | txt = DAGBuilder().make_dag().draw(format="txt") 281 | assert "[empty]" in txt 282 | 283 | svg = DAGBuilder().make_dag().draw(format="svg") 284 | assert "[empty]" in svg 285 | 286 | # Build a simple model stack with some preprocessing. 287 | pca = PCA(svd_solver="full", n_components="mle", whiten=True) 288 | svc = SVC(probability=True, random_state=0) 289 | rf = RandomForestClassifier(random_state=0) 290 | log = LogisticRegression() 291 | 292 | dag = ( 293 | DAGBuilder() 294 | .add_step("pca", pca) 295 | .add_step("svc1", svc, deps={"pca": slice(4)}) 296 | .add_step("svc2", svc, deps={"pca": [0, 1, 2]}) 297 | .add_step( 298 | "svc3", svc, deps={"pca": lambda X: [c for c in X if c.startswith("foo")]} 299 | ) 300 | .add_step("rf1", rf, deps={"pca": [0, 1, 2, 3, 4, 5, 6, 7, 8]}) 301 | .add_step("rf2", rf, deps={"pca": make_column_selector(pattern="^pca.*")}) 302 | .add_step("log", log, deps=["svc1", "svc2", "rf1", "rf2"]) 303 | .make_dag() 304 | ) 305 | 306 | for repr_method in [fn for fn in dir(dag) if fn.startswith("_repr_")]: 307 | if repr_method == "_repr_pretty_": 308 | try: 309 | from IPython.lib.pretty import PrettyPrinter 310 | except ImportError: # pragma: no cover 311 | continue 312 | from io import StringIO 313 | 314 | sio = StringIO() 315 | getattr(dag, repr_method)(PrettyPrinter(sio), False) 316 | sio.seek(0) 317 | out = sio.read() 318 | else: 319 | out = getattr(dag, repr_method)() 320 | 321 | if repr_method == "_repr_mimebundle_": 322 | for mimetype, data in out.items(): 323 | if mimetype in ("image/png", "image/jpeg"): 324 | expected = bytes 325 | else: 326 | expected = str 327 | assert isinstance( 328 | data, expected 329 | ), f"{repr_method} {mimetype} returns unexpected type {data}" 330 | elif repr_method in ("_repr_jpeg_", "_repr_png_"): 331 | assert isinstance( 332 | out, bytes 333 | ), f"{repr_method} returns unexpected type {out}" 334 | else: 335 | assert isinstance(out, str), f"{repr_method} returns unexpected type {out}" 336 | 337 | txt = dag.draw(format="txt") 338 | for step in dag.step_names: 339 | assert step in txt 340 | 341 | svg = dag.draw(format="svg") 342 | for step in dag.step_names: 343 | assert f"{step}" in svg 344 | 345 | svg = dag.draw(format="svg", style="dark") 346 | for step in dag.step_names: 347 | assert f"{step}" in svg 348 | 349 | with pytest.raises(ValueError): 350 | badstyle = "foo" 351 | dag.draw(format="svg", style=badstyle) 352 | 353 | svg = dag.draw(format="svg", detailed=True) 354 | for step, est in dag.steps_: 355 | assert f"{step}" in svg 356 | assert f"{type(est).__name__}" in svg 357 | 358 | 359 | def _dag_from_steplist(steps, **builder_opts): 360 | builder = DAGBuilder(**builder_opts) 361 | for step in steps: 362 | builder.add_step(**step) 363 | return builder.make_dag() 364 | 365 | 366 | @pytest.mark.parametrize( 367 | "steps", 368 | [ 369 | [ 370 | { 371 | "name": "pca", 372 | "est": PCA(n_components=1), 373 | }, 374 | { 375 | "name": "svc", 376 | "est": SVC(probability=True, random_state=0), 377 | "deps": ["pca"], 378 | }, 379 | { 380 | "name": "rf", 381 | "est": RandomForestClassifier(random_state=0), 382 | "deps": ["pca"], 383 | }, 384 | { 385 | "name": "log", 386 | "est": LogisticRegression(), 387 | "deps": ["svc", "rf"], 388 | }, 389 | ], 390 | ], 391 | ) 392 | @pytest.mark.parametrize( 393 | "X,y", 394 | [datasets.make_blobs(n_samples=200, n_features=10, centers=3, random_state=0)], 395 | ) 396 | def test_pandas(X, y, steps): 397 | dag_np = _dag_from_steplist(steps, infer_dataframe=False) 398 | dag_pd = _dag_from_steplist(steps, infer_dataframe=True) 399 | 400 | dag_np.fit(X, y) 401 | dag_pd.fit(X, y) 402 | 403 | y_pred_np = dag_np.predict_proba(X) 404 | y_pred_pd = dag_pd.predict_proba(X) 405 | assert isinstance(y_pred_np, np.ndarray) 406 | assert isinstance(y_pred_pd, pd.DataFrame) 407 | assert np.allclose(y_pred_np, y_pred_pd) 408 | 409 | 410 | @pytest.mark.parametrize("input_passthrough", [False, True]) 411 | def test_pandas_indexing(input_passthrough): 412 | X, y = datasets.load_diabetes(return_X_y=True, as_frame=True) 413 | 414 | passcols = ["age", "sex", "bmi", "bp"] 415 | 416 | builder = DAGBuilder(infer_dataframe=True) 417 | if input_passthrough: 418 | builder.add_step("inp", "passthrough") 419 | 420 | preprocessing = ( 421 | builder.add_step( 422 | "imp", SimpleImputer(), deps=["inp"] if input_passthrough else None 423 | ) 424 | .add_step("vitals", "passthrough", deps={"imp": passcols}) 425 | .add_step( 426 | "blood", 427 | PCA(n_components=2, random_state=0), 428 | deps={"imp": make_column_selector("s[0-9]+")}, 429 | ) 430 | .add_step("out", "passthrough", deps=["vitals", "blood"]) 431 | .make_dag() 432 | ) 433 | 434 | X_tr = preprocessing.fit_transform(X, y) 435 | assert isinstance(X_tr, pd.DataFrame) 436 | assert (X_tr.index == X.index).all() 437 | assert X_tr.columns.tolist() == [f"imp__{col}" for col in passcols] + [ 438 | "blood__pca0", 439 | "blood__pca1", 440 | ] 441 | 442 | predictor = ( 443 | DAGBuilder(infer_dataframe=True) 444 | .add_step("rf", RandomForestRegressor(random_state=0)) 445 | .make_dag() 446 | ) 447 | 448 | dag = preprocessing.join( 449 | predictor, 450 | edges=[("out", "rf")], 451 | ) 452 | 453 | y_pred = dag.fit_predict(X, y) 454 | 455 | assert isinstance(y_pred, pd.Series) 456 | assert (y_pred.index == y.index).all() 457 | assert y_pred.name == dag.leaves_[0].name 458 | 459 | 460 | @parametrize_with_checks( 461 | [ 462 | DAGBuilder().from_pipeline([("ss", StandardScaler())]).make_dag(), 463 | DAGBuilder().from_pipeline([("lr", LinearRegression())]).make_dag(), 464 | ( 465 | DAGBuilder() 466 | .add_step("pca", PCA(n_components=1)) 467 | .add_step("svc", SVC(probability=True, random_state=0), deps=["pca"]) 468 | .add_step("rf", RandomForestClassifier(random_state=0), deps=["pca"]) 469 | .add_step("log", LogisticRegression(), deps=["svc", "rf"]) 470 | .make_dag() 471 | ), 472 | ] 473 | ) 474 | def test_dag_check_estimator(estimator, check): 475 | # Since some parameters are estimators, we expect them to be modified during fit(), 476 | # which is why we skip these checks (in line with checks on other metaestimators 477 | # like Pipeline) 478 | if check.func.__name__ in [ 479 | "check_estimators_overwrite_params", 480 | "check_dont_overwrite_parameters", 481 | ]: 482 | # we don't clone in pipeline or feature union 483 | return 484 | check(estimator) 485 | -------------------------------------------------------------------------------- /skdag/dag/tests/utils.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | import numpy as np 4 | from sklearn.base import BaseEstimator 5 | 6 | 7 | class NoFit: # pragma: no cover 8 | """Small class to test parameter dispatching.""" 9 | 10 | def __init__(self, a=None, b=None): 11 | self.a = a 12 | self.b = b 13 | 14 | 15 | class NoTrans(NoFit): # pragma: no cover 16 | def fit(self, X, y): 17 | return self 18 | 19 | def get_params(self, deep=False): 20 | return {"a": self.a, "b": self.b} 21 | 22 | def set_params(self, **params): 23 | self.a = params["a"] 24 | return self 25 | 26 | 27 | class NoInvTransf(NoTrans): # pragma: no cover 28 | def transform(self, X): 29 | return X 30 | 31 | 32 | class Transf(NoInvTransf): # pragma: no cover 33 | def inverse_transform(self, X): 34 | return X 35 | 36 | 37 | class TransfFitParams(Transf): # pragma: no cover 38 | def fit(self, X, y, **fit_params): 39 | self.fit_params = fit_params 40 | return self 41 | 42 | 43 | class Mult(BaseEstimator): # pragma: no cover 44 | def __init__(self, mult=1): 45 | self.mult = mult 46 | 47 | def fit(self, X, y): 48 | return self 49 | 50 | def transform(self, X): 51 | return np.asarray(X) * self.mult 52 | 53 | def inverse_transform(self, X): 54 | return np.asarray(X) / self.mult 55 | 56 | def predict(self, X): 57 | return (np.asarray(X) * self.mult).sum(axis=1) 58 | 59 | predict_proba = predict_log_proba = decision_function = predict 60 | 61 | def score(self, X, y=None): 62 | return np.sum(X) 63 | 64 | 65 | class FitParamT(BaseEstimator): # pragma: no cover 66 | """Mock classifier""" 67 | 68 | def __init__(self): 69 | self.successful = False 70 | 71 | def fit(self, X, y, should_succeed=False): 72 | self.successful = should_succeed 73 | return self 74 | 75 | def predict(self, X): 76 | return self.successful 77 | 78 | def fit_predict(self, X, y, should_succeed=False): 79 | self.fit(X, y, should_succeed=should_succeed) 80 | return self.predict(X) 81 | 82 | def score(self, X, y=None, sample_weight=None): 83 | if sample_weight is not None: 84 | X = X * sample_weight 85 | return np.sum(X) 86 | 87 | 88 | class DummyTransf(Transf): # pragma: no cover 89 | """Transformer which store the column means""" 90 | 91 | def fit(self, X, y): 92 | self.means_ = np.mean(X, axis=0) 93 | # store timestamp to figure out whether the result of 'fit' has been 94 | # cached or not 95 | self.timestamp_ = time() 96 | return self 97 | 98 | 99 | class DummyEstimatorParams(BaseEstimator): # pragma: no cover 100 | """Mock classifier that takes params on predict""" 101 | 102 | def fit(self, X, y): 103 | return self 104 | 105 | def predict(self, X, got_attribute=False): 106 | self.got_attribute = got_attribute 107 | return self 108 | 109 | def predict_proba(self, X, got_attribute=False): 110 | self.got_attribute = got_attribute 111 | return self 112 | 113 | def predict_log_proba(self, X, got_attribute=False): 114 | self.got_attribute = got_attribute 115 | return self 116 | -------------------------------------------------------------------------------- /skdag/exceptions.py: -------------------------------------------------------------------------------- 1 | class DAGError(Exception): 2 | "An exception indicating an error in constructing the requested DAG." --------------------------------------------------------------------------------